In [1]:
import torch
import re
import torch.nn as nn
import random
import time
from operator import itemgetter
from tqdm import tqdm
import numpy as np

In [2]:
def load_glove():
    glove_home = './'
    words_to_load = 50000

    with open(glove_home + 'glove.6B.50d.txt',encoding='utf-8') as f:
        loaded_embeddings = np.zeros((words_to_load, 50))
        words = {}
        idx2words = {}
        ordered_words = []
        for i, line in enumerate(f):
            if i >= words_to_load: 
                break
            s = line.split()
            loaded_embeddings[i, :] = np.asarray(s[1:])
            words[s[0]] = i
            idx2words[i] = s[0]
            ordered_words.append(s[0])
    return loaded_embeddings, words, idx2words, ordered_words

In [3]:
def load_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        text = f.read()
        
        text = re.sub(r"\.{3}", " <DASH>", text)
        text = re.sub(r"\.{1}", " <PERIOD>", text)
        text = re.sub(r"\,", " <COMMA>", text)
        text = re.sub(r"\?", " <QUSTIONMARK>", text)
        text = re.sub(r"\!", " <EXCLAMATIONMARK>", text)
        
        paragraphs = text.split('\n')
        lines = []
        
        for paragraph in paragraphs:
            lines_in_paragraph = [line.strip() + " <PERIOD>" for line in paragraph.split('<PERIOD>') if line.strip()!=""]
            lines += lines_in_paragraph
        return lines
    return None
        
def build_dictionary(lines):
    word_set = set()
    for i, line in enumerate(lines):
        if line == "":
            continue
        tokens = tokenize(line)
        word_set = word_set | set(tokens)
        lines[i] = tokens
    word_set = word_set - set([''])
    vocabulary = list(word_set)
    vocabulary = ["<0>", "<ENDOFLINE>", "<UNKNOWN>", "<PADDING>"] + vocabulary
    dictionary = dict(zip(vocabulary, range(len(vocabulary))))
    idx_lines = []
    for line in lines:
        idx_lines.append([dictionary[word] for word in line if word!=""]+[1])
    lines = []
    for line in idx_lines:
        if len(line) > 0:
            lines.append(line)
    return lines, vocabulary, dictionary

def tokenize(line):
    tokens = line.split(' ')
    for i in range(len(tokens)):
        tokens[i] = tokens[i].strip('\'').lower()
    return tokens

def get_rand_line(lines):
    return lines[random.randint(0, len(lines)-1)]

def cosine_similarity(vec_one, vec_two):
    return torch.dot(vec_one, vec_two)/torch.norm(vec_one)/torch.norm(vec_two)

def build_preloadembed(vocabulary, loaded_embeddings,embedding_dim, words, ordered_words):
    preloadembed = torch.zeros((len(vocabulary), embedding_dim)).normal_(mean=0,std=0.01)
    for i in range(len(vocabulary)):
        word = vocabulary[i]
        if word in ordered_words:
            preloadembed[i, :] = torch.from_numpy(loaded_embeddings[words[word]])
    return preloadembed

class Dataloader():
    def __init__(self, minibatch_size, minibatch_num, lines):
        super(Dataloader, self).__init__()
        self.minibatch_size = minibatch_size
        self.minibatch_num = minibatch_num
        self.lines = lines
    
    def get_minibatches(self):
        chosen_idxs = np.random.randint(len(lines), size = self.minibatch_size*self.minibatch_num)
        chosen_lines = [lines[chosen_idx] for chosen_idx in chosen_idxs]
        chosen_lines.sort(key=lambda x:len(x))
        minibatches = []
        for i in range(self.minibatch_num):
            mini_batch_content = chosen_lines[i*self.minibatch_size:(i+1)*self.minibatch_size]
            max_length = len(mini_batch_content[-1])
            for i in range(len(mini_batch_content)):
                while len(mini_batch_content[i])<max_length:
                    mini_batch_content[i].append(3)
            minibatches.append(torch.LongTensor(mini_batch_content))
        return minibatches

class LSTM3(nn.Module):
    def __init__(self, preloadembed, vocabulary, dictionary, embedding_dim, hidden_size, batch_size):
        super(LSTM3, self).__init__()
        
        self.embed = nn.Embedding(len(vocabulary), embedding_dim, padding_idx=0)
        self.embedding_size = embedding_dim
        self.hidden_size = hidden_size
        self.vocabulary = vocabulary
        self.dictionary = dictionary
        self.batch_size = batch_size
        #self.embed.weight.data.copy_(torch.rand(self.embed.weight.size()))
        self.embed.weight.data.copy_(preloadembed)
        del preloadembed
        
        self.linear_f0 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_i0 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_ctilde0 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_o0 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        
        self.linear_f1 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_i1 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_ctilde1 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_o1 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        
        self.linear_f2 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_i2 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_ctilde2 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        self.linear_o2 = nn.Linear(embedding_dim + hidden_size, hidden_size)
        
        self.decoder0 = nn.Linear(hidden_size, int((hidden_size+embedding_dim)*0.5))
        self.decoder1 = nn.Linear(int((hidden_size+embedding_dim)*0.5), embedding_dim)
        self.init_weights()
    
    def find_nearest_word(self, output, k=1):
        scored_words = [(word, cosine_similarity(output, self.embed(torch.LongTensor([word])).squeeze())) for word in range(len(self.vocabulary))]
        sorted_words = sorted(scored_words, key=itemgetter(1), reverse=True)
        return sorted_words[:k]
    
    def forward(self, line, hidden, c0, c1, c2):
        line_emb = self.embed(line)
        line_embs = torch.chunk(line_emb, line_emb.size()[1], 1)
        output_line = []
        loss = 0
        
        def step(emb, hid, c_t0, c_t1, c_t2):
            combined = torch.cat((hid, emb), 1)
            f0 = torch.sigmoid(self.linear_f0(combined))
            i0 = torch.sigmoid(self.linear_i0(combined))
            c_tilde0 = torch.tanh(self.linear_ctilde0(combined))
            c_t0 = f0 * c_t0 + i0 * c_tilde0
            o0 = torch.sigmoid(self.linear_o0(combined))
            hid0 = o0 * torch.tanh(c_t0)
            
            combined = torch.cat((hid0, emb), 1)
            f1 = torch.sigmoid(self.linear_f1(combined))
            i1 = torch.sigmoid(self.linear_i1(combined))
            c_tilde1 = torch.tanh(self.linear_ctilde1(combined))
            c_t1 = f1 * c_t1 + i1 * c_tilde1
            o1 = torch.sigmoid(self.linear_o1(combined))
            hid1 = o1 * torch.tanh(c_t1)
            
            combined = torch.cat((hid1, emb), 1)
            f2 = torch.sigmoid(self.linear_f2(combined))
            i2 = torch.sigmoid(self.linear_i2(combined))
            c_tilde2 = torch.tanh(self.linear_ctilde2(combined))
            c_t2 = f2 * c_t2 + i2 * c_tilde2
            o2 = torch.sigmoid(self.linear_o2(combined))
            hid2 = o2 * torch.tanh(c_t2)
            return hid2, c_t0, c_t1, c_t2
              
        for i in range(len(line_embs)-1):
            hidden, c0, c1, c2 = step(line_embs[i].squeeze(), hidden, c0, c1, c2)
            output = self.decoder0(hidden)
            output = self.decoder1(output)
            for b in range(2, self.batch_size):
                #loss += cosine_similarity(output[b,:].squeeze(), line_embs[i+1].squeeze()[b,:])
                loss += 1-cosine_similarity(output[b,:].squeeze(), line_embs[i+1].squeeze()[b,:])
            output_line.append(output)
            """
            if i == 2:
                nearest_word = self.find_nearest_word(output[2,:].squeeze())
                nearest_word1 = self.find_nearest_word(line_embs[i+1].squeeze()[2,:])
                print(self.vocabulary[nearest_word[0][0]], self.vocabulary[nearest_word1[0][0]])
            """                   
            
        return output_line, loss
    
    def babble_mode(self, max_length):
        output_line = []
        hidden, c0, c1, c2 = torch.zeros(self.hidden_size).normal_(mean=0,std=.1), torch.zeros(self.hidden_size).normal_(mean=0,std=.1), torch.zeros(self.hidden_size).normal_(mean=0,std=.1), torch.zeros(self.hidden_size).normal_(mean=0,std=.1)
        output = self.decoder0(hidden)
        output = self.decoder1(output)
            
        def step(emb, hid, c_t0, c_t1, c_t2):
            combined = torch.cat((hid, emb), 0)
            f0 = torch.sigmoid(self.linear_f0(combined))
            i0 = torch.sigmoid(self.linear_i0(combined))
            c_tilde0 = torch.tanh(self.linear_ctilde0(combined))
            c_t0 = f0 * c_t0 + i0 * c_tilde0
            o0 = torch.sigmoid(self.linear_o0(combined))
            hid0 = o0 * torch.tanh(c_t0)
            
            combined = torch.cat((hid0, emb), 0)
            f1 = torch.sigmoid(self.linear_f1(combined))
            i1 = torch.sigmoid(self.linear_i1(combined))
            c_tilde1 = torch.tanh(self.linear_ctilde1(combined))
            c_t1 = f1 * c_t1 + i1 * c_tilde1
            o1 = torch.sigmoid(self.linear_o1(combined))
            hid1 = o1 * torch.tanh(c_t1)
            
            combined = torch.cat((hid1, emb), 0)
            f2 = torch.sigmoid(self.linear_f2(combined))
            i2 = torch.sigmoid(self.linear_i2(combined))
            c_tilde2 = torch.tanh(self.linear_ctilde2(combined))
            c_t2 = f2 * c_t2 + i2 * c_tilde2
            o2 = torch.sigmoid(self.linear_o2(combined))
            hid2 = o2 * torch.tanh(c_t2)
            return hid2, c_t0, c_t1, c_t2
              
        for i in range(max_length):
            hidden, c0, c1, c2 = step(output, hidden, c0, c1, c2)
            output = self.decoder0(hidden)
            output = self.decoder1(output)
            nearest_word = self.find_nearest_word(output)
            word = self.vocabulary[nearest_word[0][0]]
            output_line.append(word)
            if word == "<ENDOFLINE>":
                break
            line = ""
            for word in output_line:
                line += word + " "
        return line
            

    def init_hidden(self):
        h0 = torch.ones(self.batch_size, self.hidden_size).normal_(mean=0,std=.1)
        c0 = torch.ones(self.batch_size, self.hidden_size).normal_(mean=0,std=.1)
        c1 = torch.ones(self.batch_size, self.hidden_size).normal_(mean=0,std=.1)
        c2 = torch.ones(self.batch_size, self.hidden_size).normal_(mean=0,std=.1)
        return h0, c0, c1, c2
    
    def init_weights(self):
        #initrange = 0.1
        lin_layers = [self.linear_f0, self.linear_i0, self.linear_ctilde0, self.linear_o0,
                      self.linear_f1, self.linear_i1, self.linear_ctilde1, self.linear_o1,
                      self.linear_f2, self.linear_i2, self.linear_ctilde2, self.linear_o2,
                      self.decoder0, self.decoder1]
        #em_layer = [self.embed]
        em_layer = []
     
        for layer in lin_layers+em_layer:
            layer.weight.data.normal_(0, 0.1)
            if layer in lin_layers:
                layer.bias.data.fill_(0)


In [4]:
minibatch_size = 10
minibatch_num = 1000
epoch = 10
embedding_dim = 50
hidden_size = 20

time1 = time.time()
loaded_embeddings, words, idx2words, ordered_words = load_glove()
lines = load_data("Life is Magic.txt")
lines, vocabulary, dictionary = build_dictionary(lines)
loader = Dataloader(minibatch_size, minibatch_num, lines)
preloadembed = build_preloadembed(vocabulary, loaded_embeddings,embedding_dim, words, ordered_words)


rnn = LSTM3(preloadembed, vocabulary, dictionary, embedding_dim, hidden_size, minibatch_size)
rnn.init_weights()

optimizer = torch.optim.SGD(rnn.parameters(), lr= 0.001, momentum=0.9)

time2 = time.time()
print("prepare time:",time2-time1)

for e in range(epoch):
    print("epoch",e,"start!")
    minibatches = loader.get_minibatches()
    loss_sum = 0
    for b in tqdm(range(minibatch_num)):
        hidden, c1, c2, c3 = rnn.init_hidden()
        output_line, loss = rnn(minibatches[b],  hidden, c1, c2, c3)
        loss.backward()
        loss_sum += float(loss.item())
        torch.nn.utils.clip_grad_norm_(rnn.parameters(), 5.0)
        optimizer.step()
    print(loss_sum/minibatch_num/minibatch_size)

time3 = time.time()
print("training time",time3-time2)

prepare time: 4.68490743637085
epoch 0 start!


100%|██████████████████████████████████████| 1000/1000 [01:56<00:00,  2.64it/s]


2.9810074776649476
epoch 1 start!


100%|██████████████████████████████████████| 1000/1000 [01:54<00:00,  2.55it/s]


2.3060655040144917
epoch 2 start!


100%|██████████████████████████████████████| 1000/1000 [01:55<00:00,  2.52it/s]


2.0142407584428788
epoch 3 start!


100%|██████████████████████████████████████| 1000/1000 [01:57<00:00,  2.50it/s]


1.823183552622795
epoch 4 start!


100%|██████████████████████████████████████| 1000/1000 [02:01<00:00,  2.42it/s]


1.6853950708031653
epoch 5 start!


100%|██████████████████████████████████████| 1000/1000 [02:08<00:00,  2.49it/s]


1.5474228750050067
epoch 6 start!


100%|██████████████████████████████████████| 1000/1000 [02:06<00:00,  2.46it/s]


1.3961868312954901
epoch 7 start!


100%|██████████████████████████████████████| 1000/1000 [02:06<00:00,  2.39it/s]


1.2873744143664836
epoch 8 start!


100%|██████████████████████████████████████| 1000/1000 [02:04<00:00,  2.45it/s]


1.208361787635088
epoch 9 start!


100%|██████████████████████████████████████| 1000/1000 [02:06<00:00,  2.38it/s]


1.1034245105326177
training time 1218.9023954868317


In [5]:
babble = rnn.babble_mode(20)
print(babble)
time4 = time.time()
print("babble time",time4-time3)

facial fluttershy—both chest grimace trembling grimace cheerfulness glared "each "sorry glared tower's glared "each grimace glared tower's glared "each grimace 
babble time 8.899682760238647
