# Import Required Modules

In [None]:
import torch
import os
from model.LLMModel import LLMModel
from model.LLMLayer import LLMLayer
from utils.tokenizer import get_tokenizer
from utils.training_data_generator import TrainingDataGen
from utils.trainer import Trainer
from utils.read_config import get_config

In [None]:
config = get_config()

# Create Tokenizer object

In [None]:
tokenizer = get_tokenizer()

# Initialize Model

In [None]:

embedding_dimension = 512
max_sequence_length = 2048
number_of_tokens = tokenizer.get_vocab_size()
batch_size = 3
data_gen_config_path = config['checkpoint_path']
checkpoint_files = os.listdir(config['checkpoint_path'])
data_gen_config_path = config['checkpoint_path']
model_checkpoint_path = None
for file in checkpoint_files:
    if file.find(".pt") > -1:
        model_checkpoint_path = config['checkpoint_path']+file

# Create the model
model = LLMModel(LLMLayer(
    embedding_dimension=embedding_dimension,
    number_of_tokens=number_of_tokens,
    number_of_heads=4,
    number_of_layers=64,
    decoder_stacks=4,
    dropout_rate=0.1,
    max_sequence_length=max_sequence_length,
    feed_forward_dimension=4*embedding_dimension,
    training = True
))

optimizer = torch.optim.SGD(model.parameters(), lr=0.2)


# Initialize the Data Generator

In [None]:
if data_gen_config_path:
    if os.path.exists(os.path.join(data_gen_config_path,config['data_genrerator_config_name'])):
        training_data_generator = TrainingDataGen(config_path=data_gen_config_path)
    else:
        training_data_generator = TrainingDataGen(base_path=config['data_path'])
else:
    training_data_generator = TrainingDataGen(base_path=config['data_path'])

# Calculate Model Parameters

In [None]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(int(trainable_params)/pow(10,9))

# Start Training

In [None]:
trainer = Trainer(model,training_data_generator, tokenizer, optimizer, max_sequence_length, model_checkpoint_path, batch_size)
trainer.train(epochs=100, batch_size=batch_size)