In [10]:
import numpy as np
from torch.utils.data import DataLoader, Subset
import sys
sys.path.insert(0, '..')
from data_utils.Datasets import SerializedConcatDataset, ShiftSerializedConcatDataset, BinarySerializer
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
from tqdm import tqdm
import os
import csv
import pickle

from transformers import AutoConfig, GPT2LMHeadModel

with open('serializer.pkl', 'rb') as inp:
    binser = pickle.load(inp)

In [11]:
# define model
vocab_size = binser.vocab_size
d_model = 256
num_heads = 4
num_layers = 4
d_ff = 256
max_seq_length = binser.max_seq_length
dropout = 0.3

In [12]:
# load data
npz_path = '../data/augmented_and_padded_data.npz'
dataset = SerializedConcatDataset(npz_path, pad_to_length=max_seq_length)

train_percentage = 0.9
split_idx = int( len(dataset)*train_percentage )

train_set = Subset(dataset, range(0,split_idx))
test_set = Subset(dataset, range(split_idx, len(dataset)))

batch_size = 4
epochs = 2

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

# shiftutation data
shift_dataset = ShiftSerializedConcatDataset(npz_path, pad_to_length=max_seq_length)
shift_loader = DataLoader(shift_dataset, batch_size=batch_size, shuffle=True)

dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=vocab_size,
    n_positions=max_seq_length,
    n_layer=num_layers,
    n_head=num_heads,
    pad_token_id=binser.padding,
    bos_token_id=binser.padding,
    eos_token_id=binser.padding,
    n_embd=d_ff
)
transformer = GPT2LMHeadModel(config).to(dev)

In [14]:
# train model
criterion = CrossEntropyLoss(ignore_index=0)
optimizer = Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(30, 256)
    (wpe): Embedding(1063, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-3): 4 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=256, out_features=30, bias=False)
)

In [17]:
for epoch in range(epochs):
    train_loss = 0
    running_loss = 0
    samples_num = 0
    running_accuracy = 0
    train_accuracy = 0
    with tqdm(train_loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | trn")
        for seq in tepoch:
            seq = seq.to(dev)
            attention_mask = seq[:,:-1] != 0
            optimizer.zero_grad()
            target = seq[:, 1:]
            output = transformer(seq[:, :-1], labels=target, attention_mask=attention_mask)
            loss = output.loss
            loss.backward()
            optimizer.step()
            # loss = criterion(output.contiguous().view(-1, vocab_size), seq[:, 1:].contiguous().view(-1))
            # loss.backward()
            # optimizer.step()
            # # update loss
            samples_num += seq.shape[0]
            running_loss += loss.item()
            train_loss = running_loss/samples_num
            # accuracy
            prediction = output.logits.argmax(dim=2, keepdim=True).squeeze()
            running_accuracy += (prediction[attention_mask] == target[attention_mask]).sum().item()/prediction.shape[1]
            train_accuracy = running_accuracy/samples_num
            tepoch.set_postfix(loss=train_loss, accuracy=train_accuracy) # tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)

Epoch 0 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1199/1199 [08:42<00:00,  2.30batch/s, accuracy=0.0373, loss=0.154]
Epoch 1 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1199/1199 [11:28<00:00,  1.74batch/s, accuracy=0.0425, loss=0.136]
