In [1]:
import torch
import torch.nn as nn
import numpy as np

## Data

In [5]:
with open('./data/aclImdb/imdb.vocab', 'r') as f:
    vocab = f.read().splitlines()

'the'

In [58]:
VOCAB_LEN = len(vocab)
NUM_CLASSES = 10
NUM_EXAMPLES = 2000

In [5]:
def one_hot_encode(n):
    arr = np.zeros(NUM_CLASSES)
    arr[int(n)-1] = 1
    return arr

In [18]:
def glove2dict(src_filename):
    data = {}
    with open(src_filename) as f:
        while True:
            try:
                line = next(f)
                line = line.strip().split()
                data[line[0]] = np.array(line[1: ], dtype=np.float)
            except StopIteration:
                break
            except UnicodeDecodeError:
                pass
    return data

In [68]:
import random

glove_dim = 300
GLOVE = glove2dict('./data/glove.6B/glove.6B.{}d.txt'.format(glove_dim))

def randvec(n=50, lower=-1.0, upper=1.0):
    """Returns a random vector of length `n`. `w` is ignored."""
    return np.array([random.uniform(lower, upper) for i in range(n)])

def glove_vec(w):
    """Return `w`'s GloVe representation if available, else return 
    a random vector."""
    return GLOVE.get(w, randvec(glove_dim))

def vec_average(u, v):
    """Averages np.array instances `u` and `v` into a new np.array"""
    return np.add(u, v) / 2

In [69]:
def bow_to_vec(features):
    all_word_vecs = []
    for f in features:
        i, c = f.split(':')    # index, count
        w = vocab[int(i)]      # get vocab word by index
        g = glove_vec(w)       # glove word embedding
        all_word_vecs.append(g)
    output = np.mean(all_word_vecs, axis=0)    
#     arr = np.zeros(VOCAB_LEN)
#     for f in features:
#         i, c = f.split(':') # index, count
#         arr[int(i)] = int(c)
    return output

In [70]:
from collections import defaultdict

def get_data(filename, num_examples):
    with open(filename, 'r') as f:
        imdb = f.readlines()
    
    x_train, y_train = [], []
    label_count = defaultdict(int) # used to balance dataset
    for line in imdb:
        label, *features = line.split(' ')
        if label_count[label] >= NUM_EXAMPLES / NUM_CLASSES:
            continue
        x_train.append(bow_to_vec(features))
        y_train.append(int(label) - 1)
        label_count[label] += 1
    
    x_train = torch.tensor(x_train, dtype=torch.float)
    y_train = torch.tensor(y_train)
    return x_train, y_train

In [71]:
x_train_smol, y_train_smol = get_data('./data/aclImdb/train/labeledBow.feat', NUM_EXAMPLES)

## Model

In [72]:
n_in, n_h, n_out = glove_dim, NUM_EXAMPLES, NUM_CLASSES

In [73]:
model = nn.Sequential(nn.Linear(n_in, n_h),
                     nn.ReLU(),
                     nn.Linear(n_h, n_out),
                     nn.Sigmoid())

In [74]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
for epoch in range(1000):
    y_pred = model(x_train_smol)
    loss = loss_fn(y_pred, y_train_smol)
    print('epoch: ', epoch, ' loss: ', loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

epoch:  0  loss:  2.302448034286499
epoch:  1  loss:  2.2842350006103516
epoch:  2  loss:  2.2670977115631104
epoch:  3  loss:  2.250572919845581
epoch:  4  loss:  2.2347233295440674
epoch:  5  loss:  2.2199411392211914
epoch:  6  loss:  2.206761360168457
epoch:  7  loss:  2.195650100708008
epoch:  8  loss:  2.186852216720581
epoch:  9  loss:  2.1802971363067627
epoch:  10  loss:  2.175677537918091
epoch:  11  loss:  2.172576665878296
epoch:  12  loss:  2.1705641746520996
epoch:  13  loss:  2.1692888736724854
epoch:  14  loss:  2.1684906482696533
epoch:  15  loss:  2.1679940223693848
epoch:  16  loss:  2.1676816940307617
epoch:  17  loss:  2.167489528656006
epoch:  18  loss:  2.1673688888549805
epoch:  19  loss:  2.167290687561035
epoch:  20  loss:  2.1672394275665283
epoch:  21  loss:  2.1672046184539795
epoch:  22  loss:  2.167182683944702
epoch:  23  loss:  2.1671621799468994
epoch:  24  loss:  2.167142391204834
epoch:  25  loss:  2.1671228408813477
epoch:  26  loss:  2.167099714279

epoch:  218  loss:  1.9038013219833374
epoch:  219  loss:  1.9028489589691162
epoch:  220  loss:  1.9019027948379517
epoch:  221  loss:  1.9009655714035034
epoch:  222  loss:  1.9000545740127563
epoch:  223  loss:  1.899138331413269
epoch:  224  loss:  1.898209810256958
epoch:  225  loss:  1.8972816467285156
epoch:  226  loss:  1.896366000175476
epoch:  227  loss:  1.8954493999481201
epoch:  228  loss:  1.8945398330688477
epoch:  229  loss:  1.89363694190979
epoch:  230  loss:  1.8927361965179443
epoch:  231  loss:  1.8918300867080688
epoch:  232  loss:  1.8909207582473755
epoch:  233  loss:  1.8900200128555298
epoch:  234  loss:  1.8891302347183228
epoch:  235  loss:  1.8882352113723755
epoch:  236  loss:  1.8873395919799805
epoch:  237  loss:  1.8864452838897705
epoch:  238  loss:  1.8855538368225098
epoch:  239  loss:  1.8846606016159058
epoch:  240  loss:  1.8837765455245972
epoch:  241  loss:  1.8828963041305542
epoch:  242  loss:  1.8820136785507202
epoch:  243  loss:  1.88112795

epoch:  433  loss:  1.7594785690307617
epoch:  434  loss:  1.7590011358261108
epoch:  435  loss:  1.7585376501083374
epoch:  436  loss:  1.7580877542495728
epoch:  437  loss:  1.7576411962509155
epoch:  438  loss:  1.7572035789489746
epoch:  439  loss:  1.7567652463912964
epoch:  440  loss:  1.756325125694275
epoch:  441  loss:  1.7558822631835938
epoch:  442  loss:  1.755433201789856
epoch:  443  loss:  1.7549787759780884
epoch:  444  loss:  1.7545243501663208
epoch:  445  loss:  1.7540676593780518
epoch:  446  loss:  1.7536147832870483
epoch:  447  loss:  1.7531635761260986
epoch:  448  loss:  1.7527189254760742
epoch:  449  loss:  1.7522735595703125
epoch:  450  loss:  1.7518330812454224
epoch:  451  loss:  1.7513967752456665
epoch:  452  loss:  1.75095534324646
epoch:  453  loss:  1.7505203485488892
epoch:  454  loss:  1.750085473060608
epoch:  455  loss:  1.7496495246887207
epoch:  456  loss:  1.7492165565490723
epoch:  457  loss:  1.7487902641296387
epoch:  458  loss:  1.74836421

epoch:  645  loss:  1.6847797632217407
epoch:  646  loss:  1.6844990253448486
epoch:  647  loss:  1.684220790863037
epoch:  648  loss:  1.6839568614959717
epoch:  649  loss:  1.6837029457092285
epoch:  650  loss:  1.6834557056427002
epoch:  651  loss:  1.6832166910171509
epoch:  652  loss:  1.682981252670288
epoch:  653  loss:  1.6827497482299805
epoch:  654  loss:  1.6825120449066162
epoch:  655  loss:  1.6822677850723267
epoch:  656  loss:  1.6820124387741089
epoch:  657  loss:  1.6817502975463867
epoch:  658  loss:  1.681488037109375
epoch:  659  loss:  1.6812282800674438
epoch:  660  loss:  1.680972695350647
epoch:  661  loss:  1.68072509765625
epoch:  662  loss:  1.680483102798462
epoch:  663  loss:  1.680240511894226
epoch:  664  loss:  1.6800003051757812
epoch:  665  loss:  1.6797668933868408
epoch:  666  loss:  1.6795331239700317
epoch:  667  loss:  1.6793023347854614
epoch:  668  loss:  1.679073691368103
epoch:  669  loss:  1.6788409948349
epoch:  670  loss:  1.678603649139404

## Testing

In [66]:
x_test_smol, y_test_smol = get_data('./data/aclImdb/test/labeledBow.feat', 200)

In [67]:
y_pred = model(x_test_smol)
labels_pred = torch.argmax(y_pred, 1)
correct = (labels_pred == y_test_smol).sum().item()
print('Accuracy: ' + str(correct / len(y_test_smol)))

Accuracy: 0.23875
