In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from load_data import *
import numpy as np
import time
from rnn_layers_torch import *

In [3]:
class RNN(nn.Module):

    def __init__(self, word_to_idx, wordvec_dim, hidden_dim, cell_type, seed):
        super(RNN, self).__init__()

        vocab_size = len(word_to_idx)
        self.start_token = word_to_idx["<START>"]
        self.null_token = word_to_idx["<NULL>"]
        self.end_token = word_to_idx["<END>"]
        self.cell_type = cell_type
        self.params = {}

        if(seed is not None):
            np.random.seed(seed)

        self.params["W_embed"] = np.random.randn(vocab_size, wordvec_dim)
        self.params["W_embed"] /= 100

        dim_mul = {"lstm": 4, "rnn": 1}[cell_type]
        self.params["Wx"] = np.random.randn(wordvec_dim, dim_mul * hidden_dim)
        self.params["Wx"] /= np.sqrt(wordvec_dim)
        self.params["Wh"] = np.random.randn(hidden_dim, dim_mul * hidden_dim)
        self.params["Wh"] /= np.sqrt(hidden_dim)
        self.params["b"] = np.zeros(dim_mul * hidden_dim)

        self.params["W_vocab"] = np.random.randn(hidden_dim, vocab_size)
        self.params["W_vocab"] /= np.sqrt(hidden_dim)
        self.params["b_vocab"] = np.zeros(vocab_size)

        self.params["h_init"] = np.random.randn(hidden_dim)

        for key in self.params.keys():
            self.params[key] = self.params[key].astype(np.float32)
            self.params[key] = torch.from_numpy(self.params[key])
            self.params[key].requires_grad = True


        # self.params["W_embed"] = torch.randn(vocab_size, wordvec_dim, requires_grad=True) / 100

        # dim_mul = {"lstm": 4, "rnn": 1}[cell_type]
        # self.params["Wx"] = torch.randn(wordvec_dim, dim_mul * hidden_dim, requires_grad=True)
        # self.params["Wh"] = torch.randn(hidden_dim, dim_mul * hidden_dim,requires_grad=True)
        # self.params["b"] = torch.zeros(dim_mul * hidden_dim,requires_grad=True)
        # self.params["W_vocab"] = torch.randn(hidden_dim, vocab_size,requires_grad=True)
        # self.params["b_vocab"] = torch.zeros(vocab_size, requires_grad=True)
        # self.params["h_init"] = torch.randn(hidden_dim, requires_grad=True)

    def forward(self, captions):

        captions_out = captions[:,1:].clone()
        captions_in = captions[:,:-1]
        captions_out[:,0:8] = self.null_token
        mask = captions_out != self.null_token
        # print(captions_out)
        # print(captions_in)
        # print(mask)
        N = captions.shape[0]
        h0 = torch.tile(self.params["h_init"], (N, 1))
        h = None

        inputs = word_embedding_forward(captions_in, self.params["W_embed"])
        if(self.cell_type == "rnn"):
            h = rnn_forward(inputs, h0, self.params["Wx"], self.params["Wh"], self.params["b"])
        elif(self.cell_type == "lstm"):
            h = lstm_forward(inputs, h0, self.params["Wx"], self.params["Wh"], self.params["b"])
        else:
            return None
        out = temporal_affine_forward(h, self.params["W_vocab"], self.params["b_vocab"])
        loss = temporal_softmax_loss(out, captions_out, mask)

        return loss
    
    def load(self, params):
        for key in self.params.keys():
            self.params[key] = params[key]

In [5]:
file_name = "data/2_digit_ops.txt"
lwflag = 1 # 0 for words, 1 for letter
word_to_idx = None

if(lwflag):
    word_to_idx = make_dict_letter(file_name)
else:
    word_to_idx = make_dict(file_name)
reverse_dict = {}

for keys, value in word_to_idx.items():
    reverse_dict[value] = keys
pprint(reverse_dict)

{0: 'B',
 1: 'N',
 2: '3',
 3: '0',
 4: 's',
 5: '4',
 6: 'E',
 7: 'P',
 8: '\n',
 9: 'a',
 10: '1',
 11: '5',
 12: '9',
 13: '2',
 14: '8',
 15: '6',
 16: '7',
 17: '<START>',
 18: '<NULL>',
 19: '<END>'}


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)

rnn = RNN(word_to_idx, 128, 128, "lstm", seed=7)

epochs = 100
learning_rate = 0.001
data = None
# if(lwflag):
#     data = load_data_letter(word_to_idx, file_name, lines_count=2, max_train=100)
# else:
#     data = load_data(word_to_idx, file_name, lines_count=1, max_train=4)
data = load_data_letter_2_digit_ops(word_to_idx, file_name)
data = torch.from_numpy(data)


In [11]:
for i in range(len(data)):
    words = [reverse_dict[val.item()] for val in data[i]]
    print((words))
    # print("******")

['B', 'N', '1', '0', 's', 'N', '1', '3', 'E', 'B', 'P', '0', '3', 'E', '\n']
['B', 'N', '3', '3', 'a', 'N', '4', '8', 'E', 'B', 'N', '8', '1', 'E', '\n']
['B', 'N', '0', '8', 's', 'P', '3', '4', 'E', 'B', 'N', '4', '2', 'E', '\n']
['B', 'N', '3', '1', 's', 'N', '2', '7', 'E', 'B', 'N', '0', '4', 'E', '\n']
['B', 'N', '0', '5', 's', 'P', '2', '0', 'E', 'B', 'N', '2', '5', 'E', '\n']
['B', 'N', '3', '5', 's', 'N', '2', '6', 'E', 'B', 'N', '0', '9', 'E', '\n']
['B', 'N', '3', '0', 'a', 'P', '2', '4', 'E', 'B', 'N', '0', '6', 'E', '\n']
['B', 'N', '3', '3', 'a', 'N', '2', '4', 'E', 'B', 'N', '5', '7', 'E', '\n']
['B', 'N', '4', '5', 'a', 'N', '3', '7', 'E', 'B', 'N', '8', '2', 'E', '\n']
['B', 'N', '2', '1', 's', 'N', '2', '1', 'E', 'B', 'P', '0', '0', 'E', '\n']
['B', 'P', '0', '0', 'a', 'N', '0', '5', 'E', 'B', 'N', '0', '5', 'E', '\n']
['B', 'N', '1', '8', 'a', 'N', '3', '4', 'E', 'B', 'N', '5', '2', 'E', '\n']
['B', 'N', '2', '6', 's', 'P', '1', '5', 'E', 'B', 'N', '4', '1', 'E', '\n']

In [13]:
loaded = torch.load("check.pt", map_location=device)
rnn.load(loaded["params"])

In [14]:
parameters = []

for key in rnn.params.keys():
    parameters.append(rnn.params[key])

learning_rate = 1e-3
optimizer = optim.RMSprop(parameters, lr=learning_rate)
# optimizer.load_state_dict(loaded["optime"])

for i in range(epochs):
    
    # if(lwflag):
    #     data = load_data_letter(word_to_idx, file_name, lines_count=1, max_train=100)
    # else:
    #     data = load_data(word_to_idx, file_name, lines_count=1, max_train=4)

    data = load_data_letter_2_digit_ops(word_to_idx, file_name, max_train=3000)

    data = torch.from_numpy(data)
    # words = [reverse_dict[val] for val in data[:].tolist()[1]]
    # print(words)
    loss = rnn(data)
    print(loss)
    for keys in rnn.params.keys():
        rnn.params[keys].retain_grad()
    
    optimizer.zero_grad()  # Clear gradients
    loss.backward(retain_graph=True)
    optimizer.step()
    # with torch.no_grad():
    #     for key, value in rnn.params.items():
    #         rnn.params[key] -= learning_rate * rnn.params[key].grad
    #         rnn.params[key].grad.zero_()

torch.save({
    "params" : rnn.params,
    "optime" : optimizer.state_dict()
}, "check.pt")


tensor(5.1830, grad_fn=<DivBackward0>)
tensor(7.4691, grad_fn=<DivBackward0>)
tensor(14.8676, grad_fn=<DivBackward0>)
tensor(5.7286, grad_fn=<DivBackward0>)
tensor(5.2638, grad_fn=<DivBackward0>)
tensor(5.2113, grad_fn=<DivBackward0>)
tensor(5.1934, grad_fn=<DivBackward0>)
tensor(5.1833, grad_fn=<DivBackward0>)
tensor(5.1734, grad_fn=<DivBackward0>)
tensor(5.1671, grad_fn=<DivBackward0>)
tensor(5.1623, grad_fn=<DivBackward0>)
tensor(5.1476, grad_fn=<DivBackward0>)
tensor(5.1411, grad_fn=<DivBackward0>)
tensor(5.1372, grad_fn=<DivBackward0>)
tensor(5.1343, grad_fn=<DivBackward0>)
tensor(5.1589, grad_fn=<DivBackward0>)
tensor(5.1515, grad_fn=<DivBackward0>)
tensor(5.1485, grad_fn=<DivBackward0>)
tensor(5.1464, grad_fn=<DivBackward0>)


KeyboardInterrupt: 

In [329]:
# The constituent of the sentence (either words or letters should be a part of word_to_dict)
# Choose the words carefully!
# torch.save({
#     "params" : rnn.params,
#     "optime" : optimizer.state_dict()
# }, "check.pt")
start_string = "BN01sP04E"


In [330]:
words = None
num_len_start = None

if(lwflag):
    num_len_start = len(start_string)
else:
    words = start_string.split()
    num_len_start = len(words)

start_weights = rnn.params['W_embed'][rnn.start_token]
start_weights = torch.tile(start_weights, (1, 1))

prev_h = torch.tile(rnn.params["h_init"], (1, 1))
prev_c = torch.zeros((1, prev_h.shape[1]))
curr_x = start_weights

next_h, next_c = None, None
max_length = 6
captions = rnn.null_token * torch.ones((1, max_length + num_len_start), dtype=torch.int32)

In [331]:
letter_or_word = words
if(lwflag):
    letter_or_word = start_string
    
for i, letter in enumerate(letter_or_word):
    if(rnn.cell_type == "rnn"):
        next_h = rnn_step_forward(curr_x, prev_h, rnn.params["Wx"], rnn.params["Wh"], rnn.params["b"])
    else:
        next_h, next_c = lstm_step_forward(curr_x, prev_h, prev_c, rnn.params["Wx"], rnn.params["Wh"], rnn.params["b"])

    out = affine_forward(next_h, rnn.params["W_vocab"], rnn.params["b_vocab"])
    indices = torch.tensor(word_to_idx[letter])
    captions[:, i] = indices
    prev_h = next_h
    prev_c = next_c
    curr_x = rnn.params["W_embed"][indices]


In [338]:
for i in range(max_length):
    if(rnn.cell_type == "rnn"):
        next_h = rnn_step_forward(curr_x, prev_h, rnn.params["Wx"], rnn.params["Wh"], rnn.params["b"])
    else:
        next_h, next_c = lstm_step_forward(curr_x, prev_h, prev_c, rnn.params["Wx"], rnn.params["Wh"], rnn.params["b"])
        
    out = affine_forward(next_h, rnn.params["W_vocab"], rnn.params["b_vocab"])
    # indices = torch.argmax(out, dim=1)
    # introduce temperature
    # out = torch.softmax(out, dim = 1)
    T = 0.31
    out = torch.exp(out/T)
    out = out / torch.sum(out, dim = 1)
    print(out)
    indices = torch.multinomial(out, 1).squeeze(0)
    # print(indices)
    # print(torch.argmax(out))
    # print(out)
    captions[:, num_len_start + i] = indices
    prev_h = next_h
    prev_c = next_c
    curr_x = rnn.params["W_embed"][indices]
    # print(curr_x)

captions = captions.tolist()



tensor([[0., 0., 0., 0., 0., 0., 0., 0., nan, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       grad_fn=<DivBackward0>)


RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

In [333]:
file = open("out.txt",'w')
str = ""
for i in range(len(captions)):
    words = [reverse_dict[val] for val in captions[i]]
    for letter in words:
        str += letter
    file.write(str)
    print(str)
    print("******")

BN01sP04EBN07E

******


In [None]:
import matplotlib.pyplot as plt

y = np.array([ 6.0534e-01,  4.3694e+00,  2.8921e+00, -1.2803e+00, -4.4965e+00,
         -1.2102e+00,  1.8877e+00,  4.4952e+00,  2.2088e+00, -4.9342e+00,
          1.6454e+00,  2.4964e+00,  4.2149e+00, -3.4247e+00,  3.2256e+00,
          3.7176e+00,  3.4167e+00, -1.7768e+00,  3.2877e+00, -3.4899e+00,
          6.4797e-02, -2.9072e+00, -1.7842e+00,  4.1396e+00,  2.9905e+00,
         -2.7464e+00, -1.3218e+00, -1.8224e+00, -3.8122e+00,  1.2878e+00,
         -1.8629e+01, -8.8312e-01, -1.3546e+00, -3.0432e+00, -2.0706e+00,
          4.2619e-01, -3.2271e+00, -2.8355e-01, -2.8540e+00,  1.6396e+00,
          4.9245e+00,  2.5294e+00,  2.1570e+00, -3.9914e+00,  1.1421e+00,
          2.2680e+00,  1.0391e+00, -1.9857e-02, -9.4496e+00, -1.4817e+01,
          7.2657e-01, -1.0843e+01, -1.0294e+00, -1.3077e+01, -1.1968e+01,
         -3.5006e+00, -2.9895e+00, -6.5954e-01, -1.5161e+01, -1.3209e+01,
         -2.2323e+00, -1.1670e+01, -1.1378e+01, -1.5503e+01, -1.3931e+01,
         -3.6289e+00, -9.6976e+00, -2.4147e+00,  1.2029e-01, -8.7026e+00,
         -5.8831e+00, -5.3691e+00, -4.9089e+00, -1.7769e+00, -9.7471e+00,
         -7.6364e+00, -1.7976e+01, -1.8692e+01, -7.2139e+00, -8.0258e+00,
         -1.8363e+01, -1.1233e+01, -2.0220e+01, -2.0179e+01, -1.1242e+01])

x = np.linspace(1,len(y), len(y))
print(len(y))
plt.plot(x,y)
plt.show()

In [None]:
y = np.array([ -0.9733,  10.8774,  10.0441,   4.4320,  -1.1357, -10.6907,  14.5207,
          -2.8769,   9.4872,  -5.6361,  -0.6976, -16.2822,  11.1946, -14.1597,
          -5.5444,   1.0754,  -5.3912,  -4.7569,   9.3502, -15.6364,  -2.4645,
         -10.3533,  -5.3346,   9.6924,  -9.5424,  -9.7240,  -7.8794,  -6.6811,
          -5.7003,   3.2312, -21.9670,  -6.1538,  -4.8924,  -9.5898,   0.1521,
          -4.1660,  -4.6820,  -7.4905, -10.5066,  -6.9099,  -9.1708,  -2.5660,
          -7.9966,  -5.9164,   8.7853,  -4.6198,  -9.4906,  -5.9581, -15.0839,
         -19.0691,  -8.0916, -13.8294,  -4.2585, -12.2205, -17.0471, -15.0635,
           4.9389,  -5.1173, -20.8669, -22.2936, -18.9036, -23.3794, -11.0086,
         -21.8711, -14.2440,  -7.4616, -15.8068,  -4.7838,  -4.9030, -11.7466,
          -3.5857,  -8.9297,  -5.9541, -10.3894, -12.4363, -10.7774, -21.1707,
         -21.6002, -39.1100, -14.1335, -28.5526, -29.0288, -23.3054, -21.4974,
         -25.8415])
x = np.linspace(1,len(y), len(y))
T = 100000
plt.plot(x,np.exp(y/T) / np.sum(np.exp(y/T)))
# plt.plot(x,y)
plt.show()

In [None]:
pprint(reverse_dict)

In [None]:
s_before = [ 6.0534e-01,  4.3694e+00,  2.8921e+00, -1.2803e+00, -4.4965e+00,
         -1.2102e+00,  1.8877e+00,  4.4952e+00,  2.2088e+00, -4.9342e+00,
          1.6454e+00,  2.4964e+00,  4.2149e+00, -3.4247e+00,  3.2256e+00,
          3.7176e+00,  3.4167e+00, -1.7768e+00,  3.2877e+00, -3.4899e+00,
          6.4797e-02, -2.9072e+00, -1.7842e+00,  4.1396e+00,  2.9905e+00,
         -2.7464e+00, -1.3218e+00, -1.8224e+00, -3.8122e+00,  1.2878e+00,
         -1.8629e+01, -8.8312e-01, -1.3546e+00, -3.0432e+00, -2.0706e+00,
          4.2619e-01, -3.2271e+00, -2.8355e-01, -2.8540e+00,  1.6396e+00,
          4.9245e+00,  2.5294e+00,  2.1570e+00, -3.9914e+00,  1.1421e+00,
          2.2680e+00,  1.0391e+00, -1.9857e-02, -9.4496e+00, -1.4817e+01,
          7.2657e-01, -1.0843e+01, -1.0294e+00, -1.3077e+01, -1.1968e+01,
         -3.5006e+00, -2.9895e+00, -6.5954e-01, -1.5161e+01, -1.3209e+01,
         -2.2323e+00, -1.1670e+01, -1.1378e+01, -1.5503e+01, -1.3931e+01,
         -3.6289e+00, -9.6976e+00, -2.4147e+00,  1.2029e-01, -8.7026e+00,
         -5.8831e+00, -5.3691e+00, -4.9089e+00, -1.7769e+00, -9.7471e+00,
         -7.6364e+00, -1.7976e+01, -1.8692e+01, -7.2139e+00, -8.0258e+00,
         -1.8363e+01, -1.1233e+01, -2.0220e+01, -2.0179e+01, -1.1242e+01]

t_before = [  0.8672,   4.2704,   1.6499,  -0.1441,  -5.4380,  -1.6326,   3.0245,
           4.9084,   2.5126,  -4.8974,   1.4154,   3.2224,   2.7437,  -1.8678,
           3.8421,   2.1870,   3.5555,  -1.7945,   3.5485,  -2.6484,   0.1289,
          -2.8479,  -1.5428,   2.6185,   2.2913,  -3.0837,  -2.7430,  -0.9212,
          -3.8082,   0.4127, -16.9907,  -1.0579,  -1.2042,  -1.7914,  -3.5872,
          -1.0042,  -2.5655,  -2.1203,  -2.0782,   0.3144,   3.7498,   2.7011,
           2.5335,  -4.2712,   2.8088,   2.2089,   0.9351,   0.5182,  -8.8479,
         -12.8158,   1.9641, -10.5990,  -0.6618, -13.1240, -11.6949,  -2.3679,
          -1.4294,  -2.0577, -14.6225, -13.6667,  -1.3513,  -8.3802,  -8.8191,
         -14.4441, -13.4774,  -3.4689, -11.0212,  -1.7646,   0.3403,  -6.7329,
          -6.3991,  -4.7024,  -4.1516,  -1.4665,  -8.4811,  -7.1091, -15.7866,
         -16.2358,  -6.7877,  -7.1301, -17.4053, -10.6621, -19.3296, -19.2079,
         -11.2567]

after_e = [-1.0934e+01, -1.1948e+00,  4.4822e-02, -6.5232e-01,  7.9335e+00,
         -8.2125e+00, -5.7796e+00,  1.6635e+00, -2.4267e+00, -2.9554e+00,
          3.8696e+00,  2.9119e+00,  3.3425e+00, -9.7573e+00, -3.0463e+00,
          2.1159e+00,  2.4975e+00, -1.6430e+01,  2.7373e+00, -2.0869e+01,
          5.1378e+00, -1.4223e+01, -1.1924e+01, -2.9010e+00, -2.0186e+00,
         -1.5791e+01, -1.1752e+01, -1.0855e+01, -2.0744e+01,  2.1554e+00,
         -2.1378e+01, -6.1450e+00, -1.1220e+01, -4.8905e+00, -1.0900e+01,
         -1.0639e+01, -9.7942e+00, -1.4185e+01, -1.4832e+01,  7.0777e-01,
          2.2998e+00,  2.3042e-01,  6.2308e-03,  5.6740e+00, -5.1872e+00,
          4.8951e+00, -5.6572e+00, -1.0699e+01, -1.6901e+01, -1.8379e+01,
          6.3958e-01, -1.5640e+01, -1.0631e+01, -1.2981e+01, -1.7336e+01,
         -1.3306e+01, -8.7787e-01,  1.2092e+00, -1.5528e+01, -1.8282e+01,
         -1.5849e+01, -2.1366e+01, -1.3913e+00, -1.5046e+01, -1.4995e+01,
         -1.1208e+01, -1.7130e+01, -1.6069e+00, -5.5352e+00, -3.9728e-01,
          2.9421e+00,  1.9356e+00,  3.6271e+00, -1.6579e+01, -1.1398e+00,
         -1.0887e+01, -2.5913e+01, -1.8709e+01, -3.9841e+01, -7.8933e+00,
         -1.5357e+01, -2.5786e+01, -2.0423e+01, -1.9744e+01, -1.8157e+01]