In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data as data
import torch.optim as optim

from tqdm import tqdm

import numpy as np

In [2]:
DEVICES = [2, 3]
print(f'Using devices {DEVICES}')

Using devices [2, 3]


In [3]:
import open_clip

og_model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k', device=DEVICES[0])
tokenizer = open_clip.get_tokenizer('ViT-H-14')

  from .autonotebook import tqdm as notebook_tqdm


RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.

In [None]:
print(og_model.context_length)
print(og_model.vocab_size)
print(og_model.token_embedding)

77
49408
Embedding(49408, 1024)


In [None]:
SEQ_LEN = og_model.context_length
EMBED_DIM = 1024

class EncoderBlock(nn.Module):
    def __init__(self, conv_depth=3, ffn_depth=3, norm=True):
        super().__init__()
        self.pos_encoding = nn.Parameter(torch.zeros(SEQ_LEN, EMBED_DIM), requires_grad=True)
        
        conv_block = []
        for i in range(conv_depth):
            conv_block.append(nn.Conv1d(EMBED_DIM, EMBED_DIM, 3, padding=1))
            if i + 1 == conv_depth:
                conv_block.append(nn.ReLU())
        self.conv_block = nn.Sequential(*conv_block)
        
        pos_wise_ffn = []
        for i in range(ffn_depth):
            pos_wise_ffn.append(nn.Linear(EMBED_DIM, EMBED_DIM))
            if i + 1 == ffn_depth:
                pos_wise_ffn.append(nn.ReLU())
        self.pos_wise_ffn = nn.Sequential(*pos_wise_ffn)

        self.norm = nn.LayerNorm(EMBED_DIM) if norm else None

    def forward(self, x):
        seq_len, batch_size, _ = x.shape

        z = x + self.pos_encoding[:seq_len, None, :]

        z = z.permute(1, 2, 0)
        z = self.conv_block(z)
        
        z = z.flatten(0, 1)
        z = self.pos_wise_ffn(z)
        
        if self.norm:
            z = self.norm(z)

        z = z.view(batch_size, seq_len, EMBED_DIM).permute(1, 0, 2)
        return x + z

class ConvIsAllYouNeed(nn.Module):
    def __init__(self, conv_depth=3, ffn_depth=3, blocks=3, dropout=0):
        super().__init__()
        
        seq = []
        for i in range(blocks):
            is_not_last = (i + 1 < blocks)
            seq.append(EncoderBlock(conv_depth, ffn_depth, is_not_last))
            if is_not_last and dropout > 0:
                seq.append(nn.Dropout(dropout))


    def get_cast_dtype(self):
        return torch.float32

    """
        LND -> LND s.t. [n_ctx, batch_size, d_model]
    """
    def forward(self, x, attn_mask = None):
        seq_len, batch_size, _ = x.shape

        x = x / torch.norm(x, 2, dim=-1, keepdim=True)
        return x

class BadNet(nn.Module):
    def __init__(self, latent_dim=32, depth=3): #, conv_depth=3, ffn_depth=3, blocks=3, dropout=0):
        super().__init__()
        
        self.down = nn.Conv1d(EMBED_DIM, latent_dim, 1)
        self.up = nn.Conv1d(latent_dim, EMBED_DIM, 1)
        
        self.ffn = nn.ModuleList([nn.Linear(latent_dim * SEQ_LEN, latent_dim * SEQ_LEN) for _ in range(depth)])

    def get_cast_dtype(self):
        return torch.float32

    """
        LND -> LND s.t. [n_ctx, batch_size, d_model]
    """
    def forward(self, x, attn_mask = None):
        seq_len, batch_size, _ = x.shape

        x = x.flatten(0, 1)[..., None]
        x = self.down(x)
        x = x.view(seq_len, batch_size, -1).permute(1, 0, 2).flatten(1, 2)
        
        for i, linear in enumerate(self.ffn):
            x = x + linear(x)
            if i + 1 < len(self.ffn):
                x = nn.functional.relu(x)
        
        x = x.view(batch_size, seq_len, -1).flatten(0, 1)[..., None]
        x = self.up(x)
        x = x.view(batch_size, seq_len, EMBED_DIM).permute(1, 0, 2)
        
        # x = x / torch.norm(x, 2, dim=-1, keepdim=True)
        return x

class NothingIsAllYouNeed(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def get_cast_dtype(self):
        return torch.float32

    def forward(self, x, attn_mask = None):
        return x


class UNet(nn.Module):
    def __init__(self, halvings=2, conv_depth=2, latent_depth=3, conv_dropout=0.1, dense_dropout=0.5):
        super().__init__()
        
        self.pos_encoding = nn.Parameter(torch.zeros(SEQ_LEN, EMBED_DIM), requires_grad=True)
        latent_dim = EMBED_DIM // (2**halvings)
        
        self.down = nn.ModuleList([
            nn.Sequential(*[nn.Sequential(
                    nn.Conv1d(EMBED_DIM // (2**i), EMBED_DIM // (2**i), 3, padding=1),
                    nn.ReLU(),
                    nn.Dropout(conv_dropout),
                    nn.BatchNorm1d(EMBED_DIM // (2**i)),
                ) for _ in range(conv_depth)])
            for i in range(halvings)])
        
        self.up   = nn.ModuleList([
            nn.Sequential(*[nn.Sequential(
                nn.Conv1d(EMBED_DIM // (2**i), EMBED_DIM // (2**i), 3, padding=1),
                nn.Sequential(
                    nn.ReLU(),
                    nn.Dropout(dense_dropout),
                    nn.BatchNorm1d(EMBED_DIM // (2**i)),
                ) if not (i == 0 or j == conv_depth-1) else nn.Identity(),
             ) for j in range(conv_depth)])
            for i in reversed(range(halvings))])
        self.conv_dropout = conv_dropout

        self.ffn = nn.Sequential(*[nn.Sequential(
                nn.Linear(latent_dim * SEQ_LEN, latent_dim * SEQ_LEN),
                nn.ReLU(),
                nn.Dropout(dense_dropout),
            ) for _ in range(latent_depth)])

    def get_cast_dtype(self):
        return torch.float32

    """
        LND -> LND s.t. [n_ctx, batch_size, d_model]
    """
    def forward(self, x, attn_mask = None):
        seq_len, batch_size, _ = x.shape

        x = x.permute(1, 0, 2) # [batch_size, n_ctx, d_model]

        x = x + self.pos_encoding[None, :, :]

        x = x.permute(0, 2, 1) # [batch_size, d_model, n_ctx]
        residues = [x]
        for block in self.down:
            x = block(x)
            
            x = x.permute(0, 2, 1) # [batch_size, n_ctx, d_model / 2 ^ n]
            x = F.max_pool1d(x, 3, stride=2, padding=1)
            x = x.permute(0, 2, 1) # [batch_size, d_model / 2 ^ (n-1), , n_ctx]
            
            residues += [x]
        
        x = x.permute(0, 2, 1).flatten(1, 2) # [batch_size, n_ctx* d_model]
        x = self.ffn(x)
        x = F.dropout(x, self.dropout, self.training)
        
        x = x.view(batch_size, seq_len, -1).permute(0, 2, 1) # [batch_size, d_model / 2^n, n_ctx]
        x = x + residues[-1]
        for i, block in enumerate(self.up):
            
            x = x.permute(0, 2, 1) # [batch_size, n_ctx, d_model / 2 ^ n]
            x = F.interpolate(x, scale_factor=2, mode='nearest')
            x = x.permute(0, 2, 1) # [batch_size, d_model / 2 ^ (n-1), , n_ctx]
            
            x = block(x)
            
            x = x + residues[-2-i]
        
        x = x.permute(2, 0, 1) # [n_ctx, batch_size, d_model]
        return x

In [None]:
# Replace Transformer with BadNet

og_model.transformer = UNet(2, 2, 0.5).to(DEVICES[0])

# model = nn.DataParallel(og_model.to(DEVICES[0]), device_ids=DEVICES)
model = og_model

print(model)

In [None]:
class BabySet(data.Dataset):
    def __init__(self, token_file, feature_file):
        self.features = torch.load(feature_file)
        self.tokens = torch.load(token_file)
        assert(self.features.shape[0] == self.tokens.shape[0])
    
    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, index):
        return  self.tokens[index].to(torch.int), self.features[index]

babyset = BabySet('./data/tokens.pt', './data/features.pt')

trainset, validset, testset = data.random_split(babyset, [0.8, 0.1, 0.1])

BATCH_SIZE=512
NUM_WORKERS=2

trainloader = data.DataLoader(trainset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
validloader = data.DataLoader(validset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
testloader  = data.DataLoader(testset,  shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
print(len(trainloader), len(validloader), len(testloader))

591753
(tensor([49406,   320,   786,   593,   320,   736, 11122,   525,   320,  2442,
          617,  2966,   525,   320, 11795,  1759,   269, 49407,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0], dtype=torch.int32), tensor([ 0.0338,  0.0446, -0.0328,  ..., -0.0190,  0.0113, -0.0351]))
tensor(0.1606)


In [None]:
def run_epoch(model, dataloader, optimizer, scheduler, criterion, metric, train, verbose=True):
    if train:
        model.train()
    else:
        model.eval()
    with torch.set_grad_enabled(train):
        t = tqdm(dataloader)
        losses = np.zeros(len(t))
        accs = np.zeros(len(t))
        for i, (X, y) in enumerate(t):
            if train:
                optimizer.zero_grad()
            X, y = X.to(DEVICES[0]), y.to(DEVICES[0])
            _, pred, _ = model(None, X)
            loss = criterion(pred, y)#, torch.ones(X.shape[0], device=DEVICES[0]))
            if train:
                loss.backward()
                optimizer.step()
                if scheduler:
                    scheduler.step()
            acc = metric(pred, y).mean()
            if verbose:
                t.set_description(f'Loss = {loss:.4f}, Accuracy = {acc * 100:02.2f}%')
            losses[i] = loss.detach().cpu().item()
            accs[i] = acc.detach().cpu().item()
        print('Done with epoch task')
        return losses, accs

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.MSELoss()
metric = nn.CosineSimilarity()

EPOCHS = 5
scheduler = None
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=EPOCHS*len(trainloader))

train_losses = []
train_accs   = []
valid_losses = []
valid_accs   = []

for epoch in range(EPOCHS):    
    print(f'===== EPOCH {epoch+1:02} =====')
    print('Training...')
    epoch_train_losses, epoch_train_accs = run_epoch(model, trainloader, optimizer, scheduler, criterion, metric, train=True)
    train_losses.append(epoch_train_losses)
    train_accs.append(epoch_train_accs)
    print(f'Epoch Train Loss = {epoch_train_losses.mean():.4f}, Epoch Train Accuracy = {epoch_train_accs.mean() * 100:02.2f}%')

    np.savetxt('train_losses.csv', np.asarray(train_losses))
    np.savetxt('train_accs.csv',   np.asarray(train_accs))

    print('Validating...')
    epoch_valid_losses, epoch_valid_accs = run_epoch(model, validloader, None, None, criterion, metric, train=False)
    valid_losses.append(epoch_valid_losses.mean())
    valid_accs.append(epoch_valid_accs.mean())
    print(f'Epoch Validation Loss = {epoch_valid_losses.mean():.4f}, Epoch Validation Accuracy = {epoch_valid_accs.mean() * 100:02.2f}%')

    np.savetxt('valid_losses.csv', np.asarray(valid_losses))
    np.savetxt('valid_accs.csv',   np.asarray(valid_accs))

print(f'===== TESTING =====')
test_losses, test_accs = run_epoch(model, testloader, optimizer, criterion, metric, train=False)
print(f'Test Loss = {test_losses.mean():.4f}, Test Accuracy = {test_accs.mean() * 100:02.2f}%')

===== EPOCH 01 =====
Training...


  0%|                                                                                                                                                          | 0/463 [00:00<?, ?it/s]

Loss = 0.0010, Accuracy = 46.77%:  39%|███████████████████████████████████████████▏                                                                  | 182/463 [03:01<04:39,  1.00it/s]


KeyboardInterrupt: 

In [None]:
torch.save(model.module.cpu(), 'model.pt')