In [1]:
import torch
import torch.nn as nn
import numpy as np

## Data

In [2]:
with open('./data/aclImdb/imdb.vocab', 'r') as f:
    vocab = f.read().splitlines()

In [3]:
VOCAB_LEN = len(vocab)
NUM_CLASSES = 10
NUM_EXAMPLES = 10000

In [4]:
def one_hot_encode(n):
    arr = np.zeros(NUM_CLASSES)
    arr[int(n)-1] = 1
    return arr

In [5]:
def glove2dict(src_filename):
    data = {}
    with open(src_filename) as f:
        while True:
            try:
                line = next(f)
                line = line.strip().split()
                data[line[0]] = np.array(line[1: ], dtype=np.float)
            except StopIteration:
                break
            except UnicodeDecodeError:
                pass
    return data

In [6]:
import random

glove_dim = 300
GLOVE = glove2dict('./data/glove.6B/glove.6B.{}d.txt'.format(glove_dim))

def randvec(n=50, lower=-1.0, upper=1.0):
    """Returns a random vector of length `n`. `w` is ignored."""
    return np.array([random.uniform(lower, upper) for i in range(n)])

def glove_vec(w):
    """Return `w`'s GloVe representation if available, else return 
    a random vector."""
    return GLOVE.get(w, randvec(glove_dim))

def vec_average(u, v):
    """Averages np.array instances `u` and `v` into a new np.array"""
    return np.add(u, v) / 2

In [7]:
def bow_to_vec(features):
    all_word_vecs = []
    for f in features:
        i, c = f.split(':')    # index, count
        w = vocab[int(i)]      # get vocab word by index
        g = glove_vec(w)       # glove word embedding
        all_word_vecs.append(g)
    output = np.mean(all_word_vecs, axis=0)    
    return output

In [8]:
from collections import defaultdict

def get_data(filename, num_examples):
    with open(filename, 'r') as f:
        imdb = f.readlines()
    
    x_train, y_train = [], []
    label_count = defaultdict(int) # used to balance dataset
    for line in imdb:
        label, *features = line.split(' ')
        if label_count[label] >= NUM_EXAMPLES / NUM_CLASSES:
            continue
        x_train.append(bow_to_vec(features))
        y_train.append(int(label) - 1)
        label_count[label] += 1
    
    x_train = torch.tensor(x_train, dtype=torch.float)
    y_train = torch.tensor(y_train)
    return x_train, y_train

In [9]:
x_train_smol, y_train_smol = get_data('./data/aclImdb/train/labeledBow.feat', NUM_EXAMPLES)

## Model

In [16]:
n_in, n_h, n_out = glove_dim, 500, NUM_CLASSES

In [11]:
model = nn.Sequential(nn.Linear(n_in, n_h),
                     nn.ReLU(),
                     nn.Linear(n_h, n_out),
                     nn.Sigmoid())

In [12]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [13]:
for epoch in range(1000):
    y_pred = model(x_train_smol)
    loss = loss_fn(y_pred, y_train_smol)
    print('epoch: ', epoch, ' loss: ', loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

epoch:  0  loss:  2.30368709564209
epoch:  1  loss:  2.2481884956359863
epoch:  2  loss:  2.2100918292999268
epoch:  3  loss:  2.1869142055511475
epoch:  4  loss:  2.1751456260681152
epoch:  5  loss:  2.1700966358184814
epoch:  6  loss:  2.1681807041168213
epoch:  7  loss:  2.1675074100494385
epoch:  8  loss:  2.1673009395599365
epoch:  9  loss:  2.1672537326812744
epoch:  10  loss:  2.1672582626342773
epoch:  11  loss:  2.1672534942626953
epoch:  12  loss:  2.167214870452881
epoch:  13  loss:  2.167201042175293
epoch:  14  loss:  2.1671903133392334
epoch:  15  loss:  2.1671650409698486
epoch:  16  loss:  2.167114734649658
epoch:  17  loss:  2.1670045852661133
epoch:  18  loss:  2.166792392730713
epoch:  19  loss:  2.166329860687256
epoch:  20  loss:  2.1653075218200684
epoch:  21  loss:  2.1632444858551025
epoch:  22  loss:  2.1628458499908447
epoch:  23  loss:  2.1624181270599365
epoch:  24  loss:  2.160593271255493
epoch:  25  loss:  2.160588026046753
epoch:  26  loss:  2.1600055694

epoch:  216  loss:  2.0214898586273193
epoch:  217  loss:  2.021256685256958
epoch:  218  loss:  2.021023988723755
epoch:  219  loss:  2.0207743644714355
epoch:  220  loss:  2.020538806915283
epoch:  221  loss:  2.0203049182891846
epoch:  222  loss:  2.02006459236145
epoch:  223  loss:  2.0198240280151367
epoch:  224  loss:  2.0195884704589844
epoch:  225  loss:  2.019348621368408
epoch:  226  loss:  2.01910662651062
epoch:  227  loss:  2.018873453140259
epoch:  228  loss:  2.0186309814453125
epoch:  229  loss:  2.0183963775634766
epoch:  230  loss:  2.018155336380005
epoch:  231  loss:  2.017915725708008
epoch:  232  loss:  2.017677068710327
epoch:  233  loss:  2.0174384117126465
epoch:  234  loss:  2.017197847366333
epoch:  235  loss:  2.0169618129730225
epoch:  236  loss:  2.016720771789551
epoch:  237  loss:  2.016474485397339
epoch:  238  loss:  2.0162360668182373
epoch:  239  loss:  2.015995502471924
epoch:  240  loss:  2.0157523155212402
epoch:  241  loss:  2.0155134201049805
ep

epoch:  429  loss:  1.9004926681518555
epoch:  430  loss:  1.8991914987564087
epoch:  431  loss:  1.8979350328445435
epoch:  432  loss:  1.8962249755859375
epoch:  433  loss:  1.8979884386062622
epoch:  434  loss:  1.895485520362854
epoch:  435  loss:  1.8933027982711792
epoch:  436  loss:  1.8925955295562744
epoch:  437  loss:  1.893183708190918
epoch:  438  loss:  1.8927505016326904
epoch:  439  loss:  1.8905471563339233
epoch:  440  loss:  1.8895410299301147
epoch:  441  loss:  1.890012502670288
epoch:  442  loss:  1.889082670211792
epoch:  443  loss:  1.887732744216919
epoch:  444  loss:  1.887952446937561
epoch:  445  loss:  1.8873064517974854
epoch:  446  loss:  1.8866162300109863
epoch:  447  loss:  1.8863874673843384
epoch:  448  loss:  1.8857078552246094
epoch:  449  loss:  1.885438323020935
epoch:  450  loss:  1.8848899602890015
epoch:  451  loss:  1.884441614151001
epoch:  452  loss:  1.884110689163208
epoch:  453  loss:  1.8836166858673096
epoch:  454  loss:  1.883188009262

epoch:  641  loss:  1.8200483322143555
epoch:  642  loss:  1.8193798065185547
epoch:  643  loss:  1.8187612295150757
epoch:  644  loss:  1.8182543516159058
epoch:  645  loss:  1.8179162740707397
epoch:  646  loss:  1.8175729513168335
epoch:  647  loss:  1.8172913789749146
epoch:  648  loss:  1.8169550895690918
epoch:  649  loss:  1.8167387247085571
epoch:  650  loss:  1.8165844678878784
epoch:  651  loss:  1.8165454864501953
epoch:  652  loss:  1.8166381120681763
epoch:  653  loss:  1.8164399862289429
epoch:  654  loss:  1.81620192527771
epoch:  655  loss:  1.815547227859497
epoch:  656  loss:  1.814903974533081
epoch:  657  loss:  1.8143013715744019
epoch:  658  loss:  1.8138360977172852
epoch:  659  loss:  1.8135465383529663
epoch:  660  loss:  1.8132963180541992
epoch:  661  loss:  1.8130933046340942
epoch:  662  loss:  1.8128070831298828
epoch:  663  loss:  1.8125303983688354
epoch:  664  loss:  1.8122327327728271
epoch:  665  loss:  1.8120161294937134
epoch:  666  loss:  1.8118987

epoch:  853  loss:  1.7585536241531372
epoch:  854  loss:  1.7583470344543457
epoch:  855  loss:  1.7582428455352783
epoch:  856  loss:  1.7581157684326172
epoch:  857  loss:  1.7582075595855713
epoch:  858  loss:  1.757878065109253
epoch:  859  loss:  1.7578171491622925
epoch:  860  loss:  1.757344126701355
epoch:  861  loss:  1.7571004629135132
epoch:  862  loss:  1.7569248676300049
epoch:  863  loss:  1.7566639184951782
epoch:  864  loss:  1.7564842700958252
epoch:  865  loss:  1.755971908569336
epoch:  866  loss:  1.7554975748062134
epoch:  867  loss:  1.754995346069336
epoch:  868  loss:  1.7545708417892456
epoch:  869  loss:  1.754249930381775
epoch:  870  loss:  1.7540241479873657
epoch:  871  loss:  1.7538537979125977
epoch:  872  loss:  1.753692388534546
epoch:  873  loss:  1.753547191619873
epoch:  874  loss:  1.7533327341079712
epoch:  875  loss:  1.7531607151031494
epoch:  876  loss:  1.752809762954712
epoch:  877  loss:  1.7525330781936646
epoch:  878  loss:  1.75221478939

## Testing

In [14]:
x_test_smol, y_test_smol = get_data('./data/aclImdb/test/labeledBow.feat', 200)

In [15]:
y_pred = model(x_test_smol)
labels_pred = torch.argmax(y_pred, 1)
correct = (labels_pred == y_test_smol).sum().item()
print('Accuracy: ' + str(correct / len(y_test_smol)))

Accuracy: 0.278875
