In [1]:
import torch
import torch.nn as nn
import numpy as np

## Data

In [3]:
with open('./data/aclImdb/imdb.vocab', 'r') as f:
    VOCAB_LEN = len(f.readlines())
VOCAB_LEN

89527

In [58]:
NUM_CLASSES = 10
NUM_EXAMPLES = 1000

In [5]:
def one_hot_encode(n):
    arr = np.zeros(NUM_CLASSES)
    arr[int(n)-1] = 1
    return arr

In [6]:
def bow_to_vec(features):
    arr = np.zeros(VOCAB_LEN)
    for f in features:
        i, c = f.split(':') # index, count
        arr[int(i)] = int(c)
    return arr

In [43]:
from collections import defaultdict

def get_data(filename, num_examples):
    with open(filename, 'r') as f:
        imdb = f.readlines()
    
    x_train, y_train = [], []
    label_count = defaultdict(int) # used to balance dataset
    for line in imdb:
        label, *features = line.split(' ')
        if label_count[label] >= NUM_EXAMPLES / NUM_CLASSES:
            continue
        x_train.append(bow_to_vec(features))
        y_train.append(int(label) - 1)
        label_count[label] += 1
    
    x_train = torch.tensor(x_train, dtype=torch.float)
    y_train = torch.tensor(y_train)
    return x_train, y_train

In [59]:
x_train_smol, y_train_smol = get_data('./data/aclImdb/train/labeledBow.feat', NUM_EXAMPLES)

## Model

In [60]:
n_in, n_h, n_out = VOCAB_LEN, NUM_EXAMPLES, NUM_CLASSES

In [61]:
model = nn.Sequential(nn.Linear(n_in, n_h),
                     nn.ReLU(),
                     nn.Linear(n_h, n_out),
                     nn.Sigmoid())

In [62]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [63]:
for epoch in range(150):
    y_pred = model(x_train_smol)
    loss = loss_fn(y_pred, y_train_smol)
    print('epoch: ', epoch, ' loss: ', loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

epoch:  0  loss:  2.301706552505493
epoch:  1  loss:  2.219895124435425
epoch:  2  loss:  2.1454620361328125
epoch:  3  loss:  2.0843961238861084
epoch:  4  loss:  2.0191826820373535
epoch:  5  loss:  1.9466955661773682
epoch:  6  loss:  1.8802120685577393
epoch:  7  loss:  1.8284648656845093
epoch:  8  loss:  1.7799153327941895
epoch:  9  loss:  1.7327296733856201
epoch:  10  loss:  1.6932847499847412
epoch:  11  loss:  1.660703182220459
epoch:  12  loss:  1.6323150396347046
epoch:  13  loss:  1.6071144342422485
epoch:  14  loss:  1.585159420967102
epoch:  15  loss:  1.5665827989578247
epoch:  16  loss:  1.551084041595459
epoch:  17  loss:  1.5380465984344482
epoch:  18  loss:  1.5269041061401367
epoch:  19  loss:  1.5173335075378418
epoch:  20  loss:  1.5092025995254517
epoch:  21  loss:  1.5023995637893677
epoch:  22  loss:  1.4967471361160278
epoch:  23  loss:  1.4920393228530884
epoch:  24  loss:  1.4880836009979248
epoch:  25  loss:  1.4847332239151
epoch:  26  loss:  1.481876373

## Testing

In [64]:
x_test_smol, y_test_smol = get_data('./data/aclImdb/test/labeledBow.feat', 200)

In [65]:
y_pred = model(x_test_smol)
labels_pred = torch.argmax(y_pred, 1)
correct = (labels_pred == y_test_smol).sum().item()
print('Accuracy: ' + str(correct / len(y_test_smol)))

Accuracy: 0.2375
