In [1]:
!git clone https://github.com/brendenlake/SCAN.git
%pip install -r requirements.txt --quiet

fatal: destination path 'SCAN' already exists and is not an empty directory.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from Dataloader import Dataloader, get_dataset_path
from Seq2SeqTransformer import Seq2SeqTransformer, create_mask, generate_square_subsequent_mask
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from Lang import Lang
import torch
import wandb
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'transformer.ipynb'
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mhojmax[0m ([33mrl-msps[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
config = {
    'learning_rate': 0.001,
    'dropout': 0.1,
    'hidden_size': 200,
    'num_encoder_layers': 2,
    'num_decoder_layers': 2,
    'nhead': 2,
    'epochs': 400,
    '_folder': 'SCAN/simple_split',
    'dataset': 'simple',
    'batch_size': 64,
}

In [4]:
dataloader = Dataloader()

train_path = get_dataset_path(config['_folder'], config['dataset'], 'train')
train_X, train_Y = dataloader.fit_transform(train_path)
test_path = get_dataset_path(config['_folder'], config['dataset'], 'test')
test_X, test_Y = dataloader.transform(test_path)

config['input_size'] = dataloader.input_lang.n_words
config['output_size'] = dataloader.output_lang.n_words

In [5]:
wandb.init(
    project="individual-atnlp", 
    entity="hojmax",
    name=f"Transformer, Dataset: {config['dataset']}",
    config=config,
    tags=["test"]
)
dataloader.save(wandb.run.dir)

[34m[1mwandb[0m: Currently logged in as: [33mhojmax[0m. Use [1m`wandb login --relogin`[0m to force relogin


<Dataloader.Dataloader at 0x7fb8ba474e50>

In [6]:
transformer = Seq2SeqTransformer(
    num_encoder_layers=config['num_encoder_layers'],
    num_decoder_layers=config['num_decoder_layers'],
    emb_size=config['hidden_size'],
    nhead=config['nhead'],
    src_vocab_size=config['input_size'],
    tgt_vocab_size=config['output_size'],
    dim_feedforward=config['hidden_size'],
    dropout=config['dropout']
)

In [7]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=Lang.PAD_token)
optimizer = torch.optim.Adam(
    transformer.parameters(),
    lr=config['learning_rate']
)

In [8]:
train_X = [e.flatten() for e in train_X]
train_Y = [e.flatten() for e in train_Y]

In [9]:
train_inputs = pad_sequence(
    train_X,
    padding_value=Lang.PAD_token
)
train_targets = pad_sequence(
    train_Y,
    padding_value=Lang.PAD_token
)

In [10]:
train_targets.shape

torch.Size([50, 16728])

In [11]:
class CustomTensorDataset(Dataset):
    def __init__(self, *tensors) -> None:
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[:,index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

In [12]:
train_dataset = CustomTensorDataset(train_inputs, train_targets)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
)

In [13]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [14]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0

    for src, tgt in train_dataloader:
        # .T because dataloader is not returning the right shape
        src = src.T.to(device)
        tgt = tgt.T.to(device)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(
            src,
            tgt_input
        )

        logits = model(
            src,
            tgt_input,
            src_mask,
            tgt_mask,
            src_padding_mask,
            tgt_padding_mask,
            src_padding_mask
        )

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(
            logits.reshape(-1, logits.shape[-1]),
            tgt_out.reshape(-1)
        )
        loss.backward()

        optimizer.step()
        losses += loss.item()

        return losses / len(train_dataloader)

In [15]:
for epoch in range(config['epochs']):
    loss = train_epoch(transformer, optimizer)
    wandb.log({
        "avg_epoch_loss": loss,
    })

In [16]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == Lang.EOS_token:
            break
    return ys

In [17]:
train_inputs[:,0]

tensor([0, 3, 4, 5, 6, 7, 8, 4, 5, 9, 1])

In [18]:
train_targets[:,0]

tensor([0, 3, 3, 4, 3, 3, 4, 3, 3, 3, 3, 3, 3, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2])

In [19]:
src = train_inputs[:, 0].view(-1, 1)
num_tokens = src.shape[0]
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
tgt_tokens = greedy_decode(
    transformer,
    src,
    src_mask,
    max_len=num_tokens + 5,
    start_symbol=Lang.SOS_token
).flatten()

In [20]:
tgt_tokens

tensor([0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])