In [None]:
import torch
import os

In [None]:
MODEL_SAVE_PATH = "pytorch_model_saves"
VOCABULARY_SAVE_PATH = "pytorch_vocab_saves"

In [None]:
#is cuda available
print("is cuda available: ", torch.cuda.is_available())
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("cuda cache cleared")
device = "cpu" #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from torch_utils_module import TextVectorizer

# Create instances of the TextVectorization class
char_encoder = TextVectorizer(max_tokens=None, lower=True, strip_punctuation=False)
word_encoder = TextVectorizer(max_tokens=None, split=" ", lower=True, strip_punctuation=True)

#old_wiki2
char_vocabulary_save_path = VOCABULARY_SAVE_PATH + "/reduced_chars_new_full.pth"
word_vocabulary_save_path = VOCABULARY_SAVE_PATH + "/reduced_words_new_full.pth"

if os.path.exists(char_vocabulary_save_path):
    char_encoder.load_vocabulary(char_vocabulary_save_path)
    print("Char vocabulary loaded successfully")
else:
    raise Exception("Char vocabulary not found")

if os.path.exists(word_vocabulary_save_path):
    word_encoder.load_vocabulary(word_vocabulary_save_path)
    print("Word vocabulary loaded successfully")
else:
    raise Exception("Word vocabulary not found")

# Get vocabulary sizes
max_vocab_size_char = len(char_encoder.get_vocabulary()[0])
max_vocab_size_word = len(word_encoder.get_vocabulary()[0])

In [None]:
# Check if vocabulary sizes are correct
max(word_encoder.get_vocabulary()[0].values())+1 == len(word_encoder.get_vocabulary()[0].values()), max(char_encoder.get_vocabulary()[0].values())+1 == len(char_encoder.get_vocabulary()[0].values())

In [None]:
max_vocab_size_char, max_vocab_size_word

In [None]:
from torch_utils_module import Model

# Model hyperparameters
word_embed_dim = 256
char_embed_dim = 128
word_rnn_dim, word_rnn_layers = 256, 2
char_rnn_dim, char_rnn_layers = 256, 2
word_bidirectional, char_bidirectional = False, False
word_dense_dims, char_dense_dims = [128, 128], []

model_word = Model(
    in_embedding_dim=word_embed_dim,
    pretrained_embedding_path=None,
    out_embedding_dim=char_embed_dim,
    rnn_dim=word_rnn_dim,
    num_rnn_layers=word_rnn_layers,
    rnn_dropout=0,
    bidirectional=word_bidirectional,
    dense_dims=word_dense_dims,
    vocab_size=max_vocab_size_word,
    mode="word"
)
model_char = Model(
    in_embedding_dim=char_embed_dim,
    out_embedding_dim=None,
    rnn_dim=char_rnn_dim,
    num_rnn_layers=char_rnn_layers,
    rnn_dropout=0,
    bidirectional=char_bidirectional,
    dense_dims=char_dense_dims,
    vocab_size=max_vocab_size_char,
    mode="char"
)

In [None]:
model_word.to(device), model_char.to(device)

In [None]:
# Load model
temp_save_path_word = f"model_word_{word_embed_dim}_{max_vocab_size_word}_{word_rnn_dim}_{word_rnn_layers}_{word_bidirectional}_{word_dense_dims}_{char_embed_dim}_temp.pth"
temp_save_path_char = f"model_char_{char_embed_dim}_{max_vocab_size_char}_{char_rnn_dim}_{char_rnn_layers}_{char_bidirectional}_{char_dense_dims}_temp.pth"
save_path_word = f"model_word_{word_embed_dim}_{max_vocab_size_word}_{word_rnn_dim}_{word_rnn_layers}_{word_bidirectional}_{word_dense_dims}_{char_embed_dim}.pth"
save_path_char = f"model_char_{char_embed_dim}_{max_vocab_size_char}_{char_rnn_dim}_{char_rnn_layers}_{char_bidirectional}_{char_dense_dims}.pth"

if input("Load model? (y/n) ") == "y":
    if input("Load latest model in temp save? (y/n) ") == "y":
        print("Loading model from temp model save file")
        model_word.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, temp_save_path_word)))
        model_char.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, temp_save_path_char)))
        print("Model loaded successfully")
        print(f"Converting model to {device}...")
        model_word.to(device)
        model_char.to(device)
        print("Done")
    else:
        print("Loading model from model save file")
        model_word.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, save_path_word)))
        model_char.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, save_path_char)))
        print("Model loaded successfully")
        print(f"Converting model to {device}...")
        model_word.to(device)
        model_char.to(device)
        print("Done")

In [None]:
# control parameters
device = device            #this line is here so that i can change device from the same cell
prompt = "Politics is "
lines = 10
next_words = 10
use_distribution = False

prompt = prompt.lower()

# move model to device and set to eval mode
model_word.to(device)
model_char.to(device)
model_word.eval()
model_char.eval()

# print info on whether each word in the promt is in the vocabulary
print("Word\t\t\tIs_in_vocabulary\t\tIndex")
for word in prompt.split(" "):
    if word in word_encoder.get_vocabulary()[0]:
        print(f"{word}\t\t\tO\t\t\t{word_encoder.get_vocabulary()[0][word]}")
    else:
        print(f"{word}\t\t\tX")

# generate text
for _ in range(lines):
    seed_text = prompt                                                              # set seed text
    print(seed_text, end="")
    for _ in range(next_words):
        # use word model to get sentence meaning
        seed_text_p1 = " ".join(seed_text.split(" ")[:-1])                          # remove last word (last word will be used in/by char model)
        encoded_seed_text_p1 = word_encoder([seed_text_p1])                         # encode seed text
        encoded_seed_text_p1 = encoded_seed_text_p1[0].to(device, dtype=torch.int)  # move encoded seed text to device
        output_state = model_word(encoded_seed_text_p1, state="inference")          # get sentence meaining from model
        output_letter = ""                                                          # initialize output letter (will be updated in loop)
        while output_letter not in ("<pad>", "<unk>", " "):                         # loop until output letter is not in vocabulary
            # use char model to get next letter of the word in generation
            seed_text_p2 = seed_text.split(" ")[-1]                                                                                 # get last word
            encoded_seed_text_p2 = char_encoder([seed_text_p2])                                                                     # encode last word
            encoded_seed_text_p2 = encoded_seed_text_p2[0].to(device, dtype=torch.int)                                              # move encoded seed text to device
            predict_x = model_char(encoded_seed_text_p2, word_embed_info=output_state, dropout_allowance=0.075, state="inference")   # get next letter
            if use_distribution:
                classes_x = torch.distributions.Categorical(logits=predict_x).sample()                                              # sample next letter based on confidence i.e. highest probability
            else:
                classes_x = torch.argmax(predict_x).item()                                                                  # sample next letter based on probability distribution
            output_letter = ""                                                                                                      # Reset output letter
            for index, letter in enumerate(char_encoder.get_vocabulary()[0]):                                                       # get letter from vocabulary
                if index == classes_x:
                    output_letter = letter
                    break
            
            print(output_letter, end="")                                            # print next letter
            seed_text += output_letter                                              # update seed text
    print()                                                                         # print new line