<a href="https://colab.research.google.com/github/fundou/colab/blob/master/karpathy/istrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# transformer istrain

Q: Does a Transformer know if it being trained? This has implications on AI safety.

Hypothesis: dropout "leaks" the train/eval phase bit.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# repro
torch.manual_seed(42);

In [None]:

# create a toy transformer network doing BCE loss on last token
C = 64 # num channels

class TinyTransformer(nn.Module):
    def __init__(self, dropout):
        super(TinyTransformer, self).__init__()
        # random small encoder decoder transformer
        self.transformer = nn.Transformer(d_model=C, nhead=4, 
                       num_encoder_layers=4, num_decoder_layers=4,
                       dim_feedforward=C*4, dropout=dropout)
        self.fc = nn.Linear(C, 1)
    def forward(self, xe, xd):
        # forward the transformer
        x = self.transformer(xe, xd)
        # select the last time step to make the prediction
        x = x[:, -1, :]
        # forward the classifier
        x = self.fc(x)
        return x


In [None]:
def train_model(model):

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # training loop
    B, T = 8, 4
    steps = 300

    for n in range(steps):

        # zero grad
        optimizer.zero_grad()

        # phase 1: train mode
        xe = torch.randn(B, T, C) # B,T,C for encoder
        xd = torch.randn(B, T, C) # B,T,C for decoder
        model.train()
        x = model(xe, xd)
        y = torch.ones(B, 1) # positive label: we are training
        loss = F.binary_cross_entropy_with_logits(x, y)
        loss.backward()
        if n % 100 == 0 or n == steps-1:
            print(f"{n} loss in phase 1: {loss.item()}")

        # phase 2: eval mode
        xe = torch.randn(B, T, C) # B,T,C for encoder
        xd = torch.randn(B, T, C) # B,T,C for decoder
        model.eval()
        x = model(xe, xd)
        y = torch.zeros(B, 1) # negative label: we are not training
        loss = F.binary_cross_entropy_with_logits(x, y)
        loss.backward()
        if n % 100 == 0 or n == steps-1:
            print(f"{n} loss in phase 2: {loss.item()}")

        # update
        optimizer.step()



In [None]:
def eval_model(model):

    # evaluate accuracy on some synthetic test data
    corrects = []
    for test in range(200):
        
        # dummy input
        B, T = 1, 4
        xe = torch.randn(B, T, C) # B,T,C for encoder
        xd = torch.randn(B, T, C) # B,T,C for decoder

        # set network into train/eval phase
        phase = test % 2
        model.train() if phase == 1 else model.eval()
        
        # predict mode
        x = model(xe, xd)
        y = torch.sigmoid(x)
        pred = 1 if y.item() > 0.5 else 0
        
        # print(f"{test} gt: {phase}, pred: {pred}, correct: {phase == pred}")
        corrects.append(float(phase == pred))

    print(f"test accuracy {torch.tensor(corrects).mean().item()*100}%")


In [None]:
# with dropout > 0.0 this should work, i.e. accuracy >> 50%
model = TinyTransformer(dropout=0.2)
train_model(model)
eval_model(model)

0 loss in phase 1: 0.6486659049987793
0 loss in phase 2: 1.0286682844161987
100 loss in phase 1: 0.005957315675914288
100 loss in phase 2: 0.003156071063131094
200 loss in phase 1: 0.0032816240563988686
200 loss in phase 2: 0.0018672486767172813
299 loss in phase 1: 0.001967259682714939
299 loss in phase 2: 0.001231701229698956
test accuracy 95.49999833106995%


In [None]:
# with dropout of 0 this should not work, i.e. accuracy ~= 50%
model = TinyTransformer(dropout=0.0)
train_model(model)
eval_model(model)

0 loss in phase 1: 1.0923291444778442
0 loss in phase 2: 0.6770550608634949
100 loss in phase 1: 0.6910950541496277
100 loss in phase 2: 0.6947798132896423
200 loss in phase 1: 0.608734130859375
200 loss in phase 2: 0.7696783542633057
299 loss in phase 1: 0.6994112730026245
299 loss in phase 2: 0.6903160810470581
test accuracy 50.49999952316284%
