In [None]:
import numpy as np
import torch
import torchtext
from torchtext import data
from torchtext import datasets
from torchtext.vocab import Vectors, GloVe, CharNGram, FastText
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim

batch_size = 256
device = torch.device("cuda")
# device = torch.device('cpu')

In [None]:
################################
# DataLoader
################################

# set up fields
TEXT = data.Field()
LABEL = data.Field(sequential=False, dtype=torch.long)

# make splits for data
# DO NOT MODIFY: fine_grained=True, train_subtrees=False
train, val, test = datasets.SST.splits(
    TEXT, LABEL, fine_grained=True, train_subtrees=False)

# print information about the data
# print('train.fields', train.fields)
# print('len(train)', len(train))
# print('vars(train[0])', vars(train[0]))

# build the vocabulary
# you can use other pretrained vectors, refer to https://github.com/pytorch/text/blob/master/torchtext/vocab.py
# TEXT.build_vocab(train, vectors=Vectors(name='vector.txt', cache='.data'))
TEXT.build_vocab(train, vectors=Vectors(name='vector.txt', cache='./data'))
LABEL.build_vocab(train)
# We can also see the vocabulary directly using either the stoi (string to int) or itos (int to string) method.
# print(TEXT.vocab.itos[:10])
# print(LABEL.vocab.stoi)
# print(TEXT.vocab.freqs.most_common(20))

# print vocab information
# print('len(TEXT.vocab)', len(TEXT.vocab))
# print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size())

# make iterator for splits
train_iter, val_iter, test_iter = data.BucketIterator.splits(
    (train, val, test), batch_size=batch_size)

# print batch information
batch = next(iter(train_iter)) # for batch in train_iter
# print(batch.text) # input sequence
# print(batch.label) # ground truth

# Attention: batch.label in the range [1,5] not [0,4] !!!

# Copy the pre-trained word embeddings we loaded earlier into the embedding layer of our model.
pretrained_embeddings = TEXT.vocab.vectors

print(pretrained_embeddings.shape)

# you should maintain a nn.embedding layer in your network
# model.embedding.weight.data.copy_(pretrained_embeddings)

In [None]:
################################
# DataLoader
################################

# set up fields
TEXT = data.Field()
LABEL = data.Field(sequential=False,dtype=torch.long)

# make splits for data
# DO NOT MODIFY: fine_grained=True, train_subtrees=False
train, val, test = datasets.SST.splits(
    TEXT, LABEL, fine_grained=True, train_subtrees=False)

# print information about the data
print('train.fields', train.fields)
print('len(train)', len(train))
print('vars(train[0])', vars(train[0]))

# build the vocabulary
# you can use other pretrained vectors, refer to https://github.com/pytorch/text/blob/master/torchtext/vocab.py
TEXT.build_vocab(train, vectors=Vectors(name='vector.txt', cache='./data'))
LABEL.build_vocab(train)
# We can also see the vocabulary directly using either the stoi (string to int) or itos (int to string) method.
print(TEXT.vocab.itos[:10])
print(LABEL.vocab.stoi)
print(TEXT.vocab.freqs.most_common(20))

# print vocab information
print('len(TEXT.vocab)', len(TEXT.vocab))
print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size())

# make iterator for splits
train_iter, val_iter, test_iter = data.BucketIterator.splits(
    (train, val, test), batch_size=64)

# print batch information
batch = next(iter(train_iter)) # for batch in train_iter
print(batch.text) # input sequence
print(batch.label) # groud truth

# Attention: batch.label in the range [1,5] not [0,4] !!!



In [None]:
num_epoch = 5
disp_train = 20
disp_eval = 10
disp_test = 10
lmd = 0.0
lr = 0.001

In [None]:
def accuracy(out, label):
    # out : N * 5
    return np.count_nonzero((out.argmax(1) == label).cpu().numpy()) / out.shape[0]

In [None]:
hidden1 = 512
hidden2 = 128
layer = 5
dropout = 0

class RNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(18280, 300)
        self.embedding.weight.data.copy_(pretrained_embeddings)
        for p in self.embedding.parameters():
            p.requires_grad = False
        self.rnn = nn.LSTM(input_size=300, hidden_size=hidden1, num_layers=layer, dropout=dropout, bidirectional=True)
        # self.fc1 = nn.Linear(2 * hidden1, hidden2)
        self.fc1 = nn.Linear(2 * hidden1, 5)
        self.fc2 = nn.Linear(hidden2, 5)
        self.acti = nn.ReLU()
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.parameters(), lr = lr, weight_decay=lmd)
        self.train_loss = []
        self.valid_loss = []
        self.accuracy = []

    def forward(self, x):
        out, _ = self.rnn(self.embedding(x))
        y = self.fc1(out[-1, :, :])
        # y = self.fc2(self.acti(y))
        return y
    
    def train_epoch(self):
        self.train()
        epoch_loss = 0
        for i, batch in enumerate(train_iter):
            x = batch.text.to(device)
            y = (batch.label - 1).to(device)
            self.optimizer.zero_grad()
            output = self(x)
            loss = self.criterion(output, y)
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.item()
            if i % disp_train == 0:
                print(f"train batch {i}: train loss = {loss.item()}")
                print(f"accu: {accuracy(output, y)}")
        return epoch_loss / len(train_iter)

    def eval_epoch(self):
        self.eval()
        epoch_loss = 0
        epoch_accu = 0
        for i, batch in enumerate(val_iter):
            x = batch.text.to(device)
            y = (batch.label - 1).to(device)
            output = self(x)
            loss = self.criterion(output, y)
            epoch_loss += loss.item()
            accu = accuracy(output, y)
            epoch_accu += accu
            if i % disp_eval == 0:
                print(f"valid batch {i}: valid loss = {loss.item()}")
                print(f"accuracy: {accu}")
        return epoch_loss / len(val_iter), epoch_accu / len(val_iter)

    def train_network(self):
        self.to(device)
        for e in range(num_epoch):
            tloss = self.train_epoch()
            vloss, accu = self.eval_epoch()
            # self.scheduler.step()
            print("-----")
            print(f"epoch: {e}, average training loss: {tloss}, average validation loss: {vloss}, valid accuracy: {accu}")
            print("-----")
            self.train_loss.append(tloss)
            self.accuracy.append(accu)
            self.valid_loss.append(vloss)

    def test(self):
        self.eval()
        test_loss = 0
        test_accu = 0
        for i, batch in enumerate(test_iter):
            x = batch.text.to(device)
            y = (batch.label - 1).to(device)
            output = self(x)
            loss = self.criterion(output, y)
            test_loss += loss.item()
            accu = accuracy(output, y)
            test_accu += accu
            if i % disp_test == 0:
                print(f"valid batch {i}: valid loss = {loss.item()}")
                print(f"accuracy: {accu}")
        return test_loss / len(test_iter), test_accu / len(test_iter)

    def plot(self):
        fig, ax = plt.subplots(3,1)
        ax[0].plot(range(num_epoch), self.train_loss)
        ax[0].set_ylabel("train loss")
        ax[0].set_xlabel("epoch")
        ax[1].plot(range(num_epoch), self.valid_loss)
        ax[1].set_ylabel("valid loss")
        ax[1].set_xlabel("epoch")
        ax[2].plot(range(num_epoch), self.accuracy)
        ax[2].set_ylabel("accuracy")
        ax[2].set_xlabel("epoch")

In [None]:
net = RNNModel()
net.train_network()
net.plot()

In [None]:
batch.text.shape

In [None]:
batch.label.shape

In [None]:
net.embedding(batch.text.to(device)).shape

In [None]:
aa = net.embedding(batch.text.to(device))
aa

In [None]:
a = net.rnn(net.embedding(batch.text.to(device)))
a

In [None]:
for p in net.rnn.named_parameters():
    print(p)

In [None]:
b = a[0][-1,:,:]
b

In [None]:
b.squeeze(0).shape

In [None]:
c = a[1][0][-1,:,:]
c

In [None]:
d = a[1][0][-2,:,:]
d

In [None]:
b.shape, c.shape, d.shape

In [None]:
b[:, -512:]

In [None]:
torch.cat((c, d), 1)

In [None]:
batch.label

In [None]:
net.embedding(batch.text.to(device)).shape

In [None]:
l, a = net.test()
print(f"test loss: {l}, test accuracy: {a}")