In [12]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import numpy as np
import pandas as pd
import string
from nltk import word_tokenize
import re
punct = string.punctuation
from nltk.stem import WordNetLemmatizer

lemmatizer = WordNetLemmatizer()
df = pd.read_csv('/Users/gauneg/datasets/fun_dsets/sourth_park_ds.csv')

### helper functions:
1. `context_ngrams`: creates all (CONTEXT, WORD) pairs for input sentence
2. `flat_input_gen`: Applies context_ngrams to entire input dataset
3. `create_batch`: batches inputs for learning
4. `separate_x_y_encode`: encodes batched sequences

In [13]:

def context_ngrams(word_arr, context_size, mode="uni_dir"):
    txt = []
    end_point = len(word_arr) - context_size if mode=="bi_dir" else len(word_arr)
    for i in range(context_size, end_point):
        if mode=="uni_dir":
            txt.append([word_arr[i-context_size: i], word_arr[i]])
        else:
            txt.append([word_arr[i-context_size: i] +  word_arr[i+1: i + context_size+1], 
                        word_arr[i]])
    return txt

def flat_input_gen(inp_text_list, context_function):
    flat_arr = []
    for sentence in inp_text_list:
        sent_no_punct = "".join([alpha for alpha in sentence if alpha not in punct])
        tokenized_sentence = word_tokenize(sent_no_punct) #.split()
        tokenized_sentence = [lemmatizer.lemmatize(token) for token in tokenized_sentence]
        for context_outs in context_function(tokenized_sentence):
            flat_arr.append(context_outs)
    return flat_arr

def create_batch(flat_arr, batch_size):
    fin_index = len(flat_arr) - len(flat_arr)%batch_size

    for cur_index in range(0, fin_index, batch_size):
        
        yield flat_arr[cur_index: cur_index+batch_size]

    if fin_index<=len(flat_arr):
        yield flat_arr[fin_index:]

def separate_x_y_encode(xy_comb, word_to_ix):
    ctx = []
    words = []
    for context, word in xy_comb:
        ctx.append([word_to_ix.get(kx, word_to_ix['<unk>']) for kx in context])
        words.append(word_to_ix.get(word, word_to_ix['<unk>']))
    return torch.tensor(ctx), torch.tensor(words)

In [14]:
CONTEXT_SIZE = 2
EMBEDDINGS_DIM = 10
cartman_df = df[df['Character']=='Cartman']
cartman_lines = [text.strip().lower() for text in cartman_df['Line'].values]
crt_dry = cartman_lines

part_applied_context = lambda txt_arr: context_ngrams(txt_arr, CONTEXT_SIZE)

In [15]:
# select vocabulary 

X = flat_input_gen(crt_dry, part_applied_context)
vocab = [ai for a,b in X for ai in a]
word_freq = {}

for word in vocab:
    if word.lower() not in word_freq.keys():
        word_freq[word.lower()] = 0
    word_freq[word.lower()] += 1

word_frx = [(word, cx) for word, cx in word_freq.items()]
sort_vocab = sorted(word_frx, key=lambda x: x[1], reverse=True)

word_to_ix = {x[0]: i+1 for i, x in enumerate(sort_vocab[:1600])}
word_to_ix['<unk>'] = 0
ix_to_word = {x: k for k,x in word_to_ix.items()}

batched_inps = list(create_batch(X, 32))

In [16]:
class NGramLanguageModelerBatched(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size) -> None:
        super(NGramLanguageModelerBatched, self).__init__()
        self.embedding_dim = embedding_dim
        self.context_size = context_size
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.liner1 = nn.Linear(context_size * embedding_dim, 128, )
        self.liner2 = nn.Linear(128, vocab_size)

    def forward(self, inputs):
        embeds = self.embeddings(inputs).view((-1, self.embedding_dim*self.context_size))
        # print('Embedding output', embeds.shape)\
        out = F.relu(self.liner1(embeds))
        
        out = self.liner2(out)
        log_probs = F.log_softmax(out, dim=1)
        return log_probs
        

In [17]:
loss_function = nn.NLLLoss()
model = NGramLanguageModelerBatched(len(word_to_ix), EMBEDDINGS_DIM, CONTEXT_SIZE)
optimizer = optim.SGD(model.parameters(), lr=0.005)

In [18]:
for epoch in range(100):
    total_loss = 0
    for batch in batched_inps:
        enc_x, enc_y = separate_x_y_encode(batch, word_to_ix)
        model.zero_grad()
        log_probs = model(enc_x)
        loss = loss_function(log_probs, enc_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(total_loss)


22949.436532497406
20674.334274291992
19984.43032836914
19652.853738069534
19443.432136058807
19288.52066373825
19163.700829982758
19057.56549358368
