In [2]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import random

use_cuda = True
device_id = 3
from tensorflow.contrib.keras.python.keras.datasets.imdb import load_data, get_word_index

max_features = 5000
batch_size = 32
epochs = 15
learning_rate = 0.001

In [3]:
import time
import math


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [4]:
(x_train, y_train), (x_test, y_test) = load_data(num_words=max_features)


In [5]:
def pad(tensor, length):
    return torch.cat([tensor, tensor.new(length - tensor.size(0),*tensor.size()[1:]).zero_()])

In [6]:
def sortedText(idx, xs, ys):
    batch_xs = xs[idx]
    batch_ys = ys[idx]
    lengths = np.array([len(x) for x in batch_xs])
    sort_idx = np.argsort(lengths)[::-1]
    return batch_xs[sort_idx], lengths[sort_idx], batch_ys[sort_idx]


def textTensor(idx, xs, ys):
    batch_xs, lengths, batch_ys = sortedText(idx, xs, ys)
    max_length = lengths[0]
    return torch.cat([pad(torch.Tensor(x), max_length).view(max_length, 1)
                      for x in batch_xs], 1).long(), list(lengths), torch.FloatTensor(batch_ys)


In [67]:
class IMDB_Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, batch_size, n_layers=1):
        super(IMDB_Classifier, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.dropout = nn.Dropout(0.5)
        self.lstm = nn.LSTM(hidden_size, hidden_size/2, bidirectional=True)
        self.dense = nn.Linear(hidden_size, 1)
        
    def forward(self, word_input, input_length, hidden=None):
        output = self.embedding(word_input)
        output = pack_padded_sequence(output, input_length)
        output, hidden = self.lstm(output, hidden)
        output = torch.cat([hidden[0][0], hidden[0][1]], 1)
        output = self.dropout(output)
        output = F.sigmoid(self.dense(output))
        
        return output
        
    
    def initHidden(self):
        hidden = Variable(torch.zeros(1, self.batch_size, self.hidden_size))
        cell = Variable(torch.zeros(1, self.batch_size, self.hidden_size))
        return (hidden.cuda(device_id), cell.cuda(device_id)) if use_cuda else (hidden, cell) 

In [68]:
def get_last_step_indices(lengths):
    n_lengths = len(lengths)
    rev_lengths = lengths[::-1]
    rev_lengths_sum = torch.LongTensor(rev_lengths).cumsum(0)
    return torch.LongTensor([(n_lengths - i - 1) * length + rev_lengths_sum[i] - 1
                         for i, length in enumerate(rev_lengths)][::-1])


def get_last_step_tensor(packed_sequence, lengths):
    indices = Variable(torch.LongTensor(get_last_step_indices(lengths)))
    if packed_sequence.data.data.is_cuda:
        indices = indices.cuda(packed_sequence.data.data.get_device())
    last_step = packed_sequence.data.index_select(0, indices)
    return last_step

In [69]:
def test():
    idx, y = iter(train_loader).next()
    x, lengths, y = textTensor(idx, x_train, y_train)
    x, y = Variable(x), Variable(y)
    print('input_batches', x.size()) 
    model = IMDB_Classifier(max_features, 2, 32)
    model = model.cuda() if use_cuda else model
    (x, y) = (x.cuda(), y.cuda()) if use_cuda else (x, y)
    output = model(x, lengths)
    print('output_batches', output.size()) 
    cretrion = nn.BCELoss()

    print(cretrion(output, y))

In [70]:
clf = IMDB_Classifier(max_features, 128, batch_size)
clf = clf.cuda(device_id) if use_cuda else clf
optimizer = optim.Adam(clf.parameters(), lr=0.001)
criterion = nn.BCELoss()

start = time.time()
for epoch in range(1, 101):
    losses = 0
    indices = np.random.permutation(np.array(range(25000)))
    for i in range(1, indices.shape[0] / 32 + 1):
        x, lengths, y = textTensor(indices[(i-1)*32:i*32], x_train, y_train)
        x, y = Variable(x), Variable(y)

        (x, y) = (x.cuda(device_id), y.cuda(device_id)) if use_cuda else (x, y)

        output = clf(x, lengths)
        loss = criterion(output, y)
        losses += loss.data[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print("batch: {}, batch_loss: {}".format(i, losses / i))

    print("Epoch: {}, time: {}, loss: {}".format(epoch, timeSince(start, float(epoch) / epochs), losses/(i)))



    total = 0
    correct = 0
    for i in range(250):
        x, lengths, y = textTensor(range(25000)[i*100:(i+1)*100], x_train, y_train)
        #x, lengths, y = textTensor(torch.LongTensor(range(32)), x_train, y_train)
        x, y = Variable(x, volatile=True), Variable(y)
        (x, y) = (x.cuda(device_id), y.cuda(device_id)) if use_cuda else (x, y)
        output = clf(x, lengths)
        output = output > 0.5
        correct += (output.float() == y).sum().data[0]
        total += y.size(0)
    print('Accuracy of the network on the {} texts: {} %'.format(y_test.shape[0], 100. * correct / total))

batch: 10, batch_loss: 0.695220530033
batch: 20, batch_loss: 0.692368993163
batch: 30, batch_loss: 0.690876714389
batch: 40, batch_loss: 0.690659186244
batch: 50, batch_loss: 0.689180973768
batch: 60, batch_loss: 0.6895878613
batch: 70, batch_loss: 0.689470046759
batch: 80, batch_loss: 0.688695345074
batch: 90, batch_loss: 0.688032517831
batch: 100, batch_loss: 0.687166379094
batch: 110, batch_loss: 0.686698469791
batch: 120, batch_loss: 0.685625120004
batch: 130, batch_loss: 0.684387684785
batch: 140, batch_loss: 0.683821835262
batch: 150, batch_loss: 0.683342698018
batch: 160, batch_loss: 0.682021187246
batch: 170, batch_loss: 0.68045007902
batch: 180, batch_loss: 0.679582681921
batch: 190, batch_loss: 0.679153075657
batch: 200, batch_loss: 0.677671817839
batch: 210, batch_loss: 0.676283865316
batch: 220, batch_loss: 0.675144485994
batch: 230, batch_loss: 0.674321341774
batch: 240, batch_loss: 0.671996882061
batch: 250, batch_loss: 0.669673966169
batch: 260, batch_loss: 0.66771438523

KeyboardInterrupt: 