In [None]:
import gensim
from gensim.models import Word2Vec
from gensim.utils import simple_preprocess
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import pickle
import threading

In [None]:
# exquisite acronym explanation (also sounds like lean):
# R - recurrent
# E - embedding
# A - approximation
# N - network

In [None]:
# embeddings mode
model_file = fr"./embedding_models/wiki_model_5_XL_vector.model"
embeddings_model = Word2Vec.load(model_file)
vector_size = embeddings_model.vector_size
window = embeddings_model.window

# neural net settings
context_length = 16
input_size = vector_size + context_length * vector_size
attention_heads = 4

# dataset
train_dataset_path = fr"./datasets/wiki_dump_train.txt"
test_dataset_path = fr"./datasets/wiki_dump_test.txt"
unique_examples_train = 4096# * 8 * 8
unique_examples_test = 4096
fluffed_up_size_train = unique_examples_train# * context_length
fluffed_up_size_test = unique_examples_test# * context_length
predicted_ram_usage = (fluffed_up_size_test + fluffed_up_size_train) * vector_size * 4 / 1000 / 1000 / 100

# training
epochs = 25 // 2#6
lr = 0.00001
loss = nn.MSELoss()
optimizer = torch.optim.Adam
batch_size = 4#8#16#32#56#1024 * 1

# pytorch
run_device = torch.device("cuda")
storage_device = torch.device("cpu")

In [None]:
print("settings summary:\n")
print(f"--- embeddings model ---\nvec size: {vector_size}\nwindow: {window}\n")
print(f"--- neural network ---\ncontext length: {context_length}\nlayer input size: {input_size}\n")
print(f"--- dataset ---\nunique train examples: {unique_examples_train}\nfluffed up train size: {fluffed_up_size_train}\nunique test examples: {unique_examples_test}\nfluffed up test size: {fluffed_up_size_test}\npredicted ram requirements: {predicted_ram_usage:.2f} GB\n")
print(f"--- training ---\nepochs: {epochs}\nlr: {lr}\nbatch size: {batch_size}\noptimizer: {optimizer}\nloss: {loss}\n")
print(f"--- pytorch ---\ndevice: {run_device}   |   {torch.cuda.get_device_properties(0).name}\nVRAM capacity: {torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024:.2f} GB\n")

In [None]:
class attention_mechanism(nn.Module):
    def __init__(self):
        super(attention_mechanism, self).__init__()
        # Linear layers to project input to queries, keys and values
        self.query = nn.Linear(vector_size, vector_size).to(run_device)
        self.key = nn.Linear(vector_size, vector_size).to(run_device)
        self.value = nn.Linear(vector_size, vector_size).to(run_device)
        
        # Output linear layer
        self.out = nn.Linear(vector_size, vector_size).to(run_device)
        
        # Scaling factor
        self.scale = torch.sqrt(torch.FloatTensor([vector_size // attention_heads])).to(run_device)

    def forward(self, x):
        x = x.to(run_device)
        
        # Linear projections
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)
        
        # Split into multiple heads
        batch_size, context_length, vector_size = x.shape
        queries = queries.view(batch_size, context_length, attention_heads, vector_size // attention_heads).permute(0, 2, 1, 3)
        keys = keys.view(batch_size, context_length, attention_heads, vector_size // attention_heads).permute(0, 2, 1, 3)
        values = values.view(batch_size, context_length, attention_heads, vector_size // attention_heads).permute(0, 2, 1, 3)
        
        # Scaled dot-product attention
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / self.scale
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_output = torch.matmul(attention_weights, values)
        
        # Concatenate heads and pass through the final linear layer
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
        attention_output = attention_output.view(batch_size, context_length, vector_size)
        output = self.out(attention_output)
        
        return output

In [None]:
# each block takes in:
# the output of the block before it (prev_block)
# the original input into the network (fresh_input)
# the original prompt of the user, not necessarily the same as original_input (<REMOVE>)
# each block returns:
# a vector of the same size as the embedding (which can be passed into the next block)
class REAN_block(nn.Module):
    def __init__(self, input_size):
        super(REAN_block, self).__init__()
        self.fc1 = nn.Linear(input_size, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, vector_size)

    def forward(self, prev_block: torch.Tensor, fresh_input: torch.Tensor) -> torch.Tensor:
        x = torch.cat((prev_block, fresh_input), dim=1)
        
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = self.fc3(x)
        
        return x

In [None]:
class REAN(nn.Module):
    def __init__(self, input_size):
        super(REAN, self).__init__()
        
        #self.attention_mechanism = attention_mechanism()
        
        self.block1 = REAN_block(input_size)
        self.block2 = REAN_block(input_size)
        self.block3 = REAN_block(input_size)
        self.block4 = REAN_block(input_size)
        self.block5 = REAN_block(input_size)
        self.block6 = REAN_block(input_size)
        self.block7 = REAN_block(input_size)
        self.block8 = REAN_block(input_size)

    def forward(self, current_segment: torch.Tensor) -> torch.Tensor:
        # keep an x's clone from the start so the blocks get fresh input
        self.fresh_input = current_segment.clone()
        
        # apply attention mechanism
        #self.fresh_input = self.attention_mechanism(self.fresh_input).reshape(current_segment.shape[0], vector_size * context_length)
        self.fresh_input = self.fresh_input.reshape(current_segment.shape[0], vector_size * context_length)
        
        # supplement the (non existant) previous block's output with zeros
        current_segment = self.block1(torch.zeros((current_segment.shape[0], vector_size), dtype=torch.float32, device=run_device), self.fresh_input)
        current_segment = self.block2(current_segment, self.fresh_input)
        current_segment = self.block3(current_segment, self.fresh_input)
        current_segment = self.block4(current_segment, self.fresh_input)
        current_segment = self.block5(current_segment, self.fresh_input)
        current_segment = self.block6(current_segment, self.fresh_input)
        current_segment = self.block7(current_segment, self.fresh_input)
        current_segment = self.block8(current_segment, self.fresh_input)
        
        return current_segment

net = REAN(input_size)
net.to(run_device)
optimizer = optimizer(net.parameters(), lr=lr)

In [None]:
def load_dataset_chunk(path: str, num_words: int, seek_start: int, sep: str = " ") -> tuple[list, bool, int]:
    """
    function to load a chunk of the dataset where the words are separated by "sep" into a list
    
    parameters:
        path (str): path to the dataset txt file
        num_words (int): number of words to load
        seek_start (int): start char to pull words from
        sep (str, optional): separator in the dataset, defaults to space " "
    
    returns:
        list: list of strings (loaded words), is EOF hit, seek position to move 1 word forward
    """
    
    # some safety checks so later code looks cleaner
    num_words = max(0, num_words)
    seek_start = max(0, seek_start)
    
    words = []
    current_word_idx = 0
    word_buffer = ""
    current_seek = seek_start
    next_seek = 0
    first_word_flag = True

    with open(path, 'r', encoding='utf-8', errors='ignore') as file:
        file.seek(seek_start)
        
        # loop over all chars after seek_start
        while True:
            char = file.read(1)
            current_seek += 1
            
            # end of file, return whatever has been collected immediately
            if not char:
                return words, True, next_seek
            
            # is a separator between words hit
            if char == sep or char.isspace():
                if word_buffer:
                    if current_word_idx < num_words:
                        words.append(word_buffer)
                    
                    current_word_idx += 1
                    word_buffer = ""
                
                if current_word_idx >= num_words:
                    break
                
                # the first word is covered, this is where the next chunk is going to be loaded from
                if first_word_flag:
                    first_word_flag = False
                    next_seek = current_seek
            else:
                word_buffer += char

    return words, False, next_seek

In [None]:
def vectorize_sentence(sentence: list[str], model: Word2Vec, default: int = 0) -> np.ndarray:
    """
    encodes all words in a given list to corresponding vectors in given model.
    words not found in the model will be given a vector with "default" value
    
    parameters:
        sentence (list): list of strings (words)
        model (Word2Vec): model to use when encoding
        default (int): fill vector with this value if word is not found in model
    
    returns:
        np.array: 2d array with dim1 = len(sentence) and dim2 = model.vector_size
    """
    
    # generate inital array with default values
    vectorized = np.ones((len(sentence), model.vector_size)) * default
    
    # loop over every word in list
    for current_word, current_word_idx in zip(sentence, range(len(sentence))):
        # only add correct values if word is in model, otherwise leave as default
        if current_word in model.wv:
            vectorized[current_word_idx] *= 0
            vectorized[current_word_idx] += model.wv[current_word]
    
    return vectorized

In [None]:
def devectorize_sentence(vectorized_sentence: np.array, model: Word2Vec) -> list:
    """
    decodes vectors into nearest word found in model
    
    parameters:
        vectorized_sentence (np.array): 2d arrat with vectors of words to be decoded
        model (Word2Vec): model to use when decoding
    
    returns:
        list: list of strings (words) whos vectors most closely match those provided
    """
    
    result = []
    
    # go over all words and find closest match in model
    for current_word in vectorized_sentence:
        result.append(model.wv.similar_by_vector(current_word)[0][0])
    
    return result

In [None]:
def pad_or_truncate(suspected_tensor: torch.tensor, target_length: int, default: int=0) -> torch.Tensor:
    """
    pads or truncates a given tensor along dim 0 to target_length with "default" as padding
    
    parameters:
        suspected_tensor (torch.tensor): tensor to pad or truncate
        target_length (int): target length of tensor
        default (int): value to use for padding
    
    returns:
        torch.tensor: tensor of proper length no matter what
    """
    
    if len(suspected_tensor) < target_length:
        # pad
        suspected_tensor = torch.cat((torch.ones(target_length - len(suspected_tensor), suspected_tensor.shape[1], dtype=torch.float32, device=suspected_tensor.device) * default, suspected_tensor))
    else:
        # truncate
        suspected_tensor = suspected_tensor[-target_length:]
    
    return suspected_tensor

In [None]:
def prepare_sentence_for_net(sentence: list, model: Word2Vec, context_length: int, flatten: bool=True, used_device: torch.device=run_device) -> torch.Tensor:
    """
    turns a sentence (list of strings) into a tensor that can be fed directly into the network
    
    parameters:
        sentence (list): list of strings (words)
        model (Word2Vec): model to use when encoding sentence
        context_length (int): length of context to consider when encoding, should be same as network's
    
    returns:
        torch.tensor: tensor of proper length no matter what
    """
    
    # encode sentence to np.array
    vectorized = vectorize_sentence(sentence, model)
    vectorized_tensor = torch.from_numpy(vectorized).to(used_device).to(torch.float32)
    
    # pad or truncate
    vectorized_tensor = pad_or_truncate(vectorized_tensor, context_length)
    
    if flatten:
        # flatten to fit into first fc layer of the net
        vectorized_tensor = vectorized_tensor.flatten()
    
    return vectorized_tensor

In [None]:
def predict_word(current_segment: list, net: REAN, embeddings_model: Word2Vec) -> str:
    """
    uses the net and the model to predict the next word to fit the given sentence
    
    parameters:
        sentence (list): list of strings (words)
        net (GPT_like): net to use when predicting
        model (Word2Vec): embedding model to use when encoding sentence
    
    returns:
        str: predicted word
    """
    encoded_segment = prepare_sentence_for_net(current_segment, embeddings_model, context_length, flatten=False)
    
    # run sentence
    output = net(encoded_segment.unsqueeze(0))
    
    # add the net's vector to the end of the current segment
    target = output + encoded_segment[-1]
    
    # decode most similar word to whatever net predicted
    predicted_word = embeddings_model.wv.similar_by_vector(target.detach().squeeze(0).cpu().numpy())[0][0]
    
    return predicted_word

In [None]:
def predict_sequence(sentence: list, net: REAN, embeddings_model: Word2Vec, num_completions: int) -> list:
    """
    predicts multiple words at the end of the given sentence
    
    parameters:
        sentence (list): list of strings (words)
        net (GPT_like): net to use when predicting
        model (Word2Vec): embedding model to use when encoding sentence
        num_completions (int): number of words to predict
    
    returns:
        list: list of words to be appended to given sentence
    """
    
    predicted_result = sentence
    
    for _ in tqdm(range(num_completions)):
        # give the network the full context to work with, while only collecting the predicted part into the result
        predicted_result.append(predict_word(predicted_result, net, embeddings_model))
    
    return predicted_result

In [None]:
class REAN_dataset(Dataset):
    def extrapolate_and_add_example(self, path: str, seek_start: int, context_length: int, embeddings_model: Word2Vec, append_context: list, append_target: list):
        self.current_segment, self.eof, self.seek_start = load_dataset_chunk(path, context_length + 1, seek_start)
        
        self.context = prepare_sentence_for_net(self.current_segment[:-1], embeddings_model, context_length, flatten=False, used_device=storage_device)
        self.target = prepare_sentence_for_net([self.current_segment[-1]], embeddings_model, 1, flatten=False, used_device=storage_device).squeeze(0)
        
        for cut in range(len(self.current_segment) - random.randint(context_length // 2, context_length)):
            self.context = prepare_sentence_for_net(self.current_segment[:-1][cut:], embeddings_model, context_length, flatten=False, used_device=storage_device)
            
            self.diff = self.target - self.context[-1]
        
            append_context.append(self.context)
            append_target.append(self.diff)
        
        return self.eof, self.seek_start
        
    def __init__(self, path, num_unique_examples, context_length, embeddings_model):
        self.seek_start = 0
        
        self.current_segments = []
        self.targets = []
        
        for _ in tqdm(range(num_unique_examples)):
            self.eof, self.seek_start = self.extrapolate_and_add_example(path, self.seek_start, context_length, embeddings_model, self.current_segments, self.targets)
            
            if self.eof:
                print("eof hit, early stop")
                print(f"fluffed up size: {len(self.targets)}")
                return
        
        print(f"fluffed up size: {len(self.targets)}")
        
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, index):
        return self.current_segments[index], self.targets[index]

In [None]:
train_dataset = REAN_dataset(train_dataset_path, unique_examples_train, context_length, embeddings_model)
test_dataset = REAN_dataset(test_dataset_path, unique_examples_test, context_length, embeddings_model)

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
train_loss_graph = []
test_loss_graph = []

In [None]:
for epoch in tqdm(range(epochs)):
    for current_segment, target in train_loader:
        # move batch to gpu
        current_segment = current_segment.to(run_device)
        target = target.to(run_device)
        
        # train batch
        optimizer.zero_grad()
        train_outputs = net(current_segment)
        train_loss_value = loss(train_outputs, target)
        train_loss_value.backward()
        optimizer.step()
        
        # collect performance metrics
        train_loss_graph.append(train_loss_value.item())
        
    if epoch % 24 == 0:
        with torch.no_grad():
            for test_current_segment, test_target in test_loader:
                # move to gpu
                test_current_segment = test_current_segment.to(run_device)
                test_target = test_target.to(run_device)
                
                test_outputs = net(test_current_segment)
                test_loss_value = loss(test_outputs, test_target)
                test_loss_graph.append(test_loss_value.item())

In [None]:
plt.plot(train_loss_graph)

In [None]:
torch.save(net.state_dict(), 'no_attention_mech.pth')

In [None]:
plt.plot(test_loss_graph)

In [None]:
sentence = "was the first man".split(" ")

In [None]:
" ".join(predict_sequence(sentence, net, embeddings_model, 32))