In [1]:
from dataclasses import dataclass
from Model import PashkoModel
import os
import torch
import time

In [2]:
@dataclass
class PashkoModelConfig:
    sequence_length: int = 1024
    vocab_size: int = 50304
    embed_dim: int = 768

    encoder = 'gpt2'

    num_heads: int = 12
    num_blocks: int = 12

    dropout: float = 0.0

    ffnn_bias: bool = False
    qkv_bias: bool = False
    layernorm_bias: bool = False

    topK: int = 10
    temperature: float = 1.0

@dataclass
class PashkoTrainConfig:
    batch_size: int = 64

    learning_rate: float = 0.001
    betas: tuple = (0.9, 0.999)
    weight_decay: float = 0.01
        
    max_iterations: int = 600000

In [3]:
modelConfig = PashkoModelConfig()
trainConfig = PashkoTrainConfig()
iter_num = 0
ckpt_num = 0
best_val_loss = 1e9

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
print("Type of initialisation...", end=' ', flush=True)
init_type = input()
print(init_type)

if init_type == "scratch":
    Pashko = PashkoModel(modelConfig)
    print(f"Initialised new model with {Pashko.num_params()[0]} parameters")

    optimiser = torch.optim.AdamW(Pashko.parameters(),
                              lr=trainConfig.learning_rate,
                              betas=trainConfig.betas,
                              weight_decay=trainConfig.weight_decay)
    
    print("Initialised optimiser using trainConfig")

    checkpoint = {
                    'model': Pashko.state_dict(),
                    'optimiser': optimiser.state_dict(),
                    'iter_num': iter_num,
                    'ckpt_num': ckpt_num,
                    'best_val_loss': best_val_loss,
                    'model_config': modelConfig,
                    'train_config': trainConfig,
                }
    
    torch.save(checkpoint, os.path.join('checkpoints', f'ckpt{ckpt_num}.pashko'))

    print(f"Saved initial checkpoint as ckpt{ckpt_num}.pt")

    print("Initialisation complete")

elif init_type == "resume":
    print("Checkpoint name...", end=' ', flush=True)
    ckpt_name = input()
    print(ckpt_name)

    ckpt_path = os.path.join('checkpoints', f'{ckpt_name}.pashko')
    ckpt = torch.load(ckpt_path, map_location=device)

    modelConfig = ckpt['model_config']
    trainConfig = ckpt['train_config']

    Pashko = PashkoModel(modelConfig)
    Pashko.load_state_dict(ckpt['model'])

    print(f"Loaded model from checkpoint {ckpt_name}")

    optimiser = torch.optim.AdamW(Pashko.parameters(),
                              lr=trainConfig.learning_rate,
                              betas=trainConfig.betas,
                              weight_decay=trainConfig.weight_decay)
    
    optimiser.load_state_dict(ckpt['optimiser'])

    print(f"Loaded optimiser from checkpoint {ckpt_name}")

    iter_num = ckpt['iter_num']
    ckpt_num = ckpt['ckpt_num']
    best_val_loss = ckpt['best_val_loss']

    print("Initialisation complete")

Type of initialisation... scratch
Initialised new model with 123.60M parameters
Initialised optimiser using trainConfig
Saved initial checkpoint as ckpt0.pt
Initialisation complete


In [None]:
checkpoint = {
                    'model': Pashko.state_dict(),
                    'optimiser': optimiser.state_dict(),
                    'iter_num': iter_num,
                    'ckpt_num': ckpt_num,
                    'best_val_loss': best_val_loss,
                    'model_config': modelConfig,
                    'train_config': trainConfig,
                }

In [None]:
t0 = time.time()
x, targets = get_batch('train')
while True:
    