In [28]:
import torch
from torch import nn
from torch import optim
from data_loader import get_loader
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [39]:
class Model(nn.Module):
    def __init__(self, embedding_size, vocab_size, number_classes):
        super(Model, self).__init__()
        self.embedding_size = embedding_size # embedding space dimension
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, number_classes)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, input_tensor):
        batch_size = input_tensor.shape[0]
        input_tensor = input_tensor.permute(1, 0) # seq_len * batch_size
        x = self.embedding(input_tensor) # seq_len * batch_size * embedding_size
        x = torch.mean(x, 0) # batch_size * embedding_size
        logits = self.linear(x) # batch_size * number_classes
        return self.softmax(logits)

In [32]:
def trainOneBatch(model, batch_input, optimizer, criterion):
    optimizer.zero_grad()
    sequences = batch_input[0] # get input sequence of shape: batch_size * sequence_len
    targets = batch_input[1] # get targets of shape : batch_size
    out = model.forward(sequences) # shape: batch_size * number_classes 
    loss = criterion(out, targets)
    loss.backward() # compute the gradient
    optimizer.step() # update network parameters
    return loss.item() # return loss value

In [33]:
def evaluate(model, data_loader):
    count_batch = 0
    accuracy = 0
    for batch in data_loader:
        sequences = batch[0]
        target = batch[1]
        out, = model.forward(sequences)
        predicted = torch.argmax(out, -1)
        accuracy += torch.sum(predicted==target).item()/(sequences.shape[0])
        count_batch += 1
    accuracy = accuracy/count_batch
    return accuracy

In [30]:
def trainModel(model, path_documents, path_labels, word2ind, n_epochs=5, batch_size=16,  printEvery=20):
    data_loader_params = (path_documents, path_labels, word2ind, str(device), batch_size)
    epoch = 0
    loss = 0
    count_iter = 0
    patience = 0 # before interrupting training 
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    #negative log likelihood
    criterion = nn.NLLLoss()
    time1 = time.time()
    training_accuracy_epochs = [] # save training accuracy for each epoch
    validation_accuracy_epochs = [] # save validation accuracy for each epoch 
    for i in range(n_epochs):
        loader = get_loader(*data_loader_params)
        for batch in loader:
            loss += trainOneBatch(model, batch, optimizer, criterion)
            count_iter += 1
            if count_iter % printEvery == 0:
                time2 = time.time()
                print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                          time2 - time1, loss/printEvery))
                loss = 0
        training_accuracy = evaluate(model, get_loader(*data_loader_params))
        validation_accuracy = evaluate(model, get_loader(*data_loader_params))
        print('Epoch {0} done: training_accuracy = {1:.3f}, validation_accuracy = {2:.3f}'.format(i+1, training_accuracy, validation_accuracy))

In [19]:
path_cat2ind = 'data/cat2ind.csv'
path_word_count = 'data/word2count.txt'

#load index to category mapping
ind2category = {}
word2ind = {'PAD':0, 'OOV':1}
with open(path_cat2ind, encoding='utf-8') as f:
    for line in f:
        mapping = line.split(',')
        ind2category[int(mapping[1])] = mapping[0]

#load word to index mapping
count = 2
with open(path_word_count) as f:
    for line in f:
        mapping = line.split('\t')
        word2ind[mapping[0]] = count
        count+=1

In [23]:
my_model = Model(50, len(word2ind), len(ind2category)).to(device)

In [40]:
path_documents_train = 'data/train_documents.txt'
path_labels_train = 'data/train_labels.txt'
trainModel(my_model, path_documents_train, path_labels_train, word2ind)

NameError: name 'M' is not defined