<a href="https://colab.research.google.com/github/dejanbatanjac/pytorch-learning-101/blob/master/Char_RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.utils.data as utils
import torch.utils.data
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(0)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
np.random.seed(0)


In [15]:
! wget https://raw.githubusercontent.com/dejanbatanjac/pytorch-learning-101/master/rij.txt

--2019-03-14 12:08:05--  https://raw.githubusercontent.com/dejanbatanjac/pytorch-learning-101/master/rij.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 159216 (155K) [text/plain]
Saving to: ‘rij.txt.2’


2019-03-14 12:08:06 (5.01 MB/s) - ‘rij.txt.2’ saved [159216/159216]



In [16]:
! head rij.txt

ACT I
PROLOGUE

    Two households, both alike in dignity,
    In fair Verona, where we lay our scene,
    From ancient grudge break to new mutiny,
    Where civil blood makes civil hands unclean.
    From forth the fatal loins of these two foes
    A pair of star-cross'd lovers take their life;
    Whose misadventured piteous overthrows


In [17]:
text = open('rij.txt', 'r').read() # should be simple plain text file
chars = sorted(list(set(text)))
chars.insert(0, "\0") #cannot imagine world w/o this character
chars.insert(1, "\t") #cannot imagine world w/o this character
print(len(chars))
print(chars)

66
['\x00', '\t', '\n', ' ', '!', '&', "'", ',', '-', '.', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'Z', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [18]:
data_size,vocab_size = len(text),len(chars)
print('data has %d characters, %d unique.' % (data_size, vocab_size))

char_indices = { ch:i for i,ch in enumerate(chars) }
indices_char = { i:ch for i,ch in enumerate(chars) }

print(char_indices)
print(indices_char)

data has 153168 characters, 66 unique.
{'\x00': 0, '\t': 1, '\n': 2, ' ': 3, '!': 4, '&': 5, "'": 6, ',': 7, '-': 8, '.': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'Y': 36, 'Z': 37, '[': 38, ']': 39, 'a': 40, 'b': 41, 'c': 42, 'd': 43, 'e': 44, 'f': 45, 'g': 46, 'h': 47, 'i': 48, 'j': 49, 'k': 50, 'l': 51, 'm': 52, 'n': 53, 'o': 54, 'p': 55, 'q': 56, 'r': 57, 's': 58, 't': 59, 'u': 60, 'v': 61, 'w': 62, 'x': 63, 'y': 64, 'z': 65}
{0: '\x00', 1: '\t', 2: '\n', 3: ' ', 4: '!', 5: '&', 6: "'", 7: ',', 8: '-', 9: '.', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'Y', 37: 'Z', 38: '[', 39: ']', 40: 'a', 41: 'b

In [19]:
# little data preview
idx = [char_indices[c] for c in text]    
print(idx[:100])

''.join(indices_char[i] for i in idx[:100])

[13, 15, 32, 3, 21, 2, 28, 30, 27, 24, 27, 19, 33, 17, 2, 2, 3, 3, 3, 3, 32, 62, 54, 3, 47, 54, 60, 58, 44, 47, 54, 51, 43, 58, 7, 3, 41, 54, 59, 47, 3, 40, 51, 48, 50, 44, 3, 48, 53, 3, 43, 48, 46, 53, 48, 59, 64, 7, 2, 3, 3, 3, 3, 21, 53, 3, 45, 40, 48, 57, 3, 34, 44, 57, 54, 53, 40, 7, 3, 62, 47, 44, 57, 44, 3, 62, 44, 3, 51, 40, 64, 3, 54, 60, 57, 3, 58, 42, 44, 53]


'ACT I\nPROLOGUE\n\n    Two households, both alike in dignity,\n    In fair Verona, where we lay our scen'

In [20]:
cs = 8
c_in_dat = [[] for _ in range(cs)]
for i in range(cs):#0..7    
    c_in_dat[i] = np.stack([idx[j] for j in range(i, len(idx)-cs-1, cs)])
    

# the prediction
c_out_dat = [idx[j] for j in range(cs, len(idx)-cs-1, cs)]

# transform list to torch tensors, and don't take the last few characters
y  = np.stack(c_out_dat)

print("input")    
for i in range(cs): print(c_in_dat[i],"len:" ,len(c_in_dat[i]))
print("prediction:")
print(y, "len:", len(y))

# we need to have the same length for all 
for i in range(cs): c_in_dat[i]= c_in_dat[i][:len(y)]
print("input again after cut off")    
for i in range(cs): print(c_in_dat[i], "len:" , len(c_in_dat[i]))
    



input
[13 27  3 ... 51  3 44] len: 19145
[15 24  3 ... 48 47 54] len: 19145
[32 27  3 ... 44 44  9] len: 19145
[ 3 19  3 ... 59 57  2] len: 19145
[21 33 32 ...  3  3  2] len: 19145
[ 2 17 62 ... 40 30  3] len: 19145
[28  2 54 ... 53 54  3] len: 19145
[30  2  3 ... 60 43 52] len: 19144
prediction:
[27  3 47 ... 51  3 44] len: 19144
input again after cut off
[13 27  3 ... 48 51  3] len: 19144
[15 24  3 ... 58 48 47] len: 19144
[32 27  3 ...  3 44 44] len: 19144
[ 3 19  3 ... 54 59 57] len: 19144
[21 33 32 ... 45  3  3] len: 19144
[ 2 17 62 ...  3 40 30] len: 19144
[28  2 54 ... 22 53 54] len: 19144
[30  2  3 ... 60 43 52] len: 19144


In [21]:
# these are the data again
l = list()
for i in range(cs):
    l.append(c_in_dat[i])
    

x = np.stack(l, axis=1)
print(len(x))
print(x.shape)
print(x)

19144
(19144, 8)
[[13 15 32 ...  2 28 30]
 [27 24 27 ... 17  2  2]
 [ 3  3  3 ... 62 54  3]
 ...
 [48 58  3 ...  3 22 60]
 [51 48 44 ... 40 53 43]
 [ 3 47 44 ... 30 54 52]]


In [22]:
# converting to tensor
X = torch.from_numpy(x).cuda()
Y = torch.from_numpy(y).cuda()

print(X.shape)
print(Y.shape)

# batch size
bs = 512

#now we create Dataset and DataLoader
ds = utils.TensorDataset(X, Y) 
dl = utils.DataLoader(ds, batch_size=bs, shuffle=False)


mb, yt = next(iter(dl))
print("mini-batch shape", mb.shape)
print("prediction shape", yt.shape)




torch.Size([19144, 8])
torch.Size([19144])
mini-batch shape torch.Size([512, 8])
prediction shape torch.Size([512])


In [0]:
# hidden activation states
# this is something we define 
# also called hidden features
n_hidden = 256

# latent factors, the size of the input embedding
n_fac = 33 

class CharRNN(nn.Module):
    def __init__(self, vocab_size, n_fac):
        super().__init__()
        
        self.e = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.RNN(n_fac, n_hidden)        
        self.l_out = nn.Linear(n_hidden, vocab_size)

        
        
    def forward(self, *cs): 
               
        bs = cs[0].size(0)
        h = torch.zeros(1, bs, n_hidden).to("cuda")
        inp = self.e(torch.stack(cs))
        outp, h = self.rnn(inp, h)
        
        # just return the last state (-1)
        return torch.log_softmax(self.l_out(outp[-1]),1) 



In [24]:
# create the model as m
torch.manual_seed(0)
np.random.seed(0)

m = CharRNN(vocab_size, n_fac).cuda()

t = [o.numel() for o in m.parameters() ]
print(t, sum(t))
# print(list(m.parameters()))
print(m)



[2772, 10752, 65536, 256, 256, 16896, 66] 96534
CharRNN(
  (e): Embedding(66, 42)
  (rnn): RNN(42, 256)
  (l_out): Linear(in_features=256, out_features=66, bias=True)
)


In [25]:
# train phase
m.train()

# set Adam optimizer
opt = optim.Adam(m.parameters(), lr = 0.001)

# set the loss function
loss_fn = nn.NLLLoss()

# would be nice to fine tune this number
num_epochs = 40
# once = False

for epoch in range(num_epochs):
    
    for mb,yt in dl: 
                
        if(mb.size(0)!=bs):
            # next epoch
            continue
        
        tup = torch.unbind(mb, dim=1) 
        y_hat = m(*tup)
        
        # calculate loss
        loss = loss_fn(y_hat, yt)

        # go backward()
        opt.zero_grad()
        loss.backward()

        # update params
        opt.step()
    
    print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
        
        
        
        

Epoch [1/40], Loss: 2.8263
Epoch [2/40], Loss: 2.4734
Epoch [3/40], Loss: 2.3161
Epoch [4/40], Loss: 2.2191
Epoch [5/40], Loss: 2.1500
Epoch [6/40], Loss: 2.0782
Epoch [7/40], Loss: 2.0194
Epoch [8/40], Loss: 1.9661
Epoch [9/40], Loss: 1.9167
Epoch [10/40], Loss: 1.8710
Epoch [11/40], Loss: 1.8285
Epoch [12/40], Loss: 1.7890
Epoch [13/40], Loss: 1.7541
Epoch [14/40], Loss: 1.7299
Epoch [15/40], Loss: 1.6853
Epoch [16/40], Loss: 1.6514
Epoch [17/40], Loss: 1.6203
Epoch [18/40], Loss: 1.5895
Epoch [19/40], Loss: 1.5592
Epoch [20/40], Loss: 1.5290
Epoch [21/40], Loss: 1.4982
Epoch [22/40], Loss: 1.4662
Epoch [23/40], Loss: 1.4328
Epoch [24/40], Loss: 1.3989
Epoch [25/40], Loss: 1.3665
Epoch [26/40], Loss: 1.3377
Epoch [27/40], Loss: 1.3128
Epoch [28/40], Loss: 1.2891
Epoch [29/40], Loss: 1.2663
Epoch [30/40], Loss: 1.2569
Epoch [31/40], Loss: 1.2575
Epoch [32/40], Loss: 1.2457
Epoch [33/40], Loss: 1.2290
Epoch [34/40], Loss: 1.1476
Epoch [35/40], Loss: 1.0907
Epoch [36/40], Loss: 1.0500
E

In [26]:
m.eval()
torch.manual_seed(0)
np.random.seed(0)

# how many characters we predict
bptt = 8

# predict next bptt-th character
def get_next(inp): 
    inp = inp[-bptt:] 
    idxs = np.array([char_indices[c] for c in inp ])
        
    t = torch.from_numpy(idxs).cuda() #tensor([30, 27, 24, 27, 19, 33, 17,  3], device='cuda:0')
    unb = torch.unbind(t, dim=-1)

    # set e dimension to 1 (was 0)
    for e in unb: 
        if(e.dim()==0): e.unsqueeze_(0)
    p = m(*unb)
  
    
    # grab the index of the max element
    max,ind = p.max(1)
    ind = ind.item() #detach().cpu().numpy()  
    return chars[ind]
    


inp = "PROLOGUE"
get_next(inp)



while(len(inp)<1000):
    nc = get_next(inp)
    inp = inp+nc
    
print(inp)

    

PROLOGUE

    That with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow should breads,
    By heart foome, what not me where with thee be that vill thee slow sh