Skip to content

Commit

Permalink
Merge pull request #55 from carriepl/lstm_tutorial
Browse files Browse the repository at this point in the history
Add basic LSTM tutorial
  • Loading branch information
nouiz committed Dec 20, 2014
2 parents 30f10a9 + 4bb23cf commit fd1d392
Show file tree
Hide file tree
Showing 5 changed files with 790 additions and 0 deletions.
92 changes: 92 additions & 0 deletions code/imdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import cPickle
import gzip
import os
import sys
import time

import numpy

import theano
import theano.tensor as T


def prepare_data(seqs, labels, maxlen=None):
# x: a list of sentences
lengths = [len(s) for s in seqs]

if maxlen is not None:
new_seqs = []
new_labels = []
new_lengths = []
for l, s, y in zip(lengths, seqs, labels):
if l < maxlen:
new_seqs.append(s)
new_labels.append(y)
new_lengths.append(l)
lengths = new_lengths
labels = new_labels
seqs = new_seqs

if len(lengths) < 1:
return None, None, None

n_samples = len(seqs)
maxlen = numpy.max(lengths)

x = numpy.zeros((maxlen, n_samples)).astype('int64')
x_mask = numpy.zeros((maxlen, n_samples)).astype('float32')
for idx, s in enumerate(seqs):
x[:lengths[idx], idx] = s
x_mask[:lengths[idx], idx] = 1.

return x, x_mask, labels


def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1):
''' Loads the dataset
:type dataset: string
:param dataset: the path to the dataset (here IMDB)
'''

#############
# LOAD DATA #
#############

print '... loading data'

# Load the dataset
f = open(path, 'rb')
train_set = cPickle.load(f)
test_set = cPickle.load(f)
f.close()

# split training set into validation set
train_set_x, train_set_y = train_set
n_samples = len(train_set_x)
sidx = numpy.random.permutation(n_samples)
n_train = int(numpy.round(n_samples * (1. - valid_portion)))
valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
valid_set_y = [train_set_y[s] for s in sidx[n_train:]]
train_set_x = [train_set_x[s] for s in sidx[:n_train]]
train_set_y = [train_set_y[s] for s in sidx[:n_train]]

train_set = (train_set_x, train_set_y)
valid_set = (valid_set_x, valid_set_y)

def remove_unk(x):
return [[1 if w >= n_words else w for w in sen] for sen in x]

test_set_x, test_set_y = test_set
valid_set_x, valid_set_y = valid_set
train_set_x, train_set_y = train_set

train_set_x = remove_unk(train_set_x)
valid_set_x = remove_unk(valid_set_x)
test_set_x = remove_unk(test_set_x)

train = (train_set_x, train_set_y)
valid = (valid_set_x, valid_set_y)
test = (test_set_x, test_set_y)

return train, valid, test
Loading

0 comments on commit fd1d392

Please sign in to comment.