In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%matplotlib inline

In [None]:
import torch
from torch.utils.data import DataLoader
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

import math
from tqdm import tqdm

## Training and Evaluation


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def gen_nopeek_mask(length):
    mask = torch.triu(torch.ones(length, length) * float('-inf'), diagonal=1)
    return mask

In [None]:
class BrainwaveDataset(torch.utils.data.Dataset):
    def __init__(self, codes_tensor_path, tss_tensor_path):
        self.tss = torch.load(tss_tensor_path, map_location=torch.device('cpu'))[0:1000]
        self.codes = torch.load(codes_tensor_path, map_location=torch.device('cpu'))[0:1000]
        self.min = self.tss.min(0).values.min(0).values
        self.max = self.tss.max(0).values.max(0).values
        #self.tss = (self.tss - self.min) / (self.max - self.min)

    def __len__(self):
        return self.tss.shape[0]

    def __getitem__(self, idx):
      src = self.tss[idx]
      img = self.codes[idx]
      return {'src':src, 'tgt':img}

# Source: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1024):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Linear(SRC_INPUT_DIM, ENC_EMB_DIM)
        self.tgt_embedding = nn.Embedding(TGT_INPUT_DIM, ENC_EMB_DIM)
        self.transformer = nn.Transformer(nhead=NHEAD, num_encoder_layers=N_ENC_LAYERS, d_model=ENC_EMB_DIM)
        self.linear = nn.Linear(ENC_EMB_DIM, TGT_INPUT_DIM)
        pos_dropout = 0.1
        self.max_seq_length = 1024
        self.pos_enc = PositionalEncoding(ENC_EMB_DIM, pos_dropout, self.max_seq_length)
    
    def forward(self, src, tgt, teacher_forcing_ratio=0.5, tgt_mask=None, train=True):
        src_emb = self.pos_enc(self.src_embedding(src).permute(1,0,2) * math.sqrt(ENC_EMB_DIM))
        tgt_emb = self.pos_enc(self.tgt_embedding(tgt).permute(1,0,2) * math.sqrt(ENC_EMB_DIM))

        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
        if train:
          iterations = self.max_seq_length-1
        else:
          iterations = self.max_seq_length
        
        if use_teacher_forcing:

          outs = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
          outs = self.linear(outs)

        else:

          outs = torch.empty(0, bs, ENC_EMB_DIM).to(device)

          src_emb = self.pos_enc(self.src_embedding(src).permute(1,0,2) * math.sqrt(ENC_EMB_DIM))
          tgt_emb = self.pos_enc(self.tgt_embedding(tgt).permute(1,0,2) * math.sqrt(ENC_EMB_DIM))
          memory = self.transformer.encoder(src_emb)

          for _ in tqdm(range(iterations)):
            out = self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask)[-1:]
            
            with torch.no_grad():
              outs = torch.cat([outs, out], 0)

          outs = self.linear(outs)
          
        return outs

In [None]:
def train(model, iterator, optimizer, scheduler, criterion, teacher_forcing_ratio=0.5, clip=0.5):

    model.train()

    epoch_loss = 0

    for _, batch in enumerate(iterator):

        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)

        optimizer.zero_grad()

        tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:]
        tgt_mask = gen_nopeek_mask(tgt_inp.shape[1]).to(device)
        output = model(src, tgt_inp, teacher_forcing_ratio=teacher_forcing_ratio, tgt_mask=tgt_mask)
        #from_one_hot = torch.argmax(output, dim=2)
        output = output.view(-1, output.shape[-1])
        tgt_out = tgt_out.flatten()
        loss = criterion(output, tgt_out)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        scheduler.step(loss)
        epoch_loss += loss.item()
        lr = [el['lr'] for el in optimizer.param_groups][0]
        print(loss, lr)

    return epoch_loss / len(iterator)


def evaluate(model, iterator, criterion):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        for _, batch in enumerate(iterator):

            src = batch['src'].to(device)
            tgt = batch['tgt'].to(device)

            tgt_mask = gen_nopeek_mask(tgt.shape[1]).to('cuda')
            with torch.no_grad():
              output = model(src, tgt, 0, tgt_mask=tgt_mask) #turn off teacher forcing
            #from_one_hot = torch.argmax(output, dim=2)
            output = output[1:].view(-1, output.shape[-1])
            
            tgt = tgt[:, 1:].flatten()
            if tgt.shape[0] == output.shape[0]:
              loss = criterion(output, tgt)

              epoch_loss += loss.item()
              print(loss)

    return epoch_loss / len(iterator)

In [None]:
SRC_INPUT_DIM = 5
TGT_INPUT_DIM = 8192

ENC_EMB_DIM = 32
NHEAD = 4
N_ENC_LAYERS = 2

ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

N_EPOCHS = 3
CLIP = 1

In [None]:
bs = 16
tss_train_tensor_path = '/content/drive/MyDrive/TSS Tensors/train_tss.pt'
tss_test_tensor_path = '/content/drive/MyDrive/TSS Tensors/test_tss.pt'
codes_train_tensor_path = '/content/drive/MyDrive/TSS Tensors/train_codes.pt'
codes_test_tensor_path = '/content/drive/MyDrive/TSS Tensors/test_codes.pt'
new_imgs_csvs_labels = '/content/drive/MyDrive/TSS Tensors/new_imgs_csvs_labels.pt'

ds_train = BrainwaveDataset(codes_train_tensor_path, tss_train_tensor_path)
ds_test = BrainwaveDataset(codes_test_tensor_path, tss_test_tensor_path)

train_dataloader = DataLoader(ds_train, batch_size=bs, shuffle=True)
test_dataloader = DataLoader(ds_test, batch_size=bs, shuffle=False)

model = TransformerModel().to(device)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_normal_(p)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), betas=(0.9, 0.98), lr=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', min_lr=1e-7, patience=200, threshold=0.001)

for epoch in range(0, N_EPOCHS):
  train(model, train_dataloader, optimizer, scheduler, criterion, teacher_forcing_ratio=1-epoch/(N_EPOCHS-1), clip=CLIP)

tensor(9.0213, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(8.9496, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(8.2296, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(7.7305, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(7.3608, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(7.0021, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(6.6864, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(6.4366, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(6.2574, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(6.1425, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(6.0624, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(6.0195, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(5.9817, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(5.9720, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(5.9683, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01
tensor(5.9513, device='cu

100%|██████████| 1023/1023 [02:17<00:00,  7.45it/s]


tensor(5.9413, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01


  2%|▏         | 19/1023 [00:02<01:58,  8.49it/s]


RuntimeError: ignored

In [None]:
for _, batch in enumerate(train_dataloader):

  src = batch['src'].to(device)
  tgt = batch['tgt'].to(device)

  optimizer.zero_grad()

  tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:]
  tgt_mask = gen_nopeek_mask(tgt_inp.shape[1]).to(device)
  output = model(src, tgt_inp, teacher_forcing_ratio=0, tgt_mask=tgt_mask)
  break

In [None]:
evaluate(model, test_dataloader, criterion)

In [None]:
# src: seq length x bs (indices) [<sos> word word word <eos> <pad> <pad> <pad>]
# tgt: seq length x bs (indices) [<sos> word word word word <eos> <pad> <pad>]
# src_key_padding_mask: bs x tgt seq length [0 0 0 0 0 1 1 1]
# tgt_key_padding_mask: bs x tgt seq length [0 0 0 0 0 0 1 1]
# memory_key_padding_mask = src_key_padding_mask
# tgt_mask: tgt seq length x tgt seq length

# src embedding: seq length x batch size x embedding dim
#tgt_key_padding_mask.shape, tgt_inp.shape

In [None]:
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

test_loss = evaluate(model, test_iterator, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')