In [None]:
from collections import Counter
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from collections import Counter
import numpy as np
import random

import torch
import torch.nn as nn
import math
import re
import numpy as np

In [2]:


class TextProcessor:
    def __init__(self, file_path):
        self.file_path = file_path
        self.text = self._load_file()

    def _load_file(self):
        with open(self.file_path, 'r', encoding='utf-8', errors = 'ignore') as file:
            return file.read().lower()


    def _clean_text(self, text):
        text = re.sub(r"[^a-zA-Z',.!?]", ' ', text)
        return text

    def _tokenize(self, text):
        len_temp = re.findall(r"\b\w+(?:'\w+)?\b", text)
        print(len(len_temp))
        tokens = re.findall(r"\b\w+(?:'\w+)?\b|[.!,?]", text)
        return tokens
    def _embed_tokens(self, tokens):
        embedded_tokens = []
        for token in tokens:
            if token in self.glove_vectors:
                embedded_tokens.append(self.glove_vectors[token])
            else:
                embedded_tokens.append(np.zeros(self.glove_vectors.vector_size))
        return embedded_tokens

    def process_text(self):
        cleaned_text = self._clean_text(self.text)
        tokens = self._tokenize(cleaned_text)
        return tokens

def _clean_text(text):
        text = re.sub(r"[^a-zA-Z']", ' ', text)
        return text

def _tokenize(text):
    tokens = re.findall(r"\b\w+\b", text)
    return tokens


def process_text(text):
    cleaned_text = _clean_text(text)
    tokens = _tokenize(cleaned_text)
    return tokens

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch

from collections import Counter

class TextDataset(Dataset):
    def __init__(self, encoded_text, seq_length, pad_token=0):
        self.encoded_text = encoded_text
        self.seq_length = seq_length
        self.pad_token = pad_token

    def __len__(self):
        return len(self.encoded_text) - self.seq_length

    def __getitem__(self, idx):
        x = self.encoded_text[idx:idx+self.seq_length]
        y = self.encoded_text[idx+self.seq_length]
        return (
            torch.tensor(x),
            torch.tensor(y)
        )

TextProc = TextProcessor('train_file.txt')
tokens = TextProc.process_text()
print("Number of words in training text: ", len(tokens))
token_counts = Counter(tokens)
_w2i = {'<PAD>': 0, '<UNK>': 1}

for word, count in token_counts.items():
    _w2i[word] = len(_w2i)
_i2w = {idx: word for word, idx in _w2i.items()}
encoded_text = [_w2i.get(word, _w2i['<UNK>']) for word in tokens]

In [1]:


class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers, hidden_dim, dropout=0.1, max_len=64):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, dropout, max_len)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, x, mask = None):

        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.positional_encoding(x)
        if mask == None:
            x = self.transformer_encoder(x)
        else:
            x = self.transformer_encoder(x, src_key_padding_mask=mask)
        x = self.fc_out(x)
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=32):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [None]:
def get_next_k_word(model, input_text, k):
    tokens = input_text
    if len(tokens) > seq_length:
        tokens = tokens[-seq_length:]
    input_ids = torch.tensor([_w2i.get(word, _w2i['<UNK>']) for word in tokens]).unsqueeze(0).to(device)
    outputs = model(input_ids)
    logits = outputs[0, -1]
    top_k_probs, top_k_indices = torch.topk(logits, k)
    top_k_words = [_i2w[index] for index in top_k_indices.tolist()]
    top_k_probs = top_k_probs.tolist()
    top_k_words_probs = sorted(zip(top_k_words, top_k_probs), key=lambda x: x[1], reverse=True)
    sorted_top_k_words = [word for word, prob in top_k_words_probs]

    return sorted_top_k_words

In [None]:
def evaluate_options(target_word, options, nr_of_suggestions = 5):
        """Target word is the next word, options is a list of the k most probable words according to the model
        Outputs found_word which is 1 if the target_word is in options else 0
        Outputs saved_keystrokes which is the amount of saved keystrokes, 0 if target_word not in options"""
        if target_word in options:
                if target_word in options[:nr_of_suggestions]:
                    found_word = 1
                    saved_keystrokes = len(target_word)
                    return found_word, saved_keystrokes
                for len_of_word in range(len(target_word)+1):
                    options = [word for word in options if word[:len_of_word] == target_word[:len_of_word]]
                    if target_word in options[:nr_of_suggestions]:
                        found_word = 1
                        saved_keystrokes = len(target_word) - len_of_word
                        return found_word, saved_keystrokes
                return 0, 0
        else:
            found_word = 0
            saved_keystrokes = 0
            return found_word, saved_keystrokes


def evaluate_text(model, tokens, neighbors=100):
        saved_keystrokes = 0
        total_keystrokes = 0
        model.eval()
        found_words = 0
        total_words = 0
        for index, word in enumerate(tokens[1:],1):
            if index < seq_length:
                input_indices = tokens[:index]
            else:
                input_indices = tokens[index-seq_length :index]
            top_words = get_next_k_word(model, input_indices, neighbors)
            f_w, s_k = evaluate_options(word, top_words, nr_of_suggestions = 5)
            saved_keystrokes += s_k
            found_words += f_w
            total_words += 1
            total_keystrokes += len(word)

            input_indices.append(index)

        print(f'Saved keystroke percentage: {100*saved_keystrokes/total_keystrokes:.2f}%')
        print(f'Found words percentage: {100*found_words/total_words:.2f}%')


In [None]:
from tqdm import tqdm

def train(model, nr_of_epochs, lr, dataloader):
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_curve = []
    for epoch in range(nr_of_epochs):
        model.train()
        total_loss = 0
        for inputs, targets in tqdm(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs[:, -1, :]
            loss = loss_func(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}')
        loss_curve.append(total_loss / len(dataloader))

    return loss_curve

seq_length = 7
traindataset = TextDataset(encoded_text, seq_length)
traindataloader = DataLoader(traindataset, batch_size=32, shuffle=True)
vocab_size = len(_w2i)
embed_size = 256
num_heads = 8
multiplier = 4
num_layers = 4
hidden_dim = embed_size * multiplier
lr = 1e-4
val_loss = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TextProc = TextProcessor('test_file.txt')
tokens = TextProc.process_text()
temp = []
model = TransformerModel(vocab_size, embed_size, num_heads, num_layers, hidden_dim).to(device)
val_loss = train(model, nr_of_epochs = 4, lr = lr, dataloader = traindataloader)
evaluate_text(model, tokens, 100)


In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'w2i': _w2i,
            'i2w': _i2w,
            'embedding_dim': embed_size,
            'hidden_size': hidden_dim,
            'drop_out': 0.1,
            'num_layers': num_layers,
            'num_heads': num_heads,
            'vocab_size': vocab_size,
            'sequence_length': seq_length,
            'eval_text': tokens
        }, "model.pth")

In [None]:

def load_model(file_path):
    checkpoint = torch.load(file_path)
    _w2i = checkpoint['w2i']
    _i2w = checkpoint['i2w']
    embed_size = checkpoint['embedding_dim']
    hidden_dim = checkpoint['hidden_size']
    num_layer = checkpoint['num_layers']
    vocab_size = checkpoint['vocab_size']
    drop_out = checkpoint['drop_out']
    num_heads = checkpoint['num_heads']
    model = TransformerModel(vocab_size, embed_size, num_heads, num_layer, hidden_dim).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model, _w2i, _i2w

seq_length = 7
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, _w2i, _i2w =load_model("model.pth")

TextProc = TextProcessor('test_file.txt')
tokens = TextProc.process_text()

evaluate_text(model, tokens, 100)
