In [1]:
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import random


In [2]:
if torch.cuda.is_available():
    DTYPE = torch.cuda.FloatTensor
else:
    DTYPE = torch.FloatTensor
    

In [3]:
text = open('texts/Lovecraft.txt','r').read()
alphabet = set(text)

ix_to_char = {k:v for k,v in enumerate(alphabet)}
char_to_ix = {k:v for v,k in enumerate(alphabet)}

In [4]:
NUM_LAYERS = 1
BATCH_SIZE = 64
HIDDEN_DIM = 128
SEQ_LEN = 65


In [5]:
def sequence_to_tensor(sequence):
    tensor = torch.zeros(len(sequence),len(alphabet)).type(DTYPE)
    for i, c in enumerate(sequence):
        tensor[i][char_to_ix[c]] = 1
    return tensor.view(BATCH_SIZE,SEQ_LEN,len(alphabet))


In [6]:
class TxtLoader(data.Dataset):
    
    def __init__(self,text):
        super(TxtLoader,self).__init__()
        self.data = text
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,index):
        return self.data[index]

In [7]:
class LSTM(nn.Module):

    def __init__(self,alphabet_size, hidden_dim, output_size):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.i2h = nn.Linear(alphabet_size,hidden_dim)
        self.lstm = nn.LSTM(hidden_dim,hidden_dim,NUM_LAYERS,batch_first=True,dropout=True)
        self.h2O = nn.Linear(hidden_dim, output_size)
        
        self.hidden = self.init_hidden()
        
        
    def init_hidden(self):
       
        return (autograd.Variable(torch.zeros(NUM_LAYERS, BATCH_SIZE, self.hidden_dim).type(DTYPE)),
                autograd.Variable(torch.zeros(NUM_LAYERS, BATCH_SIZE, self.hidden_dim).type(DTYPE)))

    def forward(self, sequence):
        out = self.i2h(sequence)
        lstm_out, self.hidden = self.lstm(out.view(BATCH_SIZE,SEQ_LEN-1,-1),self.hidden)
        out = self.h2O(lstm_out.contiguous().view(-1,self.hidden_dim))
        return out
    
    
    def gen_text(self, batch,t=None):
            
        inputs = autograd.Variable(sequence_to_tensor(batch))
        idxs = torch.zeros(inputs.data.size())
        out = self(inputs[:,:-1,:])

        if t != None:
            soft_out = F.softmax(out/t,dim=1)
            
            for i in range(soft_out.size()[0]):
                idxs[i] = np.random.choice(soft_out.size()[1],p=soft_out.data.numpy()[i])
                
        else:
            idxs = out.max(1)[1].data

        
        return out,idxs
    


In [8]:
rnn = LSTM(len(alphabet),HIDDEN_DIM,len(alphabet)).type(DTYPE)
optimizer = optim.Adam(rnn.parameters(),lr=0.01)
criterion = nn.CrossEntropyLoss()

epochs = 1000

In [9]:
def train(data_loader):
    
    best_loss = float('inf')
    
    rnn.train(True)
    for epoch in range(epochs):

        losses = np.array([])
        
        for batch in data_loader:

            rnn.zero_grad()
            rnn.hidden = rnn.init_hidden()
            
            inputs = autograd.Variable(sequence_to_tensor(batch))
            
                        
            out = rnn(inputs[:,:-1,:])
                        
            _,target = inputs[:,1:,:].topk(1)
            
            
            loss = criterion(out.view(-1,len(alphabet)),target.view(-1))
            losses = np.append(losses,loss.data[0])
            
            loss.backward()
            optimizer.step()
        
        if losses.mean() < best_loss:
            best_loss = losses.mean()
            best_wts = rnn.state_dict()
            

        print("Epoch {}/{}\nLoss: {:.2f}".format(epoch+1,epochs,losses.mean()))
        print("="*15)
        
    
    return best_wts
    


In [10]:
dataset = TxtLoader(text[:100000])
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE*SEQ_LEN,drop_last=True,num_workers=4)

In [11]:
best_wts = train(loader)
rnn.load_state_dict(best_wts)

Epoch 1/1000
Loss: 3.20
Epoch 2/1000
Loss: 2.79
Epoch 3/1000
Loss: 2.51
Epoch 4/1000
Loss: 2.36
Epoch 5/1000
Loss: 2.25
Epoch 6/1000
Loss: 2.16
Epoch 7/1000
Loss: 2.09
Epoch 8/1000
Loss: 2.03
Epoch 9/1000
Loss: 1.98
Epoch 10/1000
Loss: 1.93
Epoch 11/1000
Loss: 1.90
Epoch 12/1000
Loss: 1.86
Epoch 13/1000
Loss: 1.83
Epoch 14/1000
Loss: 1.81
Epoch 15/1000
Loss: 1.78
Epoch 16/1000
Loss: 1.76
Epoch 17/1000
Loss: 1.74
Epoch 18/1000
Loss: 1.72
Epoch 19/1000
Loss: 1.70
Epoch 20/1000
Loss: 1.69
Epoch 21/1000
Loss: 1.67
Epoch 22/1000
Loss: 1.66
Epoch 23/1000
Loss: 1.64
Epoch 24/1000
Loss: 1.63
Epoch 25/1000
Loss: 1.62
Epoch 26/1000
Loss: 1.61
Epoch 27/1000
Loss: 1.60
Epoch 28/1000
Loss: 1.59
Epoch 29/1000
Loss: 1.58
Epoch 30/1000
Loss: 1.57
Epoch 31/1000
Loss: 1.56
Epoch 32/1000
Loss: 1.55
Epoch 33/1000
Loss: 1.54
Epoch 34/1000
Loss: 1.53
Epoch 35/1000
Loss: 1.53
Epoch 36/1000
Loss: 1.52
Epoch 37/1000
Loss: 1.51
Epoch 38/1000
Loss: 1.50
Epoch 39/1000
Loss: 1.49
Epoch 40/1000
Loss: 1.49
Epoch 41/

Epoch 199/1000
Loss: 1.21
Epoch 200/1000
Loss: 1.21
Epoch 201/1000
Loss: 1.20
Epoch 202/1000
Loss: 1.20
Epoch 203/1000
Loss: 1.19
Epoch 204/1000
Loss: 1.19
Epoch 205/1000
Loss: 1.18
Epoch 206/1000
Loss: 1.19
Epoch 207/1000
Loss: 1.19
Epoch 208/1000
Loss: 1.19
Epoch 209/1000
Loss: 1.19
Epoch 210/1000
Loss: 1.19
Epoch 211/1000
Loss: 1.19
Epoch 212/1000
Loss: 1.19
Epoch 213/1000
Loss: 1.19
Epoch 214/1000
Loss: 1.19
Epoch 215/1000
Loss: 1.19
Epoch 216/1000
Loss: 1.19
Epoch 217/1000
Loss: 1.19
Epoch 218/1000
Loss: 1.19
Epoch 219/1000
Loss: 1.19
Epoch 220/1000
Loss: 1.19
Epoch 221/1000
Loss: 1.19
Epoch 222/1000
Loss: 1.19
Epoch 223/1000
Loss: 1.18
Epoch 224/1000
Loss: 1.18
Epoch 225/1000
Loss: 1.18
Epoch 226/1000
Loss: 1.17
Epoch 227/1000
Loss: 1.16
Epoch 228/1000
Loss: 1.16
Epoch 229/1000
Loss: 1.18
Epoch 230/1000
Loss: 1.18
Epoch 231/1000
Loss: 1.17
Epoch 232/1000
Loss: 1.16
Epoch 233/1000
Loss: 1.15
Epoch 234/1000
Loss: 1.14
Epoch 235/1000
Loss: 1.14
Epoch 236/1000
Loss: 1.14
Epoch 237/10

Epoch 395/1000
Loss: 1.08
Epoch 396/1000
Loss: 1.08
Epoch 397/1000
Loss: 1.07
Epoch 398/1000
Loss: 1.06
Epoch 399/1000
Loss: 1.08
Epoch 400/1000
Loss: 1.08
Epoch 401/1000
Loss: 1.08
Epoch 402/1000
Loss: 1.08
Epoch 403/1000
Loss: 1.08
Epoch 404/1000
Loss: 1.10
Epoch 405/1000
Loss: 1.12
Epoch 406/1000
Loss: 1.10
Epoch 407/1000
Loss: 1.08
Epoch 408/1000
Loss: 1.08
Epoch 409/1000
Loss: 1.09
Epoch 410/1000
Loss: 1.09
Epoch 411/1000
Loss: 1.08
Epoch 412/1000
Loss: 1.07
Epoch 413/1000
Loss: 1.06
Epoch 414/1000
Loss: 1.07
Epoch 415/1000
Loss: 1.07
Epoch 416/1000
Loss: 1.10
Epoch 417/1000
Loss: 1.10
Epoch 418/1000
Loss: 1.10
Epoch 419/1000
Loss: 1.11
Epoch 420/1000
Loss: 1.10
Epoch 421/1000
Loss: 1.09
Epoch 422/1000
Loss: 1.08
Epoch 423/1000
Loss: 1.07
Epoch 424/1000
Loss: 1.07
Epoch 425/1000
Loss: 1.08
Epoch 426/1000
Loss: 1.09
Epoch 427/1000
Loss: 1.09
Epoch 428/1000
Loss: 1.08
Epoch 429/1000
Loss: 1.08
Epoch 430/1000
Loss: 1.09
Epoch 431/1000
Loss: 1.09
Epoch 432/1000
Loss: 1.11
Epoch 433/10

Epoch 591/1000
Loss: 1.15
Epoch 592/1000
Loss: 1.17
Epoch 593/1000
Loss: 1.14
Epoch 594/1000
Loss: 1.12
Epoch 595/1000
Loss: 1.12
Epoch 596/1000
Loss: 1.11
Epoch 597/1000
Loss: 1.09
Epoch 598/1000
Loss: 1.08
Epoch 599/1000
Loss: 1.12
Epoch 600/1000
Loss: 1.16
Epoch 601/1000
Loss: 1.16
Epoch 602/1000
Loss: 1.16
Epoch 603/1000
Loss: 1.13
Epoch 604/1000
Loss: 1.10
Epoch 605/1000
Loss: 1.08
Epoch 606/1000
Loss: 1.12
Epoch 607/1000
Loss: 1.11
Epoch 608/1000
Loss: 1.10
Epoch 609/1000
Loss: 1.09
Epoch 610/1000
Loss: 1.09
Epoch 611/1000
Loss: 1.09
Epoch 612/1000
Loss: 1.09
Epoch 613/1000
Loss: 1.09
Epoch 614/1000
Loss: 1.06
Epoch 615/1000
Loss: 1.06
Epoch 616/1000
Loss: 1.07
Epoch 617/1000
Loss: 1.07
Epoch 618/1000
Loss: 1.07
Epoch 619/1000
Loss: 1.06
Epoch 620/1000
Loss: 1.06
Epoch 621/1000
Loss: 1.06
Epoch 622/1000
Loss: 1.06
Epoch 623/1000
Loss: 1.07
Epoch 624/1000
Loss: 1.08
Epoch 625/1000
Loss: 1.08
Epoch 626/1000
Loss: 1.09
Epoch 627/1000
Loss: 1.07
Epoch 628/1000
Loss: 1.06
Epoch 629/10

Epoch 787/1000
Loss: 1.11
Epoch 788/1000
Loss: 1.10
Epoch 789/1000
Loss: 1.09
Epoch 790/1000
Loss: 1.10
Epoch 791/1000
Loss: 1.12
Epoch 792/1000
Loss: 1.13
Epoch 793/1000
Loss: 1.16
Epoch 794/1000
Loss: 1.16
Epoch 795/1000
Loss: 1.15
Epoch 796/1000
Loss: 1.15
Epoch 797/1000
Loss: 1.15
Epoch 798/1000
Loss: 1.15
Epoch 799/1000
Loss: 1.15
Epoch 800/1000
Loss: 1.26
Epoch 801/1000
Loss: 1.22
Epoch 802/1000
Loss: 1.21
Epoch 803/1000
Loss: 1.18
Epoch 804/1000
Loss: 1.14
Epoch 805/1000
Loss: 1.12
Epoch 806/1000
Loss: 1.12
Epoch 807/1000
Loss: 1.10
Epoch 808/1000
Loss: 1.09
Epoch 809/1000
Loss: 1.07
Epoch 810/1000
Loss: 1.07
Epoch 811/1000
Loss: 1.08
Epoch 812/1000
Loss: 1.12
Epoch 813/1000
Loss: 1.13
Epoch 814/1000
Loss: 1.10
Epoch 815/1000
Loss: 1.10
Epoch 816/1000
Loss: 1.14
Epoch 817/1000
Loss: 1.17
Epoch 818/1000
Loss: 1.17
Epoch 819/1000
Loss: 1.16
Epoch 820/1000
Loss: 1.12
Epoch 821/1000
Loss: 1.10
Epoch 822/1000
Loss: 1.10
Epoch 823/1000
Loss: 1.09
Epoch 824/1000
Loss: 1.11
Epoch 825/10

Epoch 983/1000
Loss: 1.18
Epoch 984/1000
Loss: 1.19
Epoch 985/1000
Loss: 1.21
Epoch 986/1000
Loss: 1.19
Epoch 987/1000
Loss: 1.17
Epoch 988/1000
Loss: 1.14
Epoch 989/1000
Loss: 1.14
Epoch 990/1000
Loss: 1.13
Epoch 991/1000
Loss: 1.12
Epoch 992/1000
Loss: 1.14
Epoch 993/1000
Loss: 1.13
Epoch 994/1000
Loss: 1.13
Epoch 995/1000
Loss: 1.10
Epoch 996/1000
Loss: 1.08
Epoch 997/1000
Loss: 1.10
Epoch 998/1000
Loss: 1.14
Epoch 999/1000
Loss: 1.18
Epoch 1000/1000
Loss: 1.16


In [14]:
string = text[0]  


rnn.train(False)

for batch in loader:
    
    _ ,idxs = rnn.gen_text(batch)
    
#for i in range(100):
    
    #out = rnn(out)
    #soft_out = F.softmax(out/t,dim=1)

for c in idxs:
    string += ix_to_char[c]

print(string)           



#print(string,file=open('texts/output.txt','w'))

Ts,yy tss  Asmi way not The sley prrsonaaho hver fisitid toe sracd and tes sisit  whre tegameng sreerdand totere Aeen whhool oleueetoe srrdners aere borteally tottwf  toom the sorld  and tome hantaeatarmi aacshe r txratte on thw   Ther sere taineng totiossl aedh peraicilly p d pon icly  and totmne ihs ntreeised thir Ihetoa  of taea Iorknerss senness ttane sneund  
It wevpenes tn aogs anout the srcive ssty tf the seteoris trrl  and the srlrltoranataeamed anout theng  tn the snr thich htowwould not botpribed I Li  seving the e was now f ttngle wtocimic botng aut wn y aerysand tremeunsi Thesgs ioved O d toarted tnd toethered  and tvrs tt ues so stptste' which hhre no  ohi ly auunds  Iome hing whs nh;oetnay t soiwaas tetng teewn d tn ttme hing f tomething hhs nrr rtng tn elf wn tir chet tdtht tot to be s Go etse aast bade hn ib laf  t sot ing whs nven strll in the seght t ahe sotls and thgsn :aeepeer. Tohum sos not tuet tea sh she sountr fn lum  aut wreiaaesasded.tnout the souse ts tiogrtn

In [13]:
# pre-process txt file a bit
# hyperparameters
# refactor