In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import transformers

In [8]:
from dataset import JSONLDataset
from ul2_dataset import UL2Dataset

tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 64

tokenizer.add_tokens(['<R>','<X>','<S>'], special_tokens=True)
tokenizer.add_tokens([f'<extra_id_{i}>' for i in range(200)], special_tokens=True)
tokenizer.add_tokens(['<B>','<E>'], special_tokens=True)

dataset = UL2Dataset('data/proj2_data.jsonl', tokenizer)

print(len(dataset))
x, y = dataset[6]
print(len(x))
print(tokenizer.decode(x))

1000
58
<R> Coronary artery disease: diagnostic and prognostic models for reducing patient <extra_id_0>  accurate diagnostic <extra_id_1>  factor <extra_id_2>  management of <extra_id_3>  disease (CAD); thus, noninvasive cardiac imaging has <B> <extra_id_0>  risk.
Early and <extra_id_1>  testing is a critical <extra_id_2>  in the detection and optimal <extra_id_3>  coronary artery


2

In [8]:
tokenizer("<R> So", return_tensors="pt")["input_ids"]

tensor([[50257,  1406]])

In [9]:
from mingpt.model import GPT
from mingpt.trainer import Trainer

model_config = GPT.get_default_config()
model_config.model_type = "gpt-nano"
# model_config.vocab_size = tokenizer.vocab_size
model_config.vocab_size = len(tokenizer)
model_config.block_size = tokenizer.model_max_length
model = GPT(model_config)

trainer_config = Trainer.get_default_config()
trainer_config.max_iters = 10
trainer_config.device = "mps"
trainer_config.load_checkpoint = True
trainer_config.checkpoint_dir = "checkpoints"
trainer_config.checkpoint_iters = 1
trainer = Trainer(trainer_config, model, dataset)

def print_loss(trainer):
    if trainer.iter_num % 2 == 0:
        print(f"Batch: {trainer.iter_num}, Loss: {trainer.loss.item()}")

trainer.add_callback("on_batch_end", print_loss)

64
48
number of parameters: 2.51M
running on device mps


In [10]:
trainer.run()

Batch: 0, Loss: 10.834650039672852
Batch: 2, Loss: 10.808938980102539
Batch: 4, Loss: 10.776081085205078
Batch: 6, Loss: 10.738039016723633
Batch: 8, Loss: 10.701078414916992
