In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import random
import numpy as np

from tqdm import tqdm

device = torch.device("cuda")

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x2ac15d34690>

In [42]:
word2idx = {"<unk>": 0}
label2idx = {}
idx2word = ["<unk>"]
idx2label = []

train_data = []
with open('train.txt', 'rt', encoding='UTF8') as f:
    for line in f:
        text, author = line.strip().split()
        for c in text:
            if c not in word2idx:
                word2idx[c] = len(idx2word)
                idx2word.append(c)
        if author not in label2idx:
            label2idx[author] = len(idx2label)
            idx2label.append(author)
        train_data.append((text, author))

valid_data = []
with open("valid.txt", 'rt', encoding='UTF8') as f:
    for line in f:
        text, author = line.strip().split()
        valid_data.append((text, author))

test_data = []
with open("test.txt", 'rt', encoding='UTF8') as f:
    for line in f:
        text, author = line.strip().split()
        test_data.append((text, author))

In [43]:
print(len(word2idx), len(idx2word), len(label2idx), len(idx2label))
print(len(train_data), len(valid_data), len(test_data))

4941 4941 5 5
11271 1408 1410


In [44]:
def make_data(text, author):
    """
    input
        text: str
        author: str
    output
        x: LongTensor, shape = (1, text_length)
        y: LongTensor, shape = (1,)
    """
    
    x = torch.LongTensor([word2idx[text[i]] if text[i] in word2idx else 0 for i in range(len(text))])
    y = torch.LongTensor([label2idx[author]])
    
    return x, y

In [124]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.f = nn.Linear(input_size + hidden_size, hidden_size)
        self.i = nn.Linear(input_size + hidden_size, hidden_size)
        self.o = nn.Linear(input_size + hidden_size, hidden_size)
        self.g = nn.Linear(input_size + hidden_size, hidden_size)
    
    def forward(self, ht, ct, xt):
        # ht: 1 * hidden_size
        # ct: 1 * hidden_size
        # xt: 1 * input_size
        input_combined = torch.cat((xt, ht), 1)
        ft = torch.sigmoid(self.f(input_combined))
        it = torch.sigmoid(self.i(input_combined))
        ot = torch.sigmoid(self.o(input_combined))
        gt = torch.tanh(self.g(input_combined))
        ct = ft * ct + it * gt
        ht = ot * torch.tanh(ct)
        return ht, ct

In [143]:
class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size // 2
        self.lstm_forward = LSTM(input_size, hidden_size // 2)
        self.lstm_backward = LSTM(input_size, hidden_size // 2)
        self.register_buffer("_float", torch.zeros(1, hidden_size))
    
    def init_h_and_c(self):
        h = torch.zeros_like(self._float)
        c = torch.zeros_like(self._float)
        return h, c
    
    def forward(self, x):
        """
        input
            x: 1 * length * input_size
        output
            hiddens
        """
        hf, cf = self.init_h_and_c()
        hf, cf = hf[0:1, :self.hidden_size], cf[0:1, :self.hidden_size]
        hiddens = []
        for i in range(x.shape[1]):
            hf, cf = self.lstm_forward(hf, cf, x[0, i:i+1,:])
            hb, cb = self.lstm_backward(hf, cf, x[0, i:i+1,:])
            hiddens.append(torch.cat((hf, hb), 1))
        hiddens = torch.stack(hiddens, dim=1)
        return hiddens

In [159]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.lin = nn.Linear(hidden_size, hidden_size, bias=False)
    
    def forward(self, hiddens):
        """
        input
            hiddens: 1 * length * hidden_size
        output
            attn_outputs: 1 * hidden_size
        """
        q = self.lin(hiddens[0:1,-1,:]).unsqueeze(-1)
        attn_scores = torch.bmm(hiddens, q)
        weights = F.softmax(attn_scores, dim=1)
        attn_outputs = (weights * hiddens).sum(1)
        return attn_outputs

Model Structure：Embedding – BiLSTM – Attention – Linear – LogSoftmax

In [178]:
class EncoderRNN(nn.Module):
    def __init__(self, num_vocab, embedding_dim, hidden_size, num_classes):
        super(EncoderRNN, self).__init__()
        self.num_vocab = num_vocab
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        
        self.embed = nn.Embedding(num_vocab, embedding_dim)
        self.bilstm = BiLSTM(embedding_dim, hidden_size)
        self.attn = Attention(hidden_size)
        
        self.h2o = nn.Linear(hidden_size + hidden_size, num_classes)
        self.softmax = nn.LogSoftmax(dim=-1)
    
    def forward(self, x):
        """
        input
            x: 1 * length, LongTensor
        output
            outputs
        """
        embeddings = []
        for i in range(len(x)):
            embedding = self.embed(x[i])
            embeddings.append(embedding)
        embeddings = torch.stack(embeddings, dim=0) 
        hiddens = self.bilstm(embeddings)
        attn_output = self.attn(hiddens)
        outputs = self.h2o(torch.cat((attn_output, hiddens[0:1,-1,:]), dim=-1))
        outputs = self.softmax(outputs)
        return outputs

In [49]:
def collate(data_list):
    src = [_[0] for _ in data_list]
    tgt = [_[1] for _ in data_list]
    return src, tgt

batch_size = 16
trainloader = torch.utils.data.DataLoader([
    make_data(text, label) for text, label in train_data
], batch_size=batch_size, shuffle=True, collate_fn=collate)
validloader = torch.utils.data.DataLoader([
    make_data(text, label) for text, label in valid_data
], batch_size=batch_size, shuffle=False, collate_fn=collate)
testloader = torch.utils.data.DataLoader([
    make_data(text, label) for text, label in test_data
], batch_size=batch_size, shuffle=False, collate_fn=collate)

In [50]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [196]:
def train_loop(model, optimizer, criterion, loader):
    model.train()
    epoch_loss = 0.0
    for src, tgt in tqdm(loader):
        B = len(src)
        loss = 0.0
        for _ in range(B):
            _src = src[_].unsqueeze(0).to(device)
            output = model(_src)
            loss += criterion(output, tgt[_])
        loss /= B
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= len(loader)
    return epoch_loss

def test_loop(model, loader):
    total_acc = 0.0
    total_cnt = 0
    outputs = []
    for src, tgt in tqdm(loader):
        B = len(src)
        for _ in range(B):
            _src = src[_].unsqueeze(0).to(device)
            output = model(_src).argmax(1)
            outputs.append(output.item())
            if output == tgt[_]:
                total_acc += 1
            total_cnt += 1
    return total_acc / total_cnt, outputs

In [197]:
num_vocab = len(word2idx) 
num_classes = len(label2idx)
hidden_size = 128
embedding_dim = 256
model = EncoderRNN(num_vocab, embedding_dim, hidden_size, num_classes)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1)
criterion = nn.CrossEntropyLoss()

best_score = 0.0
for epoch in range(3):
    loss = train_loop(model, optimizer, criterion, trainloader)
    acc, _ = test_loop(model, validloader)
    if acc > best_score:
        torch.save(model.state_dict(), "model_best.pt")
        best_score = acc
    print(f"Epoch {epoch}: loss = {loss}, valid accuracy = {acc}")

100%|████████████████████████████████████████████████████████████████████████████████| 705/705 [19:20<00:00,  1.65s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [00:47<00:00,  1.86it/s]


Epoch 0: loss = 1.3386044781258766, valid accuracy = 0.5490056818181818


100%|████████████████████████████████████████████████████████████████████████████████| 705/705 [16:25<00:00,  1.40s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [00:36<00:00,  2.41it/s]


Epoch 1: loss = 1.0391498630773937, valid accuracy = 0.5838068181818182


100%|████████████████████████████████████████████████████████████████████████████████| 705/705 [16:59<00:00,  1.45s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 88/88 [00:43<00:00,  2.03it/s]

Epoch 2: loss = 0.8146614053359268, valid accuracy = 0.5951704545454546





In [198]:
model.load_state_dict(torch.load("model_best.pt"))
acc, predictions = test_loop(model, testloader)
print(f"Test accuracy = {acc}")
test = []
for batch in testloader:
    for i in batch[1]:
        test.append(i[0].item())
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
print(confusion_matrix(test, predictions))

100%|██████████████████████████████████████████████████████████████████████████████████| 89/89 [00:45<00:00,  1.94it/s]

Test accuracy = 0.642290780141844
[[ 42  32  39  20  27]
 [  8 339  26  21  20]
 [ 21  56 341  34  16]
 [ 14  42  46 111  24]
 [ 10  22  29  28  42]]



