In [2]:
import os
import sys
import pickle
import shutil
import torch

sys.path.append('../')

from accelerate import Accelerator, notebook_launcher
from dataset import EpisodeDataset, collate_fn
from model import Config, GPTModel
from trainer import train_model, validate_model
from torch.utils.data import DataLoader

In [3]:
token_to_idx = {(i, j): i * 9 + j + 1 for i in range(9) for j in range(9)} | {"up": 82, "down": 83, "left": 84, "right": 85}

In [4]:
token_to_idx['<pad>'] = 0  # Padding token

vocab_size = 86
block_size = 200
embed_size = 512
num_heads = 8
num_layers = 8
dropout = 0.1

In [5]:
path = ''

In [6]:
with open(os.path.join(path, 'train00.pkl'), 'rb') as f:
    agent00 = pickle.load(f)
with open(os.path.join(path, 'train08.pkl'), 'rb') as f:
    agent08 = pickle.load(f)
with open(os.path.join(path, 'train80.pkl'), 'rb') as f:
    agent80 = pickle.load(f)
with open(os.path.join(path, 'train88.pkl'), 'rb') as f:
    agent88 = pickle.load(f)

In [7]:
train_ratio = 0.8
valid_ratio = 0.1

d00 = len(agent00)
d08 = len(agent08)
d80 = len(agent80)
d88 = len(agent88)

train00 = agent00[:int(train_ratio * d00)]
valid00 = agent00[int(train_ratio * d00):int((train_ratio + valid_ratio) * d00) ]
test00 = agent00[int((train_ratio + valid_ratio) * d00): ]

train08 = agent08[:int(train_ratio * d08)]
valid08 = agent08[int(train_ratio * d08):int((train_ratio + valid_ratio) * d08) ]
test08 = agent08[int((train_ratio + valid_ratio) * d08): ]

train80 = agent80[:int(train_ratio * d80)]
valid80 = agent80[int(train_ratio * d80):int((train_ratio + valid_ratio) * d80) ]
test80 = agent80[int((train_ratio + valid_ratio) * d80): ]

train88 = agent88[:int(train_ratio * d88)]
valid88 = agent88[int(train_ratio * d88):int((train_ratio + valid_ratio) * d88) ]
test88 = agent88[int((train_ratio + valid_ratio) * d88): ]

In [8]:
train = train00 + train08 + train80 + train88
valid = valid00 + valid08 + valid80 + valid88
test = test00 + test08 + test80 + test88

print(len(train))
print(len(valid))
print(len(test))

1588602
198575
198577


In [9]:
train_dataset = EpisodeDataset(train, token_to_idx)
valid_dataset = EpisodeDataset(valid, token_to_idx)
test_dataset = EpisodeDataset(test, token_to_idx)

In [10]:
def train_main(save_directory = None, epochs = 15):
    
    accelerator = Accelerator()

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

    config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size)
    model = GPTModel(config)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)
    
    train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(train_loader, valid_loader, model, scheduler, optimizer)

    epoch = 0

    model_path = None
    min_loss = 1e10
    
    for epoch in range(epochs):
        accelerator.print(f'Epoch {epoch}')

        train_model(model, train_loader, optimizer, accelerator)
        valid_loss = validate_model(model, valid_loader, accelerator)
        scheduler.step()

        if accelerator.is_main_process:
            print(f'Validation Loss: {valid_loss:.8f}')

            model_save_path = f"Model_{epoch+1}.pth"
            accelerator.save(accelerator.unwrap_model(model).state_dict(), model_save_path)

            if valid_loss < min_loss:
                min_loss = valid_loss
                model_path = model_save_path

        accelerator.wait_for_everyone()

    if accelerator.is_main_process:
        shutil.copy(model_path, save_directory)


In [12]:
notebook_launcher(train_main, (os.path.join(path, 'best_model'), 15), num_processes = 4)

Launching training on 4 GPUs.
Epoch 0


Training: 100%|██████████| 12411/12411 [21:04<00:00,  9.81it/s]





Training Loss: 0.5037913918495178
Validation Loss: 0.47267327
Epoch 1

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:05<00:00,  9.80it/s]


Training: 100%|██████████| 12411/12411 [21:05<00:00,  9.80it/s]


Training Loss: 0.46917539834976196
Validation Loss: 0.46782213
Epoch 2

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:06<00:00,  9.80it/s]





Training Loss: 0.46598953008651733
Validation Loss: 0.46719846
Epoch 3

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:02<00:00,  9.83it/s]





Training Loss: 0.4646003246307373
Validation Loss: 0.46600494
Epoch 4

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:01<00:00,  9.84it/s]

Training: 100%|██████████| 12411/12411 [21:01<00:00,  9.84it/s]



Training Loss: 0.46415749192237854
Validation Loss: 0.46596977
Epoch 5

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:04<00:00,  9.82it/s]
Training: 100%|██████████| 12411/12411 [21:04<00:00,  9.82it/s]




Training Loss: 0.46470674872398376
Validation Loss: 0.46652344
Epoch 6

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:02<00:00,  9.83it/s]





Training Loss: 0.465568482875824
Validation Loss: 0.46668565
Epoch 7

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:03<00:00,  9.83it/s]





Training Loss: 0.4658709466457367
Validation Loss: 0.46701115
Epoch 8

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:02<00:00,  9.83it/s]





Training Loss: 0.46543562412261963
Validation Loss: 0.46701172
Epoch 9

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:03<00:00,  9.82it/s]

Training: 100%|██████████| 12411/12411 [21:03<00:00,  9.82it/s]



Training Loss: 0.4645695388317108
Validation Loss: 0.46618474
Epoch 10

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:03<00:00,  9.83it/s]





Training Loss: 0.46356722712516785
Validation Loss: 0.46592587
Epoch 11

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:01<00:00,  8.37it/s]

Training: 100%|██████████| 12411/12411 [21:01<00:00,  9.84it/s]



Training Loss: 0.4629788100719452
Validation Loss: 0.46569046
Epoch 12

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:03<00:00,  9.82it/s]

Training: 100%|██████████| 12411/12411 [21:03<00:00,  9.82it/s]



Training Loss: 0.4629948139190674
Validation Loss: 0.46567193
Epoch 13

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:02<00:00,  9.83it/s]





Training Loss: 0.4636186361312866
Validation Loss: 0.46616063
Epoch 14

Training:   0%|          | 0/12411 [00:00<?, ?it/s]




Training: 100%|██████████| 12411/12411 [21:02<00:00,  9.83it/s]





Training Loss: 0.46427425742149353
Validation Loss: 0.46617603
