In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torch.nn.functional as F
import logging

In [2]:
model_name="TextRNN_attn"

In [3]:
def init_logging(path):
    logger = logging.getLogger('my_logger')
    logger.setLevel(logging.INFO)

    file_handler = logging.FileHandler(path)

    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    return logger

In [4]:
logger = init_logging(f"./log/{model_name}.txt")

In [5]:
pretrained_embedding_path = "./data/pretrained_wordvector/sgns.sogou.char"
embed = []
word2idx = dict()
idx2word = dict()

size = None
with open(pretrained_embedding_path, "r") as f:
    idx = 0
    f.readline()
    for line in tqdm(f):
        x = line.strip().split(' ')
        word = x[0]
        vector = np.asarray(x[1:], dtype=np.float32)
        size = vector.shape
        embed.append(vector)
        word2idx[word] = idx
        idx2word[idx] = word
        idx += 1

365076it [00:13, 26880.21it/s]


In [6]:
avg = sum(embed) / len(embed)
embed.append(avg)
word2idx['<UNK>'] = idx
idx2word[idx] = '<UNK>'
idx += 1

embed.append(np.random.normal(size=size))
word2idx['<PAD>'] = idx
idx2word[idx] = '<PAD>'
idx += 1



In [7]:
embed = torch.from_numpy(np.array(embed)).float()
embed.shape, embed.dtype

(torch.Size([365078, 300]), torch.float32)

In [8]:
class MyDataset(Dataset):
    def __init__(self, path):
        self.data = []
        with open(path, "r") as f:
            for line in tqdm(f):
                x = line.split('\t')
                sen, label = x[0], x[1]
                self.data.append((sen, int(label)))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
                
    def collate(self, batchs):
        sen_out = []
        tot_sen = [pair[0] for pair in batchs]
        tot_label = [pair[1] for pair in batchs]
        max_len = max([len(sen) for sen in tot_sen])
        for sen in tot_sen:
            temp = []
            for ch in sen:
                if (ch not in word2idx):
                    temp.append(word2idx['<UNK>'])
                else:
                    temp.append(word2idx[ch])
            temp += [word2idx['<PAD>']] * (max_len - len(sen))
            sen_out.append(temp)
        return torch.from_numpy(np.array(sen_out)), torch.from_numpy(np.array(tot_label))
                
            
            
            
        

In [9]:
train_dataset = MyDataset("./data/train.txt")
test_dataset = MyDataset("./data/test.txt")
valid_dataset = MyDataset("./data/valid.txt")

668852it [00:00, 1693797.24it/s]
83607it [00:00, 1573132.13it/s]
83606it [00:00, 1695822.60it/s]


In [10]:
batch_size = 256

In [11]:
train_dataloder = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=test_dataset.collate)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate)

In [12]:
for batch in test_dataloader:
    print(batch[0])
    print(batch[1])
    print(batch[0].shape)
    print(batch[1].shape)
    break

tensor([[    48,   1319,   7301,  ..., 365077, 365077, 365077],
        [  2110,    920,   1992,  ..., 365077, 365077, 365077],
        [ 55874,   9032,  20354,  ..., 365077, 365077, 365077],
        ...,
        [  5899,  21257,  10550,  ..., 365077, 365077, 365077],
        [  8883,  12078,  13042,  ..., 365077, 365077, 365077],
        [    48,   1263,   6437,  ..., 365077, 365077, 365077]])
tensor([ 0, 11, 11, 11, 11, 11,  7, 11,  3,  3,  7,  3,  3, 11,  3, 13,  8,  7,
         8, 13,  2,  2,  8,  7,  3,  3, 10,  7, 12,  8,  1,  0,  0,  7,  8, 11,
         3, 11,  5,  8,  3, 11,  7,  3,  3, 11,  7, 10,  7,  7,  2,  3, 11,  8,
         7,  3,  7, 11,  3, 12, 11, 11,  6,  2,  1, 11,  7,  7,  2,  8,  7, 11,
        11,  7,  7,  7,  1, 11,  3,  3, 13,  7, 11,  8,  8,  3,  3, 10, 10, 11,
         8,  3, 13, 12,  7,  8,  1, 11, 11,  2,  3, 13, 11,  8,  3,  2,  0,  3,
         8,  3,  6, 11,  0,  0,  0,  8,  7, 11,  8,  3,  2,  2,  7, 10,  8,  8,
         1,  7, 11, 10, 10,  7,  1, 11,  2

In [13]:
from models.TextCNN import TextCNN
from models.TextRNN import TextRNN
from models.TextRNN_attn import TextRNN_attn


In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
model = TextRNN_attn(embed=embed, embed_size=300, hidden_size=64, num_class=14, dropout=0.5)
model = model.to(device)




In [16]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [17]:
eval_every = 2

In [18]:
def eval(model, dataloader, device):
    model.eval()
    acc = 0
    tot = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            data, label = batch
            data = data.to(device)
            out = model(data)
            predicted = torch.argmax(out, dim=-1)
            predicted = predicted.cpu()
            acc += (predicted == label).sum()
            tot += predicted.shape[0]
        return acc / tot

In [19]:
Epoch = 100
best_acc = 0.0
best_epoch = -1
early_stop = 5
for epoch in range(Epoch):
    model.train()
    epoch_loss = 0.0
    for batch in tqdm(train_dataloder):
        data, label = batch
        data = data.to(device)
        label = label.to(device)
        predicted = model(data)
        loss = loss_fn(predicted, label)
        epoch_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    logger.info(f"epoch:{epoch}, loss:{epoch_loss / len(train_dataloder):.4f}")
    if (epoch % eval_every == 0):
        acc = eval(model, valid_dataloader, device)
        if (acc > best_acc):
            best_acc = acc
            best_epoch = epoch
            torch.save(model.state_dict(), f"./pt/{model_name}.pt")
        else:
            early_stop -= 1
        logger.info(f"epoch:{epoch}, valid acc: {acc * 100:.4f}%, best valid acc: {best_acc * 100:.4f}, at epoch {best_epoch}")
        if (early_stop == 0):
            logger.info(f"early stop!")
            break
            

100%|██████████| 2613/2613 [00:11<00:00, 233.70it/s]
100%|██████████| 327/327 [00:00<00:00, 406.02it/s]
100%|██████████| 2613/2613 [00:11<00:00, 225.81it/s]
100%|██████████| 2613/2613 [00:11<00:00, 231.50it/s]
100%|██████████| 327/327 [00:00<00:00, 379.28it/s]
100%|██████████| 2613/2613 [00:11<00:00, 219.79it/s]
100%|██████████| 2613/2613 [00:11<00:00, 229.25it/s]
100%|██████████| 327/327 [00:00<00:00, 335.24it/s]
100%|██████████| 2613/2613 [00:11<00:00, 234.14it/s]
100%|██████████| 2613/2613 [00:10<00:00, 238.43it/s]
100%|██████████| 327/327 [00:00<00:00, 396.57it/s]
100%|██████████| 2613/2613 [00:11<00:00, 233.33it/s]
100%|██████████| 2613/2613 [00:11<00:00, 237.29it/s]
100%|██████████| 327/327 [00:00<00:00, 395.39it/s]
100%|██████████| 2613/2613 [00:11<00:00, 237.39it/s]
100%|██████████| 2613/2613 [00:10<00:00, 241.41it/s]
100%|██████████| 327/327 [00:00<00:00, 367.70it/s]
100%|██████████| 2613/2613 [00:12<00:00, 217.05it/s]
100%|██████████| 2613/2613 [00:11<00:00, 221.89it/s]
100%|

In [21]:
test_model = TextRNN_attn(embed=embed, embed_size=300, hidden_size=64, num_class=14, dropout=0.5)
test_model.load_state_dict(torch.load(f"./pt/{model_name}.pt"))

test_model = test_model.to(device)

test_model.eval()
with torch.no_grad():
    acc = eval(test_model, test_dataloader, device)
    logger.info(f"test acc: {acc * 100:.4f}%")

100%|██████████| 327/327 [00:00<00:00, 386.47it/s]
