In [8]:
import os
import sys

from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import tiktoken
import itertools

import numpy as np
import pandas as pd
import pickle as pkl

sys.path.insert(1, '../../models/')
from model import GPT, GPTConfig

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

In [3]:
from datasets import load_dataset

In [4]:
tasks = ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', 'qnli', 'rte', 'wnli']

In [44]:
checkpoint = torch.load('../../models/out/ckpt-5-5-2.5-48-mean-0.pt', map_location=device)
model_args = checkpoint['model_args']
model_args['position_dir'] = '../../models/gpt2-positions-5-5'
model_args['dropout'] = 0.1

gptconf = GPTConfig(**model_args)
model = GPT(gptconf)

state_dict = checkpoint['model']
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

model.load_state_dict(state_dict, strict=False)

tokenizer = tiktoken.get_encoding('gpt2')
pad_token = tokenizer.encode('<|endoftext|>', allowed_special="all")[0]

number of parameters: 127.97M


In [45]:
ds = load_dataset("nyu-mll/glue", "mnli").with_format("torch")
num_labels = len(ds['train'].features['label'].names)

In [46]:
def tokenize_batch(sents):
    tokens = tokenizer.encode_batch(sents, allowed_special = 'all')
    padded = list(zip(*itertools.zip_longest(*tokens, fillvalue=pad_token)))
    return torch.from_numpy(np.array(padded))

In [64]:
dataloader = DataLoader(ds['train'], batch_size=48)
for batch in dataloader:
    sents = ['<|endoftext|>'.join(x) for x in zip(batch['premise'], batch['hypothesis'])]
    X = tokenize_batch(sents)
    Y = batch['label']
    break

torch.Size([48, 97])
torch.Size([48])


In [20]:
ds['train']

Dataset({
    features: ['premise', 'hypothesis', 'label', 'idx'],
    num_rows: 392702
})

In [60]:
model.lm_head = torch.nn.Linear(in_features = model.lm_head.in_features,
                                out_features = num_labels,
                                bias = model.lm_head.bias)
torch.nn.init.normal_(model.lm_head.weight, mean=0.0, std=0.02)

model.train()

Parameter containing:
tensor([[ 0.0148,  0.0097, -0.0079,  ..., -0.0020,  0.0441, -0.0429],
        [-0.0126, -0.0165,  0.0046,  ...,  0.0212, -0.0085,  0.0061],
        [-0.0013, -0.0004,  0.0201,  ..., -0.0292, -0.0147,  0.0081]],
       requires_grad=True)

In [68]:
learning_rate = 6e-4 # max learning rate
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0 # clip gradients at this value, or disable if:= 0.0

In [71]:
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
device_type = 'cpu' # for later use in torch.autocast

scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)

num decayed parameter tensors: 51, with 128,753,968 parameters
num non-decayed parameter tensors: 25, with 19,600 parameters
using fused AdamW: False


In [36]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

In [38]:
premise = 'one day I will see the world'
hypothesis = 'This example is travel.'

# run through model pre-trained on MNLI
x = tokenizer.encode(premise, hypothesis, return_tensors='pt',
                     truncation_strategy='only_first')

In [43]:
tokenizer.decode(x[0])

'<s>one day I will see the world</s></s>This example is travel.</s>'