# Initialization

In [None]:
%load_ext autoreload
%autoreload 2

from jlib.transformer_char_predictor import TransformerCharPredictor
import jlib.data_utils as data_utils
import torch
import torch.nn as nn
import numpy as np
from torchprofile import profile_macs
text = data_utils.get_text('data/shakespeare.txt', redownload=False)
print(text[:50])


# SeqLen 10

In [None]:
seqlen = 10
def train_and_plot(seqlen: int):
    data = data_utils.gen_datasets(text, seqlen)
    train_data = data['train_dataset']
    val_data = data['val_dataset']
    alphabet: data_utils.Alphabet = data['alphabet']

    train_fetcher = data_utils.gen_data_loader(
        train_data,
        batch_size=len(train_data)//128,
        workers = 6,
        cpu_prefetch= 10,
        gpu_prefetch=10
    )

    val_fetcher = data_utils.gen_data_loader(
        val_data,
        batch_size=len(val_data)//32,
        workers = 6,
        cpu_prefetch= 10,
        gpu_prefetch=10
    )

    # model

    model = TransformerCharPredictor(
        alphabet_size = len(alphabet),
        max_len = seqlen,
        hidden_dim = 128,
        inner_dim = 2048,
        num_attn_heads = 2,
        num_attn_layers=3,
        cls_head_dims=[],
        dropout = 0.1
    )

    param_count = sum(p.numel() for p in model.parameters())
    print(f"Model parameter count: {param_count:,}")
    
    

    # test_input = next(iter(train_fetcher))[0]

    # print(f"Model MACs: {profile_macs(model, test_input):,}")
    
# Model parameter count: 1,790,380
# Model MACs: 568,279,052



    model.train_model(
        epochs=100,
        train_fetcher=train_fetcher,
        val_fetcher=val_fetcher,
        optimizer = torch.optim.Adam,
        optimizer_kwargs={
            'lr': 3e-3,
            'betas': (0.9, 0.98),
            'eps': 1e-9,
            'weight_decay': 1e-5
        },
        min_accuracy=1,
        max_negative_diff_count=10,
        save_path=f'models/p2-{seqlen}.pth',
        stop_on_plateau=True,
    )

    fig = model.plot_training(f'Shakespeare Corpus, Sequence Length {seqlen}')
    fig.savefig(f'latex/images/p2-{seqlen}.png')
    
    del train_fetcher, val_fetcher, train_data, val_data, data, model, alphabet
    





: 

In [None]:
train_and_plot(10)



Begin init data loader
Batch Size: 1.063690185546875 MiB
Data Loader init time: 6.452997 s
Begin init fetcher
Fetcher init time: 6.623583 s
Begin init data loader
Batch Size: 2.12738037109375 MiB
Data Loader init time: 8.337456 s
Begin init fetcher
Fetcher init time: 8.595949 s
Model parameter count: 939
Training TransformerCharPredictor

----------------------------------------------------------------------------------------------------------------------------------------------------
Begin Training
|       Epoch        |   Epoch Time (s)   |   Training Loss    |  Validation Loss   |Validation Accuracy |   Δ Accuracy (%)   |    Memory Usage    |
----------------------------------------------------------------------------------------------------------------------------------------------------
|         0          |     11.226316      |      4.001647      |      3.759078      |     15.314308      |      0.000000      |      0.799158      |
------------------------------------------------

In [None]:
train_and_plot(20)


In [None]:
train_and_plot(30)