In [1]:
import torch
from torch import tensor, cat
from torch.nn import functional as F

In [105]:
# Functions for Bigram 27 neuron single layer implementation.

class Ngram_nn:
    
    def __init__(self, ngram = 2):
        self.ngram = ngram
        self.weight_matrix = torch.randn((27*(self.ngram-1),27), requires_grad=True) #each column of the weight matrix corresponds to one of the 27 neurons of the single layer
        self.one_hot_encoding = torch.nn.functional.one_hot(torch.arange(0, 27), num_classes=27).float() # each row of the encoding matrix corresponds to a single character.
        print(self.one_hot_encoding.shape)
        self.loss = None
        self.learning_rate = 30
        

        

    def get_name_list_with_st(self): #st - special token
        words = open('names.txt', 'r').read().splitlines()
        words_st = [] #st - special token
        for w in words:
            words_st.append("."+ w + ".")
        print('retrieved name list')
        return words_st


    def get_indices(self,words_st):
        stoi = dict(zip(sorted(set(''.join(words_st))),list(range(27))))
        itos = {value: key for key, value in stoi.items()}

        print('retrieved indices')
        return stoi, itos

    def get_encoding(self, words_st, stoi, itos):
        print('encoding started')
        xs = torch.empty((0,27*(self.ngram-1)), dtype=torch.float32)
        ys = torch.empty((0,1),dtype=torch.int64)

            

        for w in words_st:

            for i in range(len(w) - self.ngram + 1):
                input_ngram = w[i:i + self.ngram-1]
                target_char = w[i + self.ngram-1] if i + self.ngram-1 < len(w) else None
                
                
                if target_char is not None:
                    input_encoding = torch.empty((1,27*(self.ngram-1)), dtype=torch.float32)
                    i=0

                    for ch in input_ngram:

                        input_encoding[0,(27*i):(27*(i+1))] = self.one_hot_encoding[stoi[ch]].unsqueeze(0)
                        i+=1

                    xs = torch.cat((xs, input_encoding), dim=0)
                    
                    # Append the target character to ys
                    ys = torch.cat((ys, torch.tensor(stoi[target_char]).unsqueeze(0).unsqueeze(0)), dim=0)




        
        print('retrieved encodings')
        return xs, ys

    def forward_with_loss(self,xs,ys):
        layer1_output = xs @ self.weight_matrix # each row of the layer1_output corresponds to the probability vector of the next character for each of the characters corresponding to the rows in xs.
        layer1_output_probs = F.softmax(layer1_output, dim=1)
        # print(layer1_output.shape)
        self.loss = -torch.mean(torch.log(torch.gather(layer1_output_probs, dim = 1, index=ys))) #loss is average negative log likelihood
        # print(self.loss.item())
        
    
    def forward(self,xs):

        layer1_output = xs @ self.weight_matrix # each row of the layer1_output corresponds to the probability vector of the next character for each of the characters corresponding to the rows in xs.
        layer1_output_probs = F.softmax(layer1_output, dim=1)
        
        return layer1_output_probs

    def print_loss(self):
        print(self.loss.item())
        

    
    def backward(self):
        self.loss.backward()
        
    def update_weights(self):
        with torch.no_grad():
            self.weight_matrix -= self.learning_rate * self.weight_matrix.grad

        _ = self.weight_matrix.grad.zero_()
    
    def train(self, epochs, xs, ys, stoi, itos):

        print('Learning rate:\n')
        for _ in range(epochs):
            self.forward_with_loss(xs,ys)
            self.print_loss()
            self.backward()
            self.update_weights()

        
    def generate_bigram_nn_names(self, no_of_names_to_generate):
        start_token = 0
        gen_name = ''
        
        g = torch.Generator().manual_seed(2147483647)

        for _ in range(no_of_names_to_generate): 
            next_idx = start_token
            character_encoding = self.one_hot_encoding[next_idx].unsqueeze(0)
            character_encoding = torch.cat([character_encoding] * (self.ngram-1), dim=1)
            

            while True:

                character_encoding[0,0:((self.ngram-2)*27)] = character_encoding[0,27:((self.ngram-1)*27)].clone()

                character_encoding[0,-27:] = self.one_hot_encoding[next_idx].unsqueeze(0)

                layer1_output_probs = self.forward(character_encoding)
                
                

                next_idx = torch.multinomial(layer1_output_probs, num_samples=1, replacement=True, generator=g).item()
                
                if(next_idx == 0):
                    break 
                gen_name += itos[next_idx]
            
            print(gen_name)
            gen_name = ''
    



        
        
        
    


In [111]:
#Generate encodings.

ngram_nn_obj = Ngram_nn(3)
words_st = ngram_nn_obj.get_name_list_with_st()
stoi, itos = ngram_nn_obj.get_indices(words_st)

torch.Size([27, 27])
retrieved name list
retrieved indices


In [112]:
xs, ys = ngram_nn_obj.get_encoding(words_st, stoi, itos)

encoding started
retrieved encodings


In [109]:
#Train the model.
epochs = 5000
ngram_nn_obj.train(epochs, xs, ys, stoi, itos)

Learning rate:

2.183405637741089
2.1834053993225098
2.1834053993225098
2.1834051609039307
2.1834049224853516
2.1834049224853516
2.1834046840667725
2.1834046840667725
2.1834044456481934
2.1834044456481934
2.1834042072296143
2.183403968811035
2.183403968811035
2.183403968811035
2.183403491973877
2.183403253555298
2.183403253555298
2.1834030151367188
2.1834030151367188
2.1834030151367188
2.1834030151367188
2.1834025382995605
2.1834025382995605
2.1834022998809814
2.1834020614624023
2.1834020614624023
2.1834020614624023
2.1834018230438232
2.183401584625244
2.183401584625244
2.183401346206665
2.183401107788086
2.183401107788086
2.183400869369507
2.183400869369507
2.1834006309509277
2.1834006309509277
2.1834003925323486
2.1834003925323486
2.1834001541137695
2.1833999156951904
2.1833999156951904
2.1833996772766113
2.1833994388580322
2.1833994388580322
2.183399200439453
2.183399200439453
2.183399200439453
2.183399200439453
2.183398723602295
2.183398485183716
2.183398485183716
2.183398246765136

In [110]:
ngram_nn_obj.generate_bigram_nn_names(20)

wel
ryanalluraila
vey
wellin
wyna
wyllayn
wanda
wwanthellansi
vetti
welie
wwyo
verted
wen
wyle
vedgu
veavirny
wyls
wwinn
wwytahlas
wysor
