In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torch.nn.functional as F

In [2]:
pretrained_embedding_path = "./data/pretrained_wordvector/sgns.sogou.char"
embed = []
word2idx = dict()
idx2word = dict()

size = None
with open(pretrained_embedding_path, "r") as f:
    idx = 0
    f.readline()
    for line in tqdm(f):
        x = line.strip().split(' ')
        word = x[0]
        vector = np.asarray(x[1:], dtype=np.float32)
        size = vector.shape
        embed.append(vector)
        word2idx[word] = idx
        idx2word[idx] = word
        idx += 1

365076it [00:12, 28242.38it/s]


In [3]:
avg = sum(embed) / len(embed)
embed.append(avg)
word2idx['<UNK>'] = idx
idx2word[idx] = '<UNK>'
idx += 1

embed.append(np.random.normal(size=size))
word2idx['<PAD>'] = idx
idx2word[idx] = '<PAD>'
idx += 1



In [4]:
embed = torch.from_numpy(np.array(embed)).float()
embed.shape, embed.dtype

(torch.Size([365078, 300]), torch.float32)

In [5]:
class MyDataset(Dataset):
    def __init__(self, path):
        self.data = []
        with open(path, "r") as f:
            for line in tqdm(f):
                x = line.split('\t')
                sen, label = x[0], x[1]
                self.data.append((sen, int(label)))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
                
    def collate(self, batchs):
        sen_out = []
        tot_sen = [pair[0] for pair in batchs]
        tot_label = [pair[1] for pair in batchs]
        max_len = max([len(sen) for sen in tot_sen])
        for sen in tot_sen:
            temp = []
            for ch in sen:
                if (ch not in word2idx):
                    temp.append(word2idx['<UNK>'])
                else:
                    temp.append(word2idx[ch])
            temp += [word2idx['<PAD>']] * (max_len - len(sen))
            sen_out.append(temp)
        return torch.from_numpy(np.array(sen_out)), torch.from_numpy(np.array(tot_label))
                
            
            
            
        

In [6]:
train_dataset = MyDataset("./data/train.txt")
test_dataset = MyDataset("./data/test.txt")
valid_dataset = MyDataset("./data/valid.txt")

668852it [00:00, 1666496.55it/s]
83607it [00:00, 1504583.06it/s]
83606it [00:00, 1716861.59it/s]


In [19]:
batch_size = 256

In [21]:
train_dataloder = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=test_dataset.collate)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate)

In [8]:
for batch in test_dataloader:
    print(batch[0])
    print(batch[1])
    print(batch[0].shape)
    print(batch[1].shape)
    break

tensor([[    48,   1319,   7301,  ..., 365077, 365077, 365077],
        [  2110,    920,   1992,  ..., 365077, 365077, 365077],
        [ 55874,   9032,  20354,  ..., 365077, 365077, 365077],
        ...,
        [  8154,  26841,   8360,  ..., 365077, 365077, 365077],
        [   220,   1122,     34,  ..., 365077, 365077, 365077],
        [ 64214,  44835,   5651,  ..., 365077, 365077, 365077]])
tensor([ 0, 11, 11, 11, 11, 11,  7, 11,  3,  3,  7,  3,  3, 11,  3, 13,  8,  7,
         8, 13,  2,  2,  8,  7,  3,  3, 10,  7, 12,  8,  1,  0,  0,  7,  8, 11,
         3, 11,  5,  8,  3, 11,  7,  3,  3, 11,  7, 10,  7,  7,  2,  3, 11,  8,
         7,  3,  7, 11,  3, 12, 11, 11,  6,  2,  1, 11,  7,  7,  2,  8,  7, 11,
        11,  7,  7,  7,  1, 11,  3,  3, 13,  7, 11,  8,  8,  3,  3, 10, 10, 11,
         8,  3, 13, 12,  7,  8,  1, 11, 11,  2,  3, 13, 11,  8,  3,  2,  0,  3,
         8,  3,  6, 11,  0,  0,  0,  8,  7, 11,  8,  3,  2,  2,  7, 10,  8,  8,
         1,  7])
torch.Size([128, 26])
tor

In [22]:
class TextCNN(nn.Module):
    def __init__(self, out_channels, embed_size, num_class, dropout):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embeddings=embed, freeze=True)
        self.kernel1 = nn.Conv2d(in_channels=1, out_channels=out_channels, kernel_size=(2, embed_size))
        self.kernel2 = nn.Conv2d(in_channels=1, out_channels=out_channels, kernel_size=(3, embed_size))
        self.kernel3 = nn.Conv2d(in_channels=1, out_channels=out_channels, kernel_size=(4, embed_size))
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(3 * out_channels, num_class)
        
    def forward(self, inputs):
        # inputs: (b, s)
        x = self.embedding(inputs)
        # (b, s, ebs)
        x = x.unsqueeze(1)
        # (b, 1, s, ebs)
        out1 = self.kernel1(x)
        out2 = self.kernel2(x)
        out3 = self.kernel3(x)
        # (b, out, s-1, 1)
        # (b, out, s-2, 1)
        # (b, out, s-3, 1)
        out1 = torch.squeeze(out1, dim=-1)
        out2 = torch.squeeze(out2, dim=-1)
        out3 = torch.squeeze(out3, dim=-1)
        # (b, out, s-1)
        # (b, out, s-2)
        # (b, out, s-3)
        out1 = F.max_pool1d(out1, out1.shape[-1])
        out2 = F.max_pool1d(out2, out2.shape[-1])
        out3 = F.max_pool1d(out3, out3.shape[-1])
        # (b, out, 1)
        # (b, out, 1)
        # (b, out, 1)
        out1 = torch.squeeze(out1, dim=-1)
        out2 = torch.squeeze(out2, dim=-1)
        out3 = torch.squeeze(out3, dim=-1)
        # (b, out)
        # (b, out)
        # (b, out)
        out = torch.cat([out1, out2, out3], dim=-1)
        # (b, 3 * out)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.fc(out)
        
        return out        

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
model = TextCNN(out_channels=100, embed_size=300, num_class=14, dropout=0.5)
model = model.to(device)


In [25]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [26]:
eval_every = 2

In [27]:
def eval(model, dataloader, device):
    model.eval()
    acc = 0
    tot = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            data, label = batch
            data = data.to(device)
            out = model(data)
            predicted = torch.argmax(out, dim=-1)
            predicted = predicted.cpu()
            acc += (predicted == label).sum()
            tot += predicted.shape[0]
        return acc / tot

In [28]:
Epoch = 100
best_acc = 0.0
best_epoch = -1
early_stop = 5
for epoch in range(Epoch):
    model.train()
    epoch_loss = 0.0
    for batch in tqdm(train_dataloder):
        data, label = batch
        data = data.to(device)
        label = label.to(device)
        predicted = model(data)
        loss = loss_fn(predicted, label)
        epoch_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"epoch:{epoch}, loss:{epoch_loss / len(train_dataloder):.4f}")
    if (epoch % eval_every == 0):
        acc = eval(model, valid_dataloader, device)
        if (acc > best_acc):
            best_acc = acc
            best_epoch = epoch
            torch.save(model.state_dict(), "./pt/TextCNN.pt")
        else:
            early_stop -= 1
        print(f"epoch:{epoch}, valid acc: {acc * 100:.4f}%, best valid acc: {best_acc * 100:.4f}, at epoch {best_epoch}")
        if (early_stop == 0):
            print(f"early stop!")
            break
            

100%|██████████| 2613/2613 [00:16<00:00, 156.82it/s]


epoch:0, loss:0.6092


100%|██████████| 327/327 [00:01<00:00, 277.57it/s]


epoch:0, valid acc: 88.7400%, best valid acc: 88.7400, at epoch 0


100%|██████████| 2613/2613 [00:16<00:00, 158.74it/s]


epoch:1, loss:0.4233


100%|██████████| 2613/2613 [00:16<00:00, 160.66it/s]


epoch:2, loss:0.3873


100%|██████████| 327/327 [00:01<00:00, 278.69it/s]


epoch:2, valid acc: 90.4062%, best valid acc: 90.4062, at epoch 2


100%|██████████| 2613/2613 [00:16<00:00, 160.70it/s]


epoch:3, loss:0.3663


100%|██████████| 2613/2613 [00:16<00:00, 160.17it/s]


epoch:4, loss:0.3533


100%|██████████| 327/327 [00:01<00:00, 276.38it/s]


epoch:4, valid acc: 91.0210%, best valid acc: 91.0210, at epoch 4


100%|██████████| 2613/2613 [00:16<00:00, 159.93it/s]


epoch:5, loss:0.3419


100%|██████████| 2613/2613 [00:16<00:00, 160.44it/s]


epoch:6, loss:0.3335


100%|██████████| 327/327 [00:01<00:00, 277.33it/s]


epoch:6, valid acc: 91.1944%, best valid acc: 91.1944, at epoch 6


100%|██████████| 2613/2613 [00:16<00:00, 159.92it/s]


epoch:7, loss:0.3259


100%|██████████| 2613/2613 [00:16<00:00, 159.92it/s]


epoch:8, loss:0.3201


100%|██████████| 327/327 [00:01<00:00, 276.01it/s]


epoch:8, valid acc: 91.5114%, best valid acc: 91.5114, at epoch 8


100%|██████████| 2613/2613 [00:16<00:00, 158.20it/s]


epoch:9, loss:0.3143


100%|██████████| 2613/2613 [00:16<00:00, 160.06it/s]


epoch:10, loss:0.3097


100%|██████████| 327/327 [00:01<00:00, 276.66it/s]


epoch:10, valid acc: 91.5999%, best valid acc: 91.5999, at epoch 10


100%|██████████| 2613/2613 [00:16<00:00, 158.08it/s]


epoch:11, loss:0.3058


100%|██████████| 2613/2613 [00:16<00:00, 159.44it/s]


epoch:12, loss:0.3013


100%|██████████| 327/327 [00:01<00:00, 269.37it/s]


epoch:12, valid acc: 91.7147%, best valid acc: 91.7147, at epoch 12


100%|██████████| 2613/2613 [00:16<00:00, 159.57it/s]


epoch:13, loss:0.2988


100%|██████████| 2613/2613 [00:16<00:00, 160.38it/s]


epoch:14, loss:0.2944


100%|██████████| 327/327 [00:01<00:00, 275.36it/s]


epoch:14, valid acc: 91.7171%, best valid acc: 91.7171, at epoch 14


100%|██████████| 2613/2613 [00:16<00:00, 160.38it/s]


epoch:15, loss:0.2933


100%|██████████| 2613/2613 [00:16<00:00, 160.30it/s]


epoch:16, loss:0.2910


100%|██████████| 327/327 [00:01<00:00, 274.73it/s]


epoch:16, valid acc: 91.7745%, best valid acc: 91.7745, at epoch 16


100%|██████████| 2613/2613 [00:16<00:00, 160.41it/s]


epoch:17, loss:0.2882


100%|██████████| 2613/2613 [00:16<00:00, 160.20it/s]


epoch:18, loss:0.2849


100%|██████████| 327/327 [00:01<00:00, 273.98it/s]


epoch:18, valid acc: 91.9743%, best valid acc: 91.9743, at epoch 18


100%|██████████| 2613/2613 [00:16<00:00, 161.23it/s]


epoch:19, loss:0.2838


100%|██████████| 2613/2613 [00:16<00:00, 162.51it/s]


epoch:20, loss:0.2821


100%|██████████| 327/327 [00:01<00:00, 278.34it/s]


epoch:20, valid acc: 91.9348%, best valid acc: 91.9743, at epoch 18


100%|██████████| 2613/2613 [00:16<00:00, 162.80it/s]


epoch:21, loss:0.2804


100%|██████████| 2613/2613 [00:16<00:00, 163.08it/s]


epoch:22, loss:0.2774


100%|██████████| 327/327 [00:01<00:00, 278.30it/s]


epoch:22, valid acc: 91.9635%, best valid acc: 91.9743, at epoch 18


100%|██████████| 2613/2613 [00:16<00:00, 162.43it/s]


epoch:23, loss:0.2783


100%|██████████| 2613/2613 [00:16<00:00, 162.32it/s]


epoch:24, loss:0.2747


100%|██████████| 327/327 [00:01<00:00, 276.37it/s]


epoch:24, valid acc: 91.9910%, best valid acc: 91.9910, at epoch 24


100%|██████████| 2613/2613 [00:16<00:00, 159.38it/s]


epoch:25, loss:0.2738


100%|██████████| 2613/2613 [00:16<00:00, 162.12it/s]


epoch:26, loss:0.2724


100%|██████████| 327/327 [00:01<00:00, 279.54it/s]


epoch:26, valid acc: 92.0496%, best valid acc: 92.0496, at epoch 26


100%|██████████| 2613/2613 [00:16<00:00, 159.98it/s]


epoch:27, loss:0.2709


100%|██████████| 2613/2613 [00:16<00:00, 160.71it/s]


epoch:28, loss:0.2699


100%|██████████| 327/327 [00:01<00:00, 277.51it/s]


epoch:28, valid acc: 91.9456%, best valid acc: 92.0496, at epoch 26


100%|██████████| 2613/2613 [00:16<00:00, 161.36it/s]


epoch:29, loss:0.2691


100%|██████████| 2613/2613 [00:16<00:00, 161.81it/s]


epoch:30, loss:0.2680


100%|██████████| 327/327 [00:01<00:00, 275.47it/s]


epoch:30, valid acc: 91.9838%, best valid acc: 92.0496, at epoch 26


100%|██████████| 2613/2613 [00:16<00:00, 161.25it/s]


epoch:31, loss:0.2668


100%|██████████| 2613/2613 [00:16<00:00, 161.98it/s]


epoch:32, loss:0.2668


100%|██████████| 327/327 [00:01<00:00, 278.92it/s]

epoch:32, valid acc: 92.0137%, best valid acc: 92.0496, at epoch 26
early stop!





In [30]:
test_model = TextCNN(out_channels=100, embed_size=300, num_class=14, dropout=0.5)
test_model.load_state_dict(torch.load("./pt/TextCNN.pt"))

test_model = test_model.to(device)

test_model.eval()
with torch.no_grad():
    acc = eval(test_model, test_dataloader, device)
    print(f"test acc: {acc * 100:.4f}%")

100%|██████████| 327/327 [00:01<00:00, 276.15it/s]

test acc: 92.0940%



