In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [32]:
device

device(type='cuda')

In [33]:
def count_true(a1, a2):
  e = torch.eq(a1, a2)
  e = e.cpu().numpy()
  c = np.count_nonzero(e)
  return c

In [34]:
def preprocess():
   # load the text file
    data = open("dataset.txt", 'r').read()
    chars = sorted(list(set(data)))
    data_size, vocab_size = len(data), len(chars)
    
    # char to index and index to char maps
    char_to_ix = { ch:i for i,ch in enumerate(chars) }
    ix_to_char = { i:ch for i,ch in enumerate(chars) }
    
    # convert data from chars to indices
    data = list(data)
    for i, ch in enumerate(data):
        data[i] = char_to_ix[ch]
    # data tensor on device
    data = torch.tensor(data).to(device)
    data = torch.unsqueeze(data, dim=1)
    one_hot_encoded = F.one_hot(data).float()
    
    return data , ix_to_char , data_size, vocab_size, one_hot_encoded

In [35]:
class RNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(input_size, input_size)
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, input_seq, hidden_state):
        embedding = self.embedding(input_seq)
        output, hidden_state = self.rnn(embedding, hidden_state)
        output = self.fc(output)
        return output, (hidden_state[0].detach(), hidden_state[1].detach())

In [36]:
def test(data,data_size,rnn,ix_to_char):

        data_ptr = 0
        hidden_state = None
        # random character
        rand_index = np.random.randint(data_size-1)
        input_seq = data[rand_index : rand_index+1]
        
        for i in range(400):
          
            # forward pass
            output, hidden_state = rnn(input_seq, hidden_state)
            
            # construct categorical distribution and sample a character
            output = F.softmax(torch.squeeze(output), dim=0)
            dist = Categorical(output)
            index = dist.sample()
            
            # print the sampled character
            print(ix_to_char[index.item()], end='')
            
            # next input is current output
            input_seq[0][0] = index.item()
            data_ptr += 1
            

In [37]:
def train(data , ix_to_char, data_size, vocab_size,rnn,epochs,seq_len,loss_fn,optimizer):
    acc = []
    loss_list = []
    for i_epoch in range(1, epochs+1):
        
        # random starting point (1st 100 chars) from data to begin
        data_ptr = np.random.randint(100)
        n = 0
        running_loss = 0
        hidden_state = None
        true_predicts = 0
        acc = []
        loss_list = [] 
        while True:
            input_seq = data[data_ptr : data_ptr+seq_len]
            target_seq = data[data_ptr+1 : data_ptr+seq_len+1]
            #print("target_seq : ",torch.squeeze( target_seq))
            
            # forward pass
            output, hidden_state = rnn(input_seq, hidden_state)
            
            # compute loss
            loss = loss_fn(torch.squeeze(output), torch.squeeze(target_seq))
            running_loss += loss.item()
            output = F.softmax(torch.squeeze(output), dim=0)
            #print("output : ", output)
            dist = Categorical(output)
            #print("dist : ", dist)
            index = dist.sample()
            #print("index : ", index)
            # compute gradients and take optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # compute true predicts
            true_predicts += count_true(torch.squeeze( target_seq), index)
            
            # update the data pointer
            data_ptr += seq_len
            n +=1
            
            # if at end of data : break
            if data_ptr + seq_len + 1 > data_size:
                break
        acc.append(true_predicts*100/n)  
        loss_list.append(running_loss/n)  
        # print loss and save weights after every epoch
        print("Epoch: {0} \t Loss: {1:.4f} \t accuracy: {2:.4f}".format(i_epoch, running_loss/n,true_predicts*100/data_size))
        #torch.save(rnn.state_dict(), save_path)
        
    return acc,loss_list


In [38]:
def runn(hidden_size, seq_len, num_layers, lr, epochs):
  # Hyperparameters
  hidden_size = hidden_size  
  # length of LSTM sequence 
  seq_len = seq_len
  # num of layers in LSTM layer stack      
  num_layers =num_layers      
  lr = lr   
  epochs = epochs
  data , ix_to_char, data_size, vocab_size , one_hot_encoded=  preprocess()
  # model 
  rnn = RNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)
  # loss function and optimizer
  loss_fn = nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(rnn.parameters(), lr=lr, momentum=0.9)

  acc_list,loss_list = train(data , ix_to_char, data_size, vocab_size,rnn,epochs,seq_len,loss_fn,optimizer)
  print("generate text ------------------------------")
  test(data,data_size,rnn,ix_to_char)
  return acc_list,loss_list

In [39]:
acc_list,loss_list = runn(512, 250, 3, 0.01, 20)

Epoch: 1 	 Loss: 2.9722 	 accuracy: 1.6665
Epoch: 2 	 Loss: 2.2304 	 accuracy: 4.6031
Epoch: 3 	 Loss: 1.8061 	 accuracy: 7.9990
Epoch: 4 	 Loss: 1.5785 	 accuracy: 10.1823
Epoch: 5 	 Loss: 1.4397 	 accuracy: 11.4275
Epoch: 6 	 Loss: 1.3485 	 accuracy: 12.2495
Epoch: 7 	 Loss: 1.2833 	 accuracy: 12.8821
Epoch: 8 	 Loss: 1.2331 	 accuracy: 13.3571
Epoch: 9 	 Loss: 1.1920 	 accuracy: 13.6936
Epoch: 10 	 Loss: 1.1572 	 accuracy: 14.0096
Epoch: 11 	 Loss: 1.1263 	 accuracy: 14.3101
Epoch: 12 	 Loss: 1.0982 	 accuracy: 14.5413
Epoch: 13 	 Loss: 1.0722 	 accuracy: 14.7159
Epoch: 14 	 Loss: 1.0473 	 accuracy: 14.8133
Epoch: 15 	 Loss: 1.0233 	 accuracy: 14.9405
Epoch: 16 	 Loss: 1.0000 	 accuracy: 15.0789
Epoch: 17 	 Loss: 0.9770 	 accuracy: 15.1254
Epoch: 18 	 Loss: 0.9540 	 accuracy: 15.1197
Epoch: 19 	 Loss: 0.9310 	 accuracy: 15.1636
Epoch: 20 	 Loss: 0.9077 	 accuracy: 15.2136
generate text ------------------------------
ng fouty "Wais Diggory."
"He," said Hermione.  Malfoy panced as Har

In [40]:
def runn(hidden_size, seq_len, num_layers, lr, epochs):
  # Hyperparameters
  hidden_size = hidden_size  
  # length of LSTM sequence 
  seq_len = seq_len
  # num of layers in LSTM layer stack      
  num_layers =num_layers      
  lr = lr   
  epochs = epochs
  data , ix_to_char, data_size, vocab_size , one_hot_encoded=  preprocess()
  # model 
  rnn = RNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)
  # loss function and optimizer
  loss_fn = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)
  
  acc_list,loss_list = train(data , ix_to_char, data_size, vocab_size,rnn,epochs,seq_len,loss_fn,optimizer)
  print("generate text ------------------------------")
  test(data,data_size,rnn,ix_to_char)
  return acc_list,loss_list

In [41]:
acc_list,loss_list = runn(64, 250, 3, 0.01, 40)

Epoch: 1 	 Loss: 1.8494 	 accuracy: 13.4557
Epoch: 2 	 Loss: 1.6121 	 accuracy: 17.8940
Epoch: 3 	 Loss: 1.5882 	 accuracy: 18.9166
Epoch: 4 	 Loss: 1.5752 	 accuracy: 19.4338
Epoch: 5 	 Loss: 1.5634 	 accuracy: 20.1993
Epoch: 6 	 Loss: 1.5538 	 accuracy: 20.9433
Epoch: 7 	 Loss: 1.5485 	 accuracy: 21.3221
Epoch: 8 	 Loss: 1.5463 	 accuracy: 21.8767
Epoch: 9 	 Loss: 1.5507 	 accuracy: 21.9727
Epoch: 10 	 Loss: 1.5381 	 accuracy: 22.2864
Epoch: 11 	 Loss: 1.5439 	 accuracy: 22.6077
Epoch: 12 	 Loss: 1.5341 	 accuracy: 22.9477
Epoch: 13 	 Loss: 1.5311 	 accuracy: 23.0985
Epoch: 14 	 Loss: 1.5427 	 accuracy: 22.9727
Epoch: 15 	 Loss: 1.5523 	 accuracy: 22.7914
Epoch: 16 	 Loss: 1.5513 	 accuracy: 23.0826
Epoch: 17 	 Loss: 1.5490 	 accuracy: 23.3239
Epoch: 18 	 Loss: 1.5563 	 accuracy: 23.5421
Epoch: 19 	 Loss: 1.5469 	 accuracy: 23.9405
Epoch: 20 	 Loss: 1.5596 	 accuracy: 23.7353
Epoch: 21 	 Loss: 1.5556 	 accuracy: 24.0492
Epoch: 22 	 Loss: 1.5504 	 accuracy: 24.0721
Epoch: 23 	 Loss: 1