In [32]:
import torch
import torch.cuda
import numpy as np

In [33]:
# setting device on GPU if available, else CPU
device = torch.device('cpu') # 
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cpu



In [34]:
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

import random
seed_everything(random.randint(1, 10000))

In [35]:
from torchtext.data import get_tokenizer
def load_vocab(path):
    import pickle
    with open(path,'rb') as file:
           vocab = pickle.load(file)
    return vocab

vocab = load_vocab('./vocab.pkl')

In [36]:
import torch.nn as nn
import torch.nn.functional as F

In [37]:
target_classes = ["World", "Sports", "Business", "Sci/Tech"]
len_vocab = 40708
max_words = 50

In [38]:
embed_len = 50
hidden_dim = 50
n_layers = 1

class RNN_1(nn.Module):
    def __init__(self):
        super(RNN_1, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers

        self.embedding_layer = nn.Embedding(num_embeddings=len_vocab, embedding_dim=embed_len)
        self.rnn = nn.RNN(input_size=embed_len, hidden_size=hidden_dim, num_layers=n_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, len(target_classes))
        self.dropout = nn.Dropout(0.25)
        #self.hidden = torch.nn.parameter.Parameter(torch.zeros(n_layers, batch_size, self.hidden_dim, device=device))

    def forward(self, X_batch):
        embeddings = self.embedding_layer(X_batch)

        output, hidden = self.rnn(embeddings, torch.randn(n_layers, len(X_batch), hidden_dim, device=device))
        output = self.linear(output)
        return output[:,-1]

In [39]:
embed_len = 50
hidden_dim_1 = 40
hidden_dim_2 = 50
hidden_dim_3 = 60
n_layers = 1

class RNN_2(nn.Module):
    def __init__(self):
        super(RNN_2, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers

        self.embedding_layer = nn.Embedding(num_embeddings=len_vocab, embedding_dim=embed_len)
        self.rnn1 = nn.RNN(input_size=embed_len, hidden_size=hidden_dim_1, num_layers=1, batch_first=True)
        self.rnn2 = nn.RNN(input_size=hidden_dim_1, hidden_size=hidden_dim_2, num_layers=1, batch_first=True)
        self.rnn3 = nn.RNN(input_size=hidden_dim_2, hidden_size=hidden_dim_3, num_layers=1, batch_first=True)
        self.linear = nn.Linear(hidden_dim_3, len(target_classes))
        self.dropout = nn.Dropout(0.25)
        #self.hidden = torch.nn.parameter.Parameter(torch.zeros(n_layers, batch_size, self.hidden_dim, device=device))

    def forward(self, X_batch):
        embeddings = self.embedding_layer(X_batch)

        output, hidden = self.rnn1(embeddings, torch.randn(n_layers, len(X_batch), hidden_dim_1, device=device))
        output, hidden = self.rnn2(output, torch.randn(n_layers, len(X_batch), hidden_dim_2, device=device))
        output, hidden = self.rnn3(output, torch.randn(n_layers, len(X_batch), hidden_dim_3, device=device))

        output = self.linear(output)
        return output[:,-1]

In [40]:
embed_len = 50
hidden_size = 50

class LSTM_1(nn.Module):
    def __init__(self):
        super(LSTM_1, self).__init__()
        
        self.hidden_size = hidden_size
        self.embedding_layer = nn.Embedding(num_embeddings=len_vocab, embedding_dim=embed_len)
        self.lstm = nn.LSTM(max_words, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, len(target_classes))
        
    def forward(self, X_batch):
        embeddings = self.embedding_layer(X_batch)
        lstm_out, _ = self.lstm(embeddings)
        output = self.fc(lstm_out[:, -1])
        return output

In [41]:
embed_len = 50
hidden_size = 50
num_layers = 3

class LSTM_2(nn.Module):
    def __init__(self):
        super(LSTM_2, self).__init__()
        
        self.hidden_size = hidden_size
        self.embedding_layer = nn.Embedding(num_embeddings=len_vocab, embedding_dim=embed_len)
        self.lstm_layers = []
        self.lstm_layers.append(nn.LSTM(max_words, hidden_size, batch_first=True))
        for i in range(num_layers - 1):
            self.lstm_layers.append(nn.LSTM(hidden_size, hidden_size, batch_first=True))
        
        self.fc = nn.Linear(hidden_size, len(target_classes))
        
    def forward(self, X_batch):
        embeddings = self.embedding_layer(X_batch)
        lstm_out, _ = self.lstm_layers[0](embeddings)
        for i in range(1, num_layers):
            lstm_out, _ = self.lstm_layers[i](lstm_out)
            
        output = self.fc(lstm_out[:, -1])
        return output

In [42]:
nets = []
net_1 = RNN_1()
net_1.load_state_dict(torch.load('net_rnn_simple.pth'))
nets.append(net_1)
net_2 = RNN_2()
net_2.load_state_dict(torch.load('net_rnn_layers.pth'))
nets.append(net_2)
net_3 = LSTM_1()
net_3.load_state_dict(torch.load('net_lstm_simple.pth'))
nets.append(net_3)
net_4 = LSTM_2()
net_4.load_state_dict(torch.load('net_lstm_layers.pth'))
nets.append(net_4)

In [43]:
def pre_text(string, model, classes):
   with torch.no_grad():
      model.eval()  
      output = model(string)
      index = output.data.cpu().numpy().argmax()
      class_name = classes[index]
      return class_name

In [58]:
text = "Record Win for Team USA: New Olympic Gold Medal Count in Football"
tokenizer = get_tokenizer("basic_english")
tokens = tokenizer(text)
tokens = vocab(tokens)
tokens = torch.tensor(tokens, dtype=torch.int32)
tokens = torch.unsqueeze(tokens, dim=0)
for i, net in enumerate(nets):
    print(f"Net {i} says it's {pre_text(tokens, net, target_classes)}")

Net 0 says it's Sports
Net 1 says it's Sports
Net 2 says it's Sports
Net 3 says it's Business
