Character level text generator with a single LSTM cell. My goal was to make sure that it works and that it learns, while at the same time training myself on a few more software engineering tricks. It appears to be able to learn phrases.

In [237]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.io import read_image
from torchvision.transforms import ToTensor, Lambda
import pandas as pd
import os
import torch.optim as optim
import random

In [238]:
#Dictionary class with torch.tensor encoding and decoding methods
class My_dictionary(object):
  def __init__(self, text):
    self.alphabet = list(set(text))
    self.alphabet_dictionary = {}
    for k in range(len(self.alphabet)):
      self.alphabet_dictionary[self.alphabet[k]] = torch.zeros(len(self.alphabet)).scatter_(0, torch.tensor(k), 1 )



  def text_encoding(self, text):
    data=[]
    for char in text:
     data.append(self.alphabet_dictionary[char])  
    input_tensor = torch.stack(tuple(data), dim =0)
    return input_tensor



  def text_decoding(self, input):
    text_output= ""
    input_det = input.detach()    
    for k in range(len(input_det)):
      inp_list = input_det[k].tolist()
      probability = [elem/sum(inp_list) for elem in inp_list] #This renormalization of probability is needed due to a truncation error in going from torch.tensors to lists. It is such a small difference that is irrelevant.
      sampled_char = np.random.choice(list(self.alphabet_dictionary.keys()), p= probability)
      text_output += sampled_char 
    return text_output
  




#Building an iterator over minibatches of inputed text after tensor encoding
class Dataloader_iter(object):
  def __init__(self, text, dictionary, batch_size, shuffle_batch=False):
    self.batch_size = batch_size
    self.shuffle_batch = shuffle_batch
    self.length= len(text)
    self.text = text

    self.dictionary = dictionary
    self.encoding = self.dictionary.text_encoding(text)

  

  def get_batches(self):
    bs= self.batch_size
    input, expectation = self.encoding[:-1], self.encoding[1:]
    minibatches = [(input[k*bs:(k+1)*bs], expectation[k*bs : (k+1)*bs]) for k in range(int(self.length/bs))]

    if (self.shuffle_batch):
      random.shuffle(minibatches)
    
    self.minibatches = minibatches
  

  def __iter__(self):
    self.indx =0
    self.get_batches()
    return self


  def __next__(self):
    if (self.indx >= int(self.length/self.batch_size)):
      raise StopIteration
    self.indx +=1
    return self.minibatches[self.indx -1]





In [239]:
#LSTM cell

class Lstm_cell(nn.Module):
  def __init__(self, cell_size, inp_size):
    super().__init__()
    self.cell_size = cell_size
    self.inp_size = inp_size

    self.memory_state= torch.zeros(cell_size)
    self.cell= torch.zeros(cell_size)

    self.fgate = nn.Linear(cell_size + inp_size, cell_size, bias=True) #forget gate
    self.ingate = nn.Linear(cell_size + inp_size, cell_size, bias=True) #in gate
    self.intomem = nn.Linear(cell_size + inp_size, cell_size, bias=True)  #input layer

    self.outgate = nn.Linear(cell_size + inp_size, cell_size, bias=True) #out gate
    self.celltoout = nn.Linear(cell_size, inp_size, bias=False) #output layer

    self.activation = nn.Sigmoid() #I seem to be getting better information propagation with the sigmoid as compared to relu
    self.amplifier= torch.tensor(2.) #rescales some activations below


  #reseting memory and cell state
  def forget_everything(self):
    self.memory_state =torch.zeros(self.cell_size)
    self.cell =torch.zeros(self.cell_size)

  
  def tanh_act(self,x):
    return torch.tanh(x)



  def forward(self, inp):
    #removing input+cell state gradient for truncated backpropagation algorithm
    cell_trunc = self.cell.detach()
    inp_trunc = inp.detach()
    #total cell input
    y = torch.cat((cell_trunc, inp_trunc), dim = 0)
    

    #in-gate activation
    write_control= self.activation(self.ingate(y))
    #write-in-memory activation
    write_input = self.amplifier*self.tanh_act(self.intomem(y))
    

    #forget-and-write memory state update
    regulator= 1. #adding the option of renormalizing the memory state to prevent it from saturating the tanh activation
    self.memory_state =  ((write_control * write_input) + self.memory_state*self.activation(self.fgate(y)))/regulator 


    #out-gate activation
    read_control = self.activation(self.outgate(y))
    #read-from-memory activation
    read_output= self.tanh_act(self.memory_state) #renormalization of output in order to prevent softmax blow-ups --value chosen by trial and error

    #total cell output
    out = read_control * read_output*self.amplifier 
    self.cell = out
    out = self.celltoout(self.cell)
    
    return out

In [240]:
class Network(nn.Module):
  def __init__(self, cells, activation= nn.Softmax(dim=1)):
    super().__init__()
    self.cells = cells
    self.loss = nn.CrossEntropyLoss()
    self.softmax= activation
  

  #forward method
  def forward(self, text, activate=True, forget= True):
    if forget: self.cells.forget_everything()
    output = (self.cells.forward(text[0])).unsqueeze(0)

    for elem in range(1,len(text)):
      out = self.cells.forward(text[elem])
      output = torch.cat((output, out.unsqueeze(0)), dim = 0)   
    if activate: output= self.softmax(output)
    return output

  

  #Evaluate progress
  def evaluate(self, test_batch):
    self.cells.forget_everything()
    prediction = self.forward(test_batch[0]).detach()
    accuracy= (((torch.argmax(test_batch[1],dim=1)==torch.argmax(prediction, dim=1))).sum())/len(test_batch[0])
    return accuracy
  

  #generate text --method 1: Picking most likely element at every step
  def generate_max(self, start, length, forget=True):
    if forget: self.cells.forget_everything()
    idx = torch.argmax(self.cells.forward(start).detach())
    out = torch.zeros_like(start)
    out[idx] =torch.tensor(1.)
    output=out.unsqueeze(0)

    for elem in range(length-1):
      idx = torch.argmax(self.cells.forward(out))
      out = torch.zeros_like(start)
      out[idx] =torch.tensor(1.)
      output = torch.cat((output, out.unsqueeze(0)), dim = 0)  
    output = torch.cat((start.unsqueeze(0), output), dim = 0) 
    return output


  #generate text --method 2: Using output probabilities for random character sampling at every step
  def generate_prob(self, start, length, forget=True):
    if forget: self.cells.forget_everything()

    out = start
    output = start.unsqueeze(0) 

    for k in range(length):
      out = self.cells.forward(out)
      out= self.softmax(out.unsqueeze(0))

      probability =(out.squeeze()).tolist()
      probability = [elem/sum(probability) for elem in probability]
      idx = np.random.choice(len(start), p= probability)
      out =torch.zeros_like(start)
      out[idx] =1
      
      output = torch.cat((output, out.unsqueeze(0)), dim = 0) 
    return output

  


In [246]:
#Constructing the model, data

def get_data(text, bs):
  dictionary = My_dictionary(text)
  train_dataloader = Dataloader_iter(text, My_dictionary(text), batch_size=bs, shuffle_batch=True)
  return dictionary, train_dataloader



def get_model(cell_size, inp_size, lr):
  model = Network(Lstm_cell(cell_size,inp_size))
  opt = optim.SGD(model.cells.parameters(), lr)
  return model, opt


In [242]:
#Text generator
def generate_text(model, dictionary, start, length, forget=True):
  start_vec = dictionary.alphabet_dictionary[start]
  text_instance_encoded = model.generate_prob(start_vec, length, forget)
  text_instance = dictionary.text_decoding(text_instance_encoded)
  print(text_instance)

In [243]:
#training function
def fit(epochs, model, opt, dictionary, training_dataloader, test_data=None, length=None, start=None, forget=True):
  for epoch in range(epochs):
    train_data_iter = iter(training_dataloader)
    for input_batch, expect_batch in train_data_iter:
      #training method for text minibatch
      total_loss =0
      #reset memory and cell state
      model.cells.forget_everything()
      #forward pass
      out = model.forward(input_batch, activate=False)
      #backward pass
      loss = model.loss(out , torch.argmax(expect_batch, dim=1))
      total_loss += loss
      loss.backward()
      #print(f"{[(torch.min(p.grad), torch.max(p.grad)) for p in model.parameters() ]}")

      #update text sequence minibatch    
      opt.step()
      opt.zero_grad()
    #printing out some data to make sure activations are not blowing up during training
    print(f"memory: {(torch.min(model.cells.memory_state), torch.max(model.cells.memory_state))} \n softmax: {model.softmax(model.cells.cell.unsqueeze(0))}")

    
    #compute loss on test data
    if test_data:
      test_data_iter = iter(test_data)
      accuracy=0
      with torch.no_grad():
        for exp, pred in test_data_iter:
          accuracy += model.evaluate((exp, pred))
      print(f"Epoch {epoch} achieved accuracy {accuracy/len(test_data_iter)}")
    
    #generate new text
    if (length is not None and start is not None):
      with torch.no_grad():
        generate_text(model, dictionary, start, length, forget)


In [244]:
example ="i think i may have actually managed to make this thing train! great news. now i can go on to study the transformer at last. "


In [247]:
dictionary, train_dataloader= get_data(example*500, 20)

In [248]:
model, opt =get_model(2*len(dictionary.alphabet), len(dictionary.alphabet), 0.1)

In [249]:
generate_text(model, dictionary, "w", 100)

w.moygmnakndyw.ftaoes!lovitvawg!mkmtrwytoitakdfled!dagknkfdgya.dtylfdkhfghhsruhaavrfifg.tkcnt!tyv!gdg


In [254]:
model.cells.forget_everything()
for char in dictionary.alphabet: generate_text(model, dictionary, char, 100) 

ve actually managed to make this thing train! grean nos tudy to in thi k yan go oran gormer at last. 
o thin think great new lcan so mau!y train! great new sran grean new i can go on to study the transfo
 thing train! great new in tuaink gret at last. i think i may have actually managed to make this thin
this thing train! great new sonmar at last. i think i may have actually managed to make this thing tr
ke trans. trat news. i this thing train! great news. now i can go on to study the transformer at last
ged to make this thing train! great new tcin think train! great news. now i can go on to study the tr
news. now i can go on to  thdy the transformer at last. i think i may have actually managed to make t
former at last. i think i may have actually managed to make this this this thinggrratraat new traistf
w i can go on to study the transformer at last. i think i may have actually managed to make this thin
ink i may have actually managed to make this thing train! great news. now i can go

In [252]:
inp = dictionary.text_encoding(example)
text_output= model.forward(inp)
dictionary.text_decoding(text_output)


'ckd!g ggahks  ihromfwofegoss rckatudttohmddm wf a!mf!wdcrdwgmsmdt!vwwckng owuy yvmrkms.rmkrutld!mnfdgakags ahoeshs!yg! gghrs'

In [253]:
model.cells.forget_everything()
fit(100,model,opt,dictionary,train_dataloader,test_data =None, length= 100, start= "h")

memory: (tensor(-2.5514, grad_fn=<MinBackward1>), tensor(2.4136, grad_fn=<MaxBackward1>)) 
 softmax: tensor([[0.0032, 0.0021, 0.0258, 0.0147, 0.0017, 0.0408, 0.0035, 0.0039, 0.0636,
         0.0238, 0.0043, 0.0269, 0.0040, 0.0297, 0.0024, 0.0713, 0.0428, 0.0229,
         0.0062, 0.0020, 0.0037, 0.0427, 0.0021, 0.0017, 0.0030, 0.0016, 0.0487,
         0.0171, 0.0220, 0.0485, 0.0507, 0.0017, 0.0509, 0.0295, 0.0360, 0.0199,
         0.0473, 0.0380, 0.0245, 0.0369, 0.0024, 0.0057, 0.0018, 0.0521, 0.0085,
         0.0074]], grad_fn=<SoftmaxBackward0>)
he transformer at last. i think i may have actuall make acage actually managed to make thing train! g
memory: (tensor(-2.9552, grad_fn=<MinBackward1>), tensor(3.2677, grad_fn=<MaxBackward1>)) 
 softmax: tensor([[0.0025, 0.0187, 0.0276, 0.0045, 0.0088, 0.0056, 0.0289, 0.0243, 0.0030,
         0.0033, 0.0325, 0.0027, 0.0068, 0.0023, 0.0590, 0.0835, 0.0075, 0.0199,
         0.0064, 0.0062, 0.0620, 0.0064, 0.0051, 0.0025, 0.0146, 0.0038, 0.0230,
 

KeyboardInterrupt: ignored