In [3]:
from process_tweets import get_data
from rnn import MyModel
from train_validate import (train, validate)

import copy
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

In [8]:
# hyperparameters
train_batch_size = 250
test_batch_size = 250
embedding_size = 12
hidden_size = 64
num_layers = 1
dropout = 0.1
num_classes = 2
bidrectional = True
EPOCH = 20
train_file="../data/train_en.tsv"
test_file="../data/dev_en.tsv"
device="cpu"

train_dataset, test_dataset, alphabet, longest_sent, train_x, test_x = get_data(train_file, test_file)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)

model = MyModel(
    len(alphabet),
    longest_sent,
    embedding_size,
    hidden_size,
    num_layers,
    dropout,
    num_classes,
    bidrectional,
    device
)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters())

## training loop
best = 0.0
best_cm = None
best_model = None
train_acc_epoch = []
valid_acc_epoch = []
for epoch in range(EPOCH):
    # train loop
    train_acc, train_cm = train(epoch, train_loader, model, optimizer, criterion)
    train_acc_epoch.append(train_acc.detach().cpu().numpy())

    # validation loop
    valid_acc, valid_cm = validate(epoch, test_loader, model, criterion)
    valid_acc_epoch.append(valid_acc.detach().cpu().numpy())

    if valid_acc > best:
        best = valid_acc
        best_cm = valid_cm
        best_model = copy.deepcopy(model)

print('Best Prec @1 Acccuracy: {:.4f}'.format(best))
per_cls_acc = best_cm.diag().detach().numpy().tolist()
for i, acc_i in enumerate(per_cls_acc):
    print("Accuracy of Class {}: {:.4f}".format(i, acc_i))


plt.figure()
plt.plot(range(EPOCH), train_acc_epoch, label='train')
plt.plot(range(EPOCH), valid_acc_epoch, label='validation')
plt.legend()
plt.title("accuracy curve")
plt.xlabel('epoch')
plt.ylabel('accuracy')
# plt.show()

first 3 raw tweets: ['Hurray, saving us $$$ in so many ways @potus @realDonaldTrump #LockThemUp #BuildTheWall #EndDACA #BoycottNFL #BoycottNike', "Why would young fighting age men be the vast majority of the ones escaping a war &amp; not those who cannot fight like women, children, and the elderly?It's because the majority of the refugees are not actually refugees they are economic migrants trying to get into Europe.... https://t.co/Ks0SHbtYqn", '@KamalaHarris Illegals Dump their Kids at the border like Road Kill and Refuse to Unite! They Hope they get Amnesty, Free Education and Welfare Illegal #FamilesBelongTogether in their Country not on the Taxpayer Dime Its a SCAM #NoDACA #NoAmnesty #SendThe']
first 3 processed tweets: ['hurray saving us many ways', 'would young fighting age men vast majority ones escaping war amp cannot fight like women children elderlyits majority refugees actually refugees economic migrants trying get europe', 'illegals dump kids border like road kill refuse u



Epoch: [0][0/36]	Time 0.510 (0.510)	Loss 0.7154 (0.7154)	Prec @1 0.4200 (0.4200)	
Epoch: [0][10/36]	Time 0.412 (0.400)	Loss 0.7800 (1.2034)	Prec @1 0.6480 (0.5164)	


KeyboardInterrupt: 

In [9]:
for idx, (data, target) in enumerate(train_loader):
    print(data)
    break

tensor([[  459.,  6321.,  6322.,  ...,     0.,     0.,     0.],
        [ 1950.,  3790.,  6438.,  ...,     0.,     0.,     0.],
        [   79.,    80.,    81.,  ...,     0.,     0.,     0.],
        ...,
        [  333.,   137.,  2231.,  ...,     0.,     0.,     0.],
        [ 2437.,  2472., 14844.,  ...,     0.,     0.,     0.],
        [ 2244.,   192.,    42.,  ...,     0.,     0.,     0.]])


In [11]:
target.shape

torch.Size([250])