In [2]:
import os
import sys
import pickle
import shutil
import torch

sys.path.append('../')

from accelerate import Accelerator, notebook_launcher
from dataset import GameDataset, collate_fn
from model import Config, GPTModel
from trainer import train_model, validate_model
from torch.utils.data import DataLoader

In [3]:
token_to_idx = {i: i + 1 for i in range(7)}

In [4]:
token_to_idx['<pad>'] = 0  # Padding token

vocab_size = 8
block_size = 42
embed_size = 512
num_heads = 8
num_layers = 8
dropout = 0.1

In [11]:
path = ''

In [12]:
with open(os.path.join(path, 'TrainingPairs.pkl'), 'rb') as f:
    training_data = pickle.load(f)

In [13]:
d = len(training_data)

train_ratio = 0.8
valid_ratio = 0.1

train = training_data[:int(train_ratio * d)]
valid = training_data[int(train_ratio * d) : int((train_ratio + valid_ratio) * d)]
test = training_data[int((train_ratio + valid_ratio) * d):]

In [14]:
train_dataset = GameDataset(train, token_to_idx)
valid_dataset = GameDataset(valid, token_to_idx)
test_dataset = GameDataset(test, token_to_idx)

In [15]:
def train_main(save_directory = None, epochs = 20):
    
    accelerator = Accelerator()

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

    config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size)
    model = GPTModel(config)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)
    
    train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(train_loader, valid_loader, model, scheduler, optimizer)

    epoch = 0

    model_path = None
    min_loss = 1e10
    
    for epoch in range(epochs):
        accelerator.print(f'Epoch {epoch}')

        train_model(model, train_loader, optimizer, accelerator)
        valid_loss = validate_model(model, valid_loader, accelerator)
        scheduler.step()

        if accelerator.is_main_process:
            print(f'Validation Loss: {valid_loss:.4f}')

            model_save_path = f"Model_{epoch+1}.pth"
            accelerator.save(accelerator.unwrap_model(model).state_dict(), model_save_path)

            if valid_loss < min_loss:
                min_loss = valid_loss
                model_path = model_save_path

        accelerator.wait_for_everyone()

    if accelerator.is_main_process:
        shutil.copy(model_path, save_directory)


In [16]:
notebook_launcher(train_main, (os.path.join(path, 'best_model'), 20), num_processes = 8)

Launching training on 8 GPUs.
Epoch 0

Training:   0%|          | 0/3125 [00:00<?, ?it/s]




Training: 100%|██████████| 3125/3125 [01:15<00:00, 41.27it/s]






Training: 100%|██████████| 3125/3125 [01:15<00:00, 41.27it/s]


Training Loss: 1.6878
Validation Loss: 1.6127
Epoch 1


Training: 100%|██████████| 3125/3125 [01:12<00:00, 43.34it/s]






Training: 100%|██████████| 3125/3125 [01:12<00:00, 43.33it/s]


Training Loss: 1.6114
Validation Loss: 1.5922
Epoch 2


Training: 100%|██████████| 3125/3125 [01:11<00:00, 43.91it/s]






Training: 100%|██████████| 3125/3125 [01:11<00:00, 43.91it/s]


Training Loss: 1.5927
Validation Loss: 1.5827
Epoch 3


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.74it/s]


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.74it/s]



Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.74it/s]


Training Loss: 1.5890
Validation Loss: 1.5798
Epoch 4


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.68it/s]






Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.67it/s]


Training Loss: 1.5911
Validation Loss: 1.5771
Epoch 5


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.73it/s]






Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.72it/s]


Training Loss: 1.5842
Validation Loss: 1.5711
Epoch 6


Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.46it/s]




Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.46it/s]
Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.46it/s]



Training Loss: 1.5716
Validation Loss: 1.5641
Epoch 7


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.79it/s]






Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.78it/s]


Training Loss: 1.5614
Validation Loss: 1.5588
Epoch 8


Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.58it/s]





Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.58it/s]
Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.58it/s]


Training Loss: 1.5596
Validation Loss: 1.5578
Epoch 9


Training: 100%|██████████| 3125/3125 [01:11<00:00, 43.77it/s]

Training: 100%|██████████| 3125/3125 [01:11<00:00, 43.77it/s]



Training: 100%|██████████| 3125/3125 [01:11<00:00, 31.89it/s]
Training: 100%|██████████| 3125/3125 [01:11<00:00, 43.76it/s]


Training Loss: 1.5648
Validation Loss: 1.5607
Epoch 10


Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.19it/s]





Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.18it/s]
Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.18it/s]


Training Loss: 1.5648
Validation Loss: 1.5601
Epoch 11


Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.37it/s]






Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.37it/s]


Training Loss: 1.5576
Validation Loss: 1.5571
Epoch 12


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.85it/s]






Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.84it/s]


Training Loss: 1.5499
Validation Loss: 1.5529
Epoch 13


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.88it/s]
Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.88it/s]
Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.88it/s]




Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.88it/s]


Training Loss: 1.5484
Validation Loss: 1.5524
Epoch 14


Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.58it/s]






Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.58it/s]


Training Loss: 1.5540
Validation Loss: 1.5553
Epoch 15


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.81it/s]




Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.80it/s]

Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.80it/s]


Training Loss: 1.5557
Validation Loss: 1.5559
Epoch 16


Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.81it/s]






Training: 100%|██████████| 3125/3125 [01:09<00:00, 44.80it/s]


Training Loss: 1.5499
Validation Loss: 1.5537
Epoch 17


Training: 100%|██████████| 3125/3125 [01:11<00:00, 43.98it/s]





Training: 100%|██████████| 3125/3125 [01:11<00:00, 43.98it/s]



Training Loss: 1.5427
Validation Loss: 1.5513
Epoch 18


Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.58it/s]






Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.58it/s]


Training Loss: 1.5412
Validation Loss: 1.5518
Epoch 19

Training:   0%|          | 0/3125 [00:00<?, ?it/s]




Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.12it/s]






Training: 100%|██████████| 3125/3125 [01:10<00:00, 44.11it/s]


Training Loss: 1.5469
Validation Loss: 1.5536
