In [None]:
import torch

from tensor import vocab, train_dataset, test_dataset, classes
from bow import train_epoch
from embedding import padify

vocab_size = len(vocab)

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else device)

In [None]:
class RNNClassifier(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_class):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
        self.rnn = torch.nn.RNN(input_size=embed_dim, hidden_size=hidden_dim, batch_first=True)
        # fc: fully connected
        self.fc = torch.nn.Linear(in_features=hidden_dim, out_features=num_class)

    def forward(self, x):
        x = self.embedding(x)
        # x.shape: (batch_size(16), バッチ内の全体の文章の語彙数の最大値, embed_dim(64)))

        x, h = self.rnn(x)
        # x.shape: (16, バッチ内の全体の文章の語彙数の最大値, 32)

        x_mean = x.mean(dim=1)
        # x.shape: (16, 32)

        out = self.fc(x_mean)
        # out.shape: (16, 4)

        return out

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=padify, shuffle=True)
network = RNNClassifier(vocab_size, embed_dim=64, hidden_dim=32, num_class=len(classes)).to(device)

In [None]:
train_epoch(network, train_loader, learning_rate=0.001, epoch_size=1000)

In [None]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, collate_fn=padify, shuffle=True)

In [None]:
network.eval()

with torch.no_grad():
    for batch_idx, (target, data) in enumerate(test_loader):
        word_lookup = [vocab.itos[w] for w in data[batch_idx]]
        unknow_vals = {'<unk>'}
        word_lookup = [ele for ele in word_lookup if ele not in unknow_vals]
        print(f'Input text:\n {word_lookup}\n')
        
        data, target = data.to(device), target.to(device)
        pred = network(data)
        print(torch.argmax(pred[batch_idx]))
        print(f"Actual:\nvalue={target[batch_idx]}, class_name= {classes[target[batch_idx]]}\n")
        print(f"Predicted:\nvalue={pred[0].argmax(0)}, class_name= {classes[pred[0].argmax(0)]}\n")
        break

In [None]:
class LSTMClassifier(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_class):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.embedding.weight.data = torch.randn_like(self.embedding.weight.data) - 0.5
        self.rnn = torch.nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, num_class)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.embedding(x)
        x, (h,c) = self.rnn(x)
        return self.fc(h[-1])

network = LSTMClassifier(vocab_size,64,32,len(classes)).to(device)

In [None]:
train_epoch(network, train_loader, lr=0.001)