In [1]:
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
import re
import pandas as pd
import seaborn as sns
import datetime

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

In [2]:
ds = pd.read_csv('./data/trump_tw.csv')
ds.text.head()

  interactivity=interactivity, compiler=compiler, result=result)


0    Read a great interview with Donald Trump that ...
1    Congratulations to Evan Lysacek for being nomi...
2    I was on The View this morning. We talked abou...
3    Tomorrow night's episode of The Apprentice del...
4    Donald Trump Partners with TV1 on New Reality ...
Name: text, dtype: object

In [3]:
# Take all the text together

data = ' '.join([ix for ix in ds.text])
print data[:1000]

Read a great interview with Donald Trump that appeared in The New York Times Magazine: http://tinyurl.com/qsx4o6 Congratulations to Evan Lysacek for being nominated SI sportsman of the year. He's a great guy, and he has my vote!  #EvanForSI I was on The View this morning. We talked about The Apprentice. Tonight's episode is a great one--tough, exciting and surprising. 10 pm/NBC Tomorrow night's episode of The Apprentice delivers excitement at QVC along with appearances by Isaac Mizrahi and Cathie Black. 10 pm on NBC Donald Trump Partners with TV1 on New Reality Series Entitled, Omarosa's Ultimate Merger: http://tinyurl.com/yk5m3lc I'll be appearing on Larry King Live for his final show, Thursday night at 9 p.m., CNN. Larry's been on TV for 25 years... I'll be on The Late Show with David Letterman tonight--be sure to tune in for a great show. 11:30 pm on CBS. Watch the Miss Universe competition LIVE from the Bahamas - Sunday, 8/23 @ 9pm (ET) on NBC: http://tinyurl.com/mrzad9 Watch video

In [4]:
print set(data)
print len(set(data))

set(['\x83', '\x87', '\x8b', '\x8f', '\x93', '\x97', '\x9b', '\x9f', ' ', '\xa3', '$', '\xa7', '(', '\xab', ',', '\xaf', '0', '\xb3', '4', '\xb7', '8', '\xbb', '\xbf', '@', '\xc3', 'D', 'H', 'L', 'P', 'T', '\xd7', 'X', '\\', '`', '\xe3', 'd', '\xe7', 'h', 'l', '\xef', 'p', 't', 'x', '|', '\x80', '\x84', '\x88', '\x8c', '\x90', '\x94', '\x98', '\x9c', '\xa0', '#', '\xa4', "'", '\xa8', '+', '\xac', '/', '\xb0', '3', '\xb4', '7', '\xb8', ';', '\xbc', '?', 'C', '\xc4', 'G', 'K', 'O', 'S', 'W', '[', '_', 'c', '\xe4', 'g', '\xe8', 'k', 'o', '\xf0', 's', '\xf4', 'w', '{', '\x81', '\x85', '\x89', '\n', '\x8d', '\x91', '\x95', '\x99', '\x9d', '\xa1', '"', '\xa5', '&', '\xa9', '*', '\xad', '.', '\xb1', '2', '\xb5', '6', '\xb9', ':', '\xbd', 'B', '\xc5', 'F', 'J', 'N', 'R', 'V', 'Z', '\xe1', 'b', '\xe5', 'f', '\xe9', 'j', 'n', 'r', 'v', 'z', '~', '\x82', '\x86', '\x8a', '\r', '\x8e', '\x92', '\x96', '\x9a', '\x9e', '!', '\xa2', '%', '\xa6', ')', '\xaa', '-', '\xae', '1', '\xb2', '5', '9', '\xba',

In [5]:
# Create Vocab
vocab = list(set(data))

i2c, c2i = {}, {}

for idx, chx in enumerate(vocab):
    i2c[idx] = chx
    c2i[chx] = idx

In [6]:
def get_onehot(x):
    # Take input a string and convert to one-hot encoding
    vec_size = len(c2i.keys())
    n_seq = len(x)
    data = np.zeros((1, n_seq, vec_size))
    
    # For each element in the list
    for ix in range(n_seq):
        curr_char = x[ix]
        oh_index = c2i[curr_char]
        data[:, ix, oh_index] = 1
    return data

print get_onehot('this is my string').shape

(1, 17, 174)


In [7]:
for ix in ds.text[:10]:
    print get_onehot(ix).shape

(1, 112, 174)
(1, 127, 174)
(1, 139, 174)
(1, 140, 174)
(1, 116, 174)
(1, 122, 174)
(1, 108, 174)
(1, 117, 174)
(1, 115, 174)
(1, 102, 174)


In [8]:
class CharNN(nn.Module):
    def __init__(self, in_shape=None, out_shape=None, hidden_shape=None):
        super(CharNN, self).__init__()
        self.in_shape = in_shape
        self.out_shape = out_shape
        self.hidden_shape = hidden_shape
        self.n_layers = 1
        
        self.rnn = nn.LSTM(
            input_size=self.in_shape,
            hidden_size=self.hidden_shape,
            num_layers=self.n_layers,
            batch_first=True
        )
        self.out = nn.Linear(self.hidden_shape, self.out_shape)
    
    def forward(self, x, h):
        r_out, h_state = self.rnn(x, h)
        
        outs = []
        for ix in range(r_out.size(1)):
            current_out = F.softmax(self.out(r_out[:, ix, :]))
            outs.append(current_out)
        return torch.stack(outs, dim=1), h_state
    
    def predict(self, char, h=None, top_k=None):
        if h is None:
            h = self.init_hidden(1, gpu=False)
        
        x = get_onehot(char)
        out, h = self.forward(torch.FloatTensor(x), h)
        
        p = out.data
        if top_k is None:
            top_ch = np.arange(self.out_shape)
        else:
            p, top_ch = p.topk(top_k)
            top_ch = top_ch.numpy().squeeze()
        
        p = p.numpy().squeeze()
        char = np.random.choice(top_ch, p=p/p.sum())
        return i2c[char], h
    
    def init_hidden(self, batch_size, gpu=True):
        if gpu:
            return (Variable(torch.zeros(self.n_layers, batch_size, self.hidden_shape).cuda()),
                    Variable(torch.zeros(self.n_layers, batch_size, self.hidden_shape)).cuda())
        return (Variable(torch.zeros(self.n_layers, batch_size, self.hidden_shape)),
                Variable(torch.zeros(self.n_layers, batch_size, self.hidden_shape)))

In [9]:
model = CharNN(in_shape=174, out_shape=174, hidden_shape=256)
model.cuda()
# print model

CharNN(
  (rnn): LSTM(174, 128, batch_first=True)
  (out): Linear(in_features=128, out_features=174, bias=True)
)

In [10]:
# model.predict('a', top_k=20)[0]

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [12]:
# Set to train mode
model.cuda()
model.train()
N = 5000

for epoch in range(50):
    total_loss = 0
    # For each sequence
    for qx in range(N):
        seqx = ds.text[qx]
        h_state = model.init_hidden(1)
        input_seq = seqx[:-1]
        target_seq = seqx[1:]
        
        x = Variable(torch.FloatTensor(get_onehot(input_seq)), requires_grad=True).cuda()
        y = Variable(torch.LongTensor(get_onehot(target_seq).argmax(2))).cuda()
        
        model.zero_grad()
        pred, h_state = model.forward(x, h_state)
        # print pred.squeeze().shape, y.shape
        loss = criterion(pred.squeeze(), y.squeeze())
        
        # optimizer.zero_grad()
        loss.backward()
        
        # gradient clipping to solve exploding/vanishing grads
        # clip = 5.0
        # nn.utils.clip_grad_norm(net.parameters(), clip)
        
        optimizer.step()
        total_loss += loss
        if qx%(N/5) == 0:
            print 'Loss: {} at Epoch: {} | Seq: {}'.format(loss, epoch, qx)
        
    print "Overall Average Loss: {} at Epoch: {}".format(total_loss / float(N), epoch)
    
    # Save model checkpoints
    if epoch % 10 == 0:
        torch.save(model.state_dict(), "./data/checkpoints/text_gen/model_256h_epoch_{}.ckpt".format(epoch))



Loss: 5.15909719467 at Epoch: 0 | Seq: 0
Loss: 4.97706508636 at Epoch: 0 | Seq: 1000
Loss: 4.99024248123 at Epoch: 0 | Seq: 2000
Loss: 5.02281427383 at Epoch: 0 | Seq: 3000
Loss: 4.96628522873 at Epoch: 0 | Seq: 4000
Overall Average Loss: 5.01461648941 at Epoch: 0
Loss: 5.00101280212 at Epoch: 1 | Seq: 0
Loss: 4.91032266617 at Epoch: 1 | Seq: 1000
Loss: 4.90067529678 at Epoch: 1 | Seq: 2000
Loss: 4.94082164764 at Epoch: 1 | Seq: 3000
Loss: 4.94284772873 at Epoch: 1 | Seq: 4000
Overall Average Loss: 4.9568529129 at Epoch: 1
Loss: 4.94118452072 at Epoch: 2 | Seq: 0
Loss: 4.93541955948 at Epoch: 2 | Seq: 1000
Loss: 4.89382219315 at Epoch: 2 | Seq: 2000
Loss: 4.93845129013 at Epoch: 2 | Seq: 3000
Loss: 4.93059587479 at Epoch: 2 | Seq: 4000
Overall Average Loss: 4.94334697723 at Epoch: 2
Loss: 4.90147352219 at Epoch: 3 | Seq: 0
Loss: 4.89795398712 at Epoch: 3 | Seq: 1000
Loss: 4.84971761703 at Epoch: 3 | Seq: 2000
Loss: 4.89255046844 at Epoch: 3 | Seq: 3000
Loss: 4.8633184433 at Epoch: 3 | 

KeyboardInterrupt: 

In [193]:
sentence = 'p'
model.cpu()
for ix in range(1000):
    ctx = sentence[-1]
    out = model.predict(ctx, top_k=10)[0]
    
    sentence += out
print sentence



pe the the the the the the the he the the the there the the he the he the the the he the he the the the the the the the the the the the he the the the the the the the the the the the the the the the the the the he the the the he the the the the the the the the the the the the the the the the he the the the the the the the the the the the there the the the the the there the the the the the the the the the the the the the he the he the the the the the the the the the he the the the the the the the the the the the the the the the the the the the the he he the he the the the thanonononononone the the the he the thanononononone the the he the the the the he he the the the the the the the the the the the the the there the the thanonononononononononononone he he the the the the the the the the the the there the the the the the the the the thanononononononononone the the the the the the the the the the the the the the the thanonononononononone the the the the the the the the the the the he the