This is the entry point for the project.
The choice of .ipynb it's for speeding up the training using Google Colab or Kaggle machine.

## Dataset loading
In the following code block we instanciate the dataset and use some method for analyse it and preprocess it.

### Comments 
- as of now it's a little bit too verbose the analysis, consider adding a verbose parameter in the dataset_analysis method

In [9]:
# Loading the dataset
import torch
from dataset import CharDataset
from torch.utils.data import DataLoader, random_split

torch.manual_seed(0)

try: 
    with open('dataset/dataset.txt', 'r') as file:
        text = file.read()
except FileNotFoundError:
    with open('data/dataset.txt', 'r') as file:
        text = file.read()

block_size = 128 # spatial extent of the model for its context, so it need to be used in the dataset

# Create dataset
dataset = CharDataset(text, block_size)

# Analyze original dataset
dataset.dataset_analysis()

# Remove less frequent characters and analyze the dataset
dataset.remove_less_frequent_chars(10).dataset_analysis()

dataset.preprocess()

if 'train_dataset' in globals():
    train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)
else:
    loader = DataLoader(dataset, batch_size=2048, shuffle=True)

### make this work in the future, as of now just use the whole dataset for training
# # Split the dataset into training and validation sets
# percentage = 0.9
# split_point = int(len(dataset) * percentage)

# # Slice the text for training and validation
# train_text = text[:split_point]
# val_text = text[split_point:]

# # doing stupid split for now
# train_dataset = CharDataset(train_text, block_size)
# val_dataset = CharDataset(val_text, block_size)

# # Instantiate DataLoader objects
# train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False)

INFO:CharDataset:### Dataset Analysis ###
INFO:CharDataset:### Comparison with Original Dataset ###
INFO:CharDataset:Original Total Characters: 1115394
INFO:CharDataset:Current Total Characters: 1115394
INFO:CharDataset:Original Unique Characters: 65
INFO:CharDataset:Current Unique Characters: 65
INFO:CharDataset:### Detailed Current Dataset Analysis ###
INFO:CharDataset:### Characters and Frequencies ###
INFO:CharDataset:'␣': 169892
INFO:CharDataset:'e': 94611
INFO:CharDataset:'t': 67009
INFO:CharDataset:'o': 65798
INFO:CharDataset:'a': 55507
INFO:CharDataset:'h': 51310
INFO:CharDataset:'s': 49696
INFO:CharDataset:'r': 48889
INFO:CharDataset:'n': 48529
INFO:CharDataset:'i': 45537
INFO:CharDataset:'\n': 40000
INFO:CharDataset:'l': 33339
INFO:CharDataset:'d': 31358
INFO:CharDataset:'u': 26584
INFO:CharDataset:'m': 22243
INFO:CharDataset:'y': 20448
INFO:CharDataset:',': 19846
INFO:CharDataset:'w': 17585
INFO:CharDataset:'f': 15770
INFO:CharDataset:'c': 15623
INFO:CharDataset:'g': 13356
I

In [19]:
max_index = len(loader)
# try the datloader for some random indexes

for i, (x, y) in enumerate(loader):
    print(x[0])
    print(y[0])
    break


tensor([[32, 30, 27,  ..., 18,  1, 37],
        [27, 34, 17,  ..., 30,  1, 26],
        [23, 17,  1,  ..., 17,  1, 32],
        ...,
        [14, 37,  1,  ..., 37,  6,  0],
        [30, 13, 24,  ...,  0, 23, 21],
        [ 1, 27, 18,  ..., 27, 18,  1]])
tensor([[30, 27, 33,  ...,  1, 37, 27],
        [34, 17, 30,  ...,  1, 26, 27],
        [17,  1, 34,  ...,  1, 32, 27],
        ...,
        [37,  1, 20,  ...,  6,  0, 32],
        [13, 24,  1,  ..., 23, 21, 26],
        [27, 18,  1,  ..., 18,  1, 25]])


## Instanciate the model
Here we set the least amount of parameter needed for the model and call the model summary method

### comments
- as of now the model summary doesn't work (TODO)

In [10]:
from model import CharTransformer

# usually should be train_dataset, but for now just use the whole dataset for training
if 'train_dataset' in globals():
    vocabulary_size = train_dataset.vocabulary_size
else:
    vocabulary_size = dataset.vocabulary_size

block_size # Sequence length defined earlier as block_size
embedding_dim = 128  # Embedding dimensions
num_heads = 8  # Number of attention heads
num_layers = 2  # Number of transformer blocks

# Initialize the model
model = CharTransformer(vocabulary_size, block_size, embedding_dim, num_heads, num_layers, ff_hid_dim=256)

# Display the model summary
model.summary()

INFO:CharTransformer:Model initialized
INFO:CharTransformer:### Model summary###
Layer (type:depth-idx)                        Output Shape              Param #
CharTransformerSummaryWrapper                 [32, 128, 39]             --
├─CharTransformer: 1-1                        [32, 128, 39]             16,384
│    └─Embedding: 2-1                         [32, 128, 128]            4,992
│    └─Dropout: 2-2                           [32, 128, 128]            --
│    └─ModuleList: 2-3                        --                        --
│    │    └─TransformerBlock: 3-1             [32, 128, 128]            132,480
│    │    └─TransformerBlock: 3-2             [32, 128, 128]            132,480
│    └─LayerNorm: 2-4                         [32, 128, 128]            256
│    └─Linear: 2-5                            [32, 128, 39]             5,031
Total params: 291,623
Trainable params: 291,623
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 8.81
Input size (MB): 0.03
Forward/b

## Training of the model

In [11]:
from train import Trainer
from torch import optim
from torch import nn

# Set optimizer and loss function   
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

# Initialize the trainer
trainer = Trainer(model, optimizer, loss_fn, loader)

# Train the model
# trainer.train()

## Generate text

In [18]:
from generate import TextGenerator

generator = TextGenerator(model, dataset)

generator.load_model("models/model_checkpoint_0_kaggle.pth")
# Generate text
generated_text = generator.generate("Oh god oh god",length = 100, temperature=0.7)

print(generated_text)

Model loaded from models/model_checkpoint_0_kaggle.pth
Oh god oh gododododoooooooooooooooooooooooooooooooooooooooooooodooooooooooooooooooooooooooooooooooooooooooooooooo
