In [None]:
"""
TODO:
    - implement a layer that pads char embeddings to match word model's output to reduce gradients calculations and increase training speed
      (word embeddings need to be larger than char embeddings cause char embeddings don't need to be large to be represented in the model)
            Possible solutions:
                1. nn.rnn.pad_sequence can be used
                2. custom padding logic can be implemented using nn.functional.pad    
"""

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
import torch.sparse
import os
import gc

In [None]:
DATASET_PATH = "wikisent2.txt"
BATCH_SIZE = 128
MODEL_SAVE_PATH = "pytorch_model_saves"
VOCABULARY_SAVE_PATH = "pytorch_vocab_saves"

In [None]:
#is cuda available
print("is cuda available: ", torch.cuda.is_available())
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("cuda cache cleared")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch_utils_module import LazyCustomDataset

dataset = LazyCustomDataset(DATASET_PATH, random_data=True, max_seq_len=25)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
from torch_utils_module import CustomDatasetForTextVectorizer as CD
from torch_utils_module import TextVectorizer

In [None]:
# Create instances of the TextVectorization class
char_encoder = TextVectorizer(max_tokens=None, lower=True, strip_punctuation=False)
word_encoder = TextVectorizer(max_tokens=None, split=" ", lower=True, strip_punctuation=True)

#old_wiki2
char_vocabulary_save_path = VOCABULARY_SAVE_PATH + "/reduced_chars_new_full.pth"
word_vocabulary_save_path = VOCABULARY_SAVE_PATH + "/reduced_words_new_full.pth"

if os.path.exists(char_vocabulary_save_path):
    char_encoder.load_vocabulary(char_vocabulary_save_path)
    print("Char vocabulary loaded successfully")
else:
    dataset_for_encoder = CD("wikisent2.txt")
    print("Adapting vocabulary for chars")
    char_encoder.adapt(dataset_for_encoder, on_labels=True)
    char_encoder.prune_vocab(250)
    del dataset_for_encoder
    char_encoder.save_vocabulary(char_vocabulary_save_path)

if os.path.exists(word_vocabulary_save_path):
    word_encoder.load_vocabulary(word_vocabulary_save_path)
    print("Word vocabulary loaded successfully")
else:
    dataset_for_encoder = CD("wikisent2.txt")
    print("Adapting vocabulary for words")
    word_encoder.adapt(dataset_for_encoder, on_labels=True)
    word_encoder.prune_vocab(250)
    del dataset_for_encoder
    word_encoder.save_vocabulary(word_vocabulary_save_path)

# Get vocabulary sizes
max_vocab_size_char = len(char_encoder.get_vocabulary()[0])
max_vocab_size_word = len(word_encoder.get_vocabulary()[0])

In [None]:
# Check if vocabulary sizes are correct
max(word_encoder.get_vocabulary()[0].values())+1 == len(word_encoder.get_vocabulary()[0].values()), max(char_encoder.get_vocabulary()[0].values())+1 == len(char_encoder.get_vocabulary()[0].values())

In [None]:
max_vocab_size_char, max_vocab_size_word

In [None]:
from torch_utils_module import Model

# Model hyperparameters
word_embed_dim = 256
char_embed_dim = 128
word_rnn_dim, word_rnn_layers = 256, 2
char_rnn_dim, char_rnn_layers = 256, 2
word_bidirectional, char_bidirectional = False, False
word_dense_dims, char_dense_dims = [128, 128], []

model_word = Model(
    in_embedding_dim=word_embed_dim,
    pretrained_embedding_path=None,
    out_embedding_dim=char_embed_dim,
    rnn_dim=word_rnn_dim,
    num_rnn_layers=word_rnn_layers,
    rnn_dropout=0,
    bidirectional=word_bidirectional,
    dense_dims=word_dense_dims,
    vocab_size=max_vocab_size_word,
    mode="word"
)
model_char = Model(
    in_embedding_dim=char_embed_dim,
    out_embedding_dim=None,
    rnn_dim=char_rnn_dim,
    num_rnn_layers=char_rnn_layers,
    rnn_dropout=0,
    bidirectional=char_bidirectional,
    dense_dims=char_dense_dims,
    vocab_size=max_vocab_size_char,
    mode="char"
)

In [None]:
# Check the inputs
for data in dataloader:
    inputs = data[0]
    print("Processing batch")
    print("Batch data:", data)
    print("Encoded inputs:", word_encoder(inputs))
    break

In [None]:
for i, batch in enumerate(dataloader):

    print("batch")
    print(batch[0])
    print(batch[1])

    # Encode the word inputs
    inputs_W = word_encoder(batch[0])
    inputs_W = nn.utils.rnn.pad_sequence(inputs_W, batch_first=True)
    inputs_W = inputs_W.to(dtype=torch.int)

    # Encode the char inputs
    inputs_1 = char_encoder(batch[1])
    inputs_1 = nn.utils.rnn.pad_sequence(inputs_1, batch_first=True)
    inputs_1 = inputs_1.to(dtype=torch.int)

    # Run the models to check for errors
    print("Word Model:")
    print("inputs", "inputs_W", inputs_W.shape)
    outputs_= model_word(inputs_W, state="training")
    print('outputs_')
    print('Training:', 'outputs_', outputs_.shape)
    outputs__ = model_word(inputs_W[0], state="inference")
    print('Inference:', 'outputs__', outputs__.shape)

    print("Char Model:")
    print("inputs", "inputs_1", inputs_1.shape)
    outputs1 = model_char(inputs_1, word_embed_info=outputs_, state="training")
    print('Training:', "outputs1", outputs1.shape)
    outputs11 = model_char(inputs_1[0], word_embed_info=outputs__, state="inference")
    print('Inference:', "outputs11", outputs11.shape)
    
    # Run the models to get model inner shapes
    model_word(inputs_W, state="training", debug=True)
    model_word(inputs_W[0], state="training", debug=True)
    model_word(inputs_W, state="inference", debug=True)
    model_word(inputs_W[0], state="inference", debug=True)
    
    model_char(inputs_1, state="training", debug=True)
    model_char(inputs_1[0], state="training", debug=True)
    model_char(inputs_1, state="inference", debug=True)
    model_char(inputs_1[0], state="inference", debug=True)
    break

In [None]:
# Verify shapes
outputs_.shape, outputs__.shape, outputs1.shape, outputs11.shape

In [None]:
# Free up memory
del inputs_W
del inputs_1
del outputs_
del outputs__
del outputs1
del outputs11

In [None]:
# Train function
def train(model_w, model_c, dataloader, epoch, encoder_W, encoder_C, loss_function, optimizer_W, optimizer_C, device, word_save_path, char_save_path, start_train_at_path, start_train_at):
    try:
        model_w.to(device)
        model_c.to(device)
    except:
        pass
    if start_train_at == 0:
        total_loss = 0
    else:
        total_loss = torch.load(start_train_at_path.replace(".pt", "_loss.pt"))
    loss_plot = []
    collected = 0

    pbar = tqdm(dataloader, desc=f"Training Epoch {epoch+1}", total=len(dataloader), mininterval=5)
    i = 1
    for inputs_W, targets in pbar:
        # skip to 'start_train_at' parameter given by user to resume training
        if i <= start_train_at:
            i += 1
            continue
        try:
            # Encode the word inputs
            inputs_W = encoder_W(inputs_W)
            inputs_C = encoder_C(list(map(lambda x: x[:-1], targets)))
            targets = encoder_C(list(map(lambda x:  x[1:], targets)))

            # Flip the inputs to be right-to-left for word variable so that we can apply pre-padding
            inputs_W = list(map(lambda x: x.flip(0), inputs_W))
            
            # Pad the inputs and flip word variable to be right-to-left : [1, 2, 3, 0, 0] -> [0, 0, 3, 2, 1]
            inputs_W = nn.utils.rnn.pad_sequence(inputs_W, batch_first=True, padding_value=encoder_W.vocabulary["<pad>"]).flip(-1).to(device, dtype=torch.int)
            inputs_C = nn.utils.rnn.pad_sequence(inputs_C, batch_first=True, padding_value=encoder_C.vocabulary["<pad>"]).to(device, dtype=torch.int)
            targets = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=encoder_C.vocabulary["<pad>"]).to(device, dtype=torch.long)
            
            optimizer_W.zero_grad(set_to_none=True)
            optimizer_C.zero_grad(set_to_none=True)

            word_info = model_w(inputs_W, state="training")
            outputs = model_c(inputs_C, word_embed_info=word_info, state="training")

            # Loss calculation
            losses = torch.tensor(0, dtype=torch.float, device=device)
            output = outputs.view(-1, outputs.shape[-1])
            target = targets.view(-1)
            losses = torch.masked_select(loss_function(output, target), torch.ne(target, 0)).mean()    # ne means not equal

            # Backpropagation
            losses.backward()
            optimizer_W.step()
            optimizer_C.step()

            # calculate total_loss for later use
            total_loss += losses.detach().cpu().item()
        except KeyboardInterrupt:
            plt.plot(loss_plot)
            plt.show()
            raise KeyboardInterrupt
        except Exception as e:
            plt.plot(loss_plot)
            plt.show()
            print("error in training:")
            a = 10 / 0

        try:
            if i % 1000 == 0:
                # Validation

                # Set model to evaluation mode
                model_w.eval()
                model_c.eval()

                seed_text = list(encoder_W.get_vocabulary()[0].keys())[torch.randint(0, len(encoder_W.get_vocabulary()[0]), (1,))]# Get seed text
                next_letters = 100
                output_letter = ""
                for _ in range(next_letters):
                    seed_text_p1, seed_text_p2 = " ".join(seed_text.split(" ")[:-1]), " " + seed_text.split(" ")[-1]

                    encoded_seed_text_p1 = encoder_W([seed_text_p1])
                    encoded_seed_text_p1 = encoded_seed_text_p1[0].to(device, dtype=torch.int)

                    encoded_seed_text_p2 = encoder_C([seed_text_p2])
                    encoded_seed_text_p2 = encoded_seed_text_p2[0].to(device, dtype=torch.int)

                    output_state = model_w(encoded_seed_text_p1)

                    predict_x = model_c(encoded_seed_text_p2, word_embed_info=output_state, state="inference", dropout_allowance=None)

                    # this classes_x is based highest probality (model has more confidence)
                    #classes_x = torch.argmax(predict_x, dim=-1)

                    # this classes_x is based on sampling from the distribution
                    classes_x = torch.distributions.Categorical(logits=predict_x).sample()

                    output_letter = ""
                    for index, letter in enumerate(encoder_C.get_vocabulary()[0]):
                        if index == classes_x:
                            output_letter = letter
                            break
                    seed_text += output_letter
                
                # tarcking model learning
                with open("output.txt", "a") as f:
                    f.write(seed_text+"\n")
                
                # Save models and index into temp path
                if i % 1000 == 0:
                    path_word = word_save_path
                    path_char = char_save_path
                    torch.save(model_w.state_dict(), path_word)
                    torch.save(model_c.state_dict(), path_char)
                    torch.save(i, start_train_at_path)
                    torch.save(total_loss, start_train_at_path.replace(".pt", "_loss.pt"))
                    collected = gc.collect()
                pbar.set_postfix({'total_loss':float(total_loss/(i)), 'sample':seed_text, 'collected':collected})

                # Set model back to training mode
                model_w.train()
                model_c.train()
        except KeyboardInterrupt:
            plt.plot(loss_plot)
            plt.show()
            raise KeyboardInterrupt
        except Exception as e:
            plt.plot(loss_plot)
            plt.show()
            print("error in sample generation:")
            print(e)
        loss_plot.append(total_loss/(i))
        i += 1
    # Reset saved start_train_at index
    torch.save(0, start_train_at_path)
    print(f"Loss: {total_loss/(len(dataloader)/128)}")
    return loss_plot

In [None]:
loss_function = nn.CrossEntropyLoss(reduction="none")
optimizer_word = torch.optim.AdamW(model_word.parameters(), lr=1e-3)
optimizer_char = torch.optim.AdamW(model_char.parameters(), lr=1e-3)

In [None]:
model_word.to(device), model_char.to(device)

In [None]:
epochs = 30

temp_save_name_word = f"model_word_{word_embed_dim}_{max_vocab_size_word}_{word_rnn_dim}_{word_rnn_layers}_{word_bidirectional}_{word_dense_dims}_{char_embed_dim}_temp.pth"
temp_save_name_char = f"model_char_{char_embed_dim}_{max_vocab_size_char}_{char_rnn_dim}_{char_rnn_layers}_{char_bidirectional}_{char_dense_dims}_temp.pth"
train_start_index_file = temp_save_name_char.split(".")[0]+temp_save_name_char.split(".")[1]+".pt"
train_start_index = 0

if not os.path.exists(os.path.join(MODEL_SAVE_PATH, temp_save_name_word)):
    print("Model does not exist, creating new model")
    print("Starting training from scratch")
    with open("output.txt", "w") as f:
        f.write("")
else:
    if input("Model already exists, overwrite? (y/n) ") == "y":
        print("Model already exists, overwriting...")
        with open("output.txt", "w") as f:
            f.write("")
        os.remove(os.path.join(MODEL_SAVE_PATH, temp_save_name_word))
        os.remove(os.path.join(MODEL_SAVE_PATH, temp_save_name_char))
        print("Starting training from scratch")
    else:
        print("Model already exists, loading...")
        model_word.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, temp_save_name_word)))
        model_char.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, temp_save_name_char)))
        if input("Load train start index? (y/n) ") == "y":
            train_start_index = torch.load(os.path.join(MODEL_SAVE_PATH, train_start_index_file))
            print(f"Starting from index {train_start_index}")

# Training loop   
for epoch in range(epochs):
    plot_data_for_loss = train(        
        model_w=model_word,        
        model_c=model_char,        
        dataloader=dataloader,        
        epoch=epoch,        
        encoder_W=word_encoder,        
        encoder_C=char_encoder,        
        loss_function=loss_function,        
        optimizer_W=optimizer_word,        
        optimizer_C=optimizer_char,        
        device=device,
        word_save_path = os.path.join(MODEL_SAVE_PATH, temp_save_name_word),
        char_save_path = os.path.join(MODEL_SAVE_PATH, temp_save_name_char),
        start_train_at_path = os.path.join(MODEL_SAVE_PATH, train_start_index_file),
        start_train_at = train_start_index
        )
    train_start_index = 0
    plt.plot(plot_data_for_loss)
    plt.show()

In [None]:
if input("Save the model?"):
    save_name_word = f"model_word_{word_embed_dim}_{max_vocab_size_word}_{word_rnn_dim}_{word_rnn_layers}_{word_bidirectional}_{word_dense_dims}_{char_embed_dim}.pth"
    save_name_char = f"model_char_{char_embed_dim}_{max_vocab_size_char}_{char_rnn_dim}_{char_rnn_layers}_{char_bidirectional}_{char_dense_dims}.pth"
    print("Converting model to cpu...")
    model_char.to("cpu")
    model_word.to("cpu")
    print("Saving model...")
    torch.save(model_word.state_dict(), os.path.join(MODEL_SAVE_PATH, save_name_word))
    torch.save(model_char.state_dict(), os.path.join(MODEL_SAVE_PATH, save_name_char))
    print("Model saved!")
    print(f"Converting model to {device}...")
    model_char.to(device)
    model_word.to(device)
    print("Done")

In [None]:
if input("Load model? (y/n) ") == "y":
    if input("Load latest model in temp save? (y/n) ") == "y":
        print("Loading model from temp model save file")
        model_word.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, temp_save_name_word)))
        model_char.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, temp_save_name_char)))
        print("Model loaded successfully")
        print(f"Converting model to {device}...")
        model_word.to(device)
        model_char.to(device)
        print("Done")
    else:
        print("Loading model from model save file")
        model_word.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, save_name_word)))
        model_char.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, save_name_char)))
        print("Model loaded successfully")
        print(f"Converting model to {device}...")
        model_word.to(device)
        model_char.to(device)
        print("Done")