# Initialization

In [4]:
%load_ext autoreload
%autoreload 2

from jlib.transformer_char_predictor import TransformerCharPredictor
import jlib.data_utils as data_utils
import torch
import torch.nn as nn
import numpy as np
from torchprofile import profile_macs
text = ""
with open('data/sequence.txt', 'r') as f:
    text = f.read()


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# SeqLen 10

In [None]:
seqlen = 10
def train_and_plot(seqlen: int):
    data = data_utils.gen_datasets(text, seqlen)
    train_data = data['train_dataset']
    val_data = data['val_dataset']
    alphabet: data_utils.Alphabet = data['alphabet']

    train_fetcher = data_utils.gen_data_loader(
        train_data,
        batch_size=32,
        workers = 6,
        cpu_prefetch= 20,
        gpu_prefetch=20
    )

    val_fetcher = data_utils.gen_data_loader(
        val_data,
        batch_size=len(val_data),
        workers = 6,
        cpu_prefetch= 10,
        gpu_prefetch=10
    )

    # model

    model = TransformerCharPredictor(
        alphabet_size = len(alphabet),
        max_len = seqlen,
        hidden_dim = 1024,
        inner_dim = 4096,
        num_attn_heads = 8,
        num_attn_layers=6,
        cls_head_dims=[1024, 512],
        dropout = 0.1
    )

    param_count = sum(p.numel() for p in model.parameters())
    print(f"Model parameter count: {param_count:,}")

    # test_input = next(iter(train_fetcher))[0]

    # print(f"Model MACs: {profile_macs(model, test_input):,}")
    
#Model parameter count: 77,219,372
#Model MACs: 49,602,887,704



    model.train_model(
        epochs=100,
        train_fetcher=train_fetcher,
        val_fetcher=val_fetcher,
        optimizer = torch.optim.Adam,
        optimizer_kwargs={
            'lr': 3e-4,
            'betas': (0.9, 0.98),
            'eps': 1e-9,
            'weight_decay': 1e-5
        },
        min_accuracy=1,
        max_negative_diff_count=10,
        save_path=f'models/p1-{seqlen}.pth',
        stop_on_plateau=True
    )

    model.plot_training(f'Small Corpus, Sequece Length {seqlen}')





In [6]:
train_and_plot(20)


Begin init data loader
Batch Size: 0.0048828125 MiB
Data Loader init time: 0.122642 s
Begin init fetcher
Fetcher init time: 0.216636 s
Begin init data loader
Batch Size: 0.07232666015625 MiB
Data Loader init time: 0.125589 s
Begin init fetcher
Fetcher init time: 0.213814 s
Model parameter count: 77,219,372
Training TransformerCharPredictor

----------------------------------------------------------------------------------------------------------------------------------------------------
|       Epoch        |   Epoch Time (s)   |   Training Loss    |  Validation Loss   |Validation Accuracy |   Δ Accuracy (%)   |    Memory Usage    |
----------------------------------------------------------------------------------------------------------------------------------------------------
|         0          |      3.212504      |      2.928128      |      2.579155      |     27.911392      |      0.000000      |      5.393523      |
-------------------------------------------------------------

KeyboardInterrupt: 