In [1]:
import time
import torch
import os
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import pickle as pkl
from tqdm import tqdm
from datetime import timedelta
from sklearn import metrics
from tensorboardX import SummaryWriter
import random

  from .autonotebook import tqdm as notebook_tqdm


# （一）数据预处理

In [2]:
train_path = './data/train.txt'
dev_path = './data/dev.txt'

i = 0
label_dict = {}

with open(train_path, 'r', encoding='utf-8') as f:
    train_lines = f.readlines()
    f.close()

for line in train_lines:
    label = line.split('\t')[1].replace('\n', '')
    if label not in label_dict:
        label_dict[label] = i
        i += 1

with open('data/train.txt', 'w', encoding='utf-8') as f:
    for line in train_lines[0:int(len(train_lines)*0.8)]:
        title = line.split('\t')[0]
        label = line.split('\t')[1].replace('\n', '')
        f.write(title + '\t' + str(label_dict[label]) + '\n')
    f.close()

with open('data/test.txt', 'w', encoding='utf-8') as f:
    for line in train_lines[int(len(train_lines)*0.8):]:
        title = line.split('\t')[0]
        label = line.split('\t')[1].replace('\n', '')
        f.write(title + '\t' + str(label_dict[label]) + '\n')
    f.close()

with open(dev_path, 'r', encoding='utf-8') as f:
    dev_lines = f.readlines()
    f.close()

with open('data/dev.txt', 'w', encoding='utf-8') as f:
    for line in dev_lines:
        title = line.split('\t')[0]
        label = line.split('\t')[1].replace('\n', '')
        f.write(title + '\t' + str(label_dict[label]) + '\n')
    f.close()

label_list = list(label_dict)
with open('data/dict.txt', 'w', encoding='utf-8') as f:
    for label in label_list:
        f.write(label + '\n')
    f.close()

In [3]:
MAX_VOCAB_SIZE = 10000  # 词表长度限制
UNK, PAD = '<UNK>', '<PAD>'  # 未知字，padding符号
torch.cuda.set_device(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = lambda x: [y for y in x]  # char-level

def build_vocab(file_path, tokenizer, max_size, min_freq):
    vocab_dic = {}
    with open(file_path, 'r', encoding='UTF-8') as f:
        for line in tqdm(f):
            lin = line.strip()
            if not lin:
                continue
            content = lin.split('\t')[0]
            for word in tokenizer(content):
                vocab_dic[word] = vocab_dic.get(word, 0) + 1
        vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
        vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
        vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
    return vocab_dic

if os.path.exists('data/vocab.pkl'):
    vocab = pkl.load(open('data/vocab.pkl', 'rb'))
else:
    vocab = build_vocab('data/train.txt', tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
    pkl.dump(vocab, open('data/vocab.pkl', 'wb'))
class_list = [x.strip() for x in open('data/dict.txt', encoding='utf-8').readlines()]
num_class = len(class_list)

# （二）数据加载

In [4]:
def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))


def load_dataset(path, pad_size=32):
    contents = []
    with open(path, 'r', encoding='UTF-8') as f:
        for line in tqdm(f):
            lin = line.strip()
            if not lin:
                continue
            content, label = lin.split('\t')
            words_line = []
            token = tokenizer(content)
            seq_len = len(token)
            if pad_size:
                if len(token) < pad_size:
                    token.extend([PAD] * (pad_size - len(token)))
                else:
                    token = token[:pad_size]
                    seq_len = pad_size
            # word to id
            for word in token:
                words_line.append(vocab.get(word, vocab.get(UNK)))
            contents.append((words_line, int(label), seq_len))
    return contents  # [([...], 0), ([...], 1), ...]


train_data = load_dataset('data/train.txt', 32)
dev_data = load_dataset('data/dev.txt', 32)
test_data = load_dataset('data/test.txt', 32)

class DatasetIterater(object):
    def __init__(self, batches, batch_size, device):
        self.batch_size = batch_size
        self.batches = batches
        self.n_batches = len(batches) // batch_size
        self.residue = False  # 记录batch数量是否为整数
        if len(batches) % self.n_batches != 0:
            self.residue = True
        self.index = 0
        self.device = device

    def _to_tensor(self, datas):
        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)

        # pad前的长度(超过pad_size的设为pad_size)
        seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
        return (x, seq_len), y

    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.batches[self.index * self.batch_size: len(self.batches)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

        elif self.index >= self.n_batches:
            self.index = 0
            raise StopIteration
        else:
            batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

    def __iter__(self):
        return self

    def __len__(self):
        if self.residue:
            return self.n_batches + 1
        else:
            return self.n_batches


train_iter = DatasetIterater(train_data, 128, device)
dev_iter = DatasetIterater(dev_data, 128, device)
test_iter = DatasetIterater(test_data, 128, device)

num_vocab = len(vocab)

385264it [00:04, 91452.19it/s] 
80000it [00:00, 91877.24it/s] 
96316it [00:01, 92855.92it/s] 


# （三）定义GAN网络

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(num_vocab, 300, padding_idx=num_vocab - 1)
        self.conv1 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, stride=2, padding=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(9600, num_class)

    def forward(self, x):
        if(isinstance(x, tuple)):
            x = self.embedding(x[0])
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.fc(x)
        return x


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Linear(100, 32 * 300)  # Assuming input noise size is 100

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 32, 300)
        return x


discriminator = Discriminator().to(device)
generator = Generator().to(device)

# （四）模型训练

In [6]:
def train(discriminator, generator, train_iter, dev_iter, test_iter):
    start_time = time.time()
    discriminator.train()
    optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3)

    # 学习率指数衰减，每次epoch：学习率 = gamma * 学习率
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升
    writer = SummaryWriter(log_dir='log/GAN' + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    for epoch in range(5):
        print('Epoch [{}/{}]'.format(epoch + 1, 5))
        # scheduler.step() # 学习率衰减
        for i, (trains, labels) in enumerate(train_iter):
            outputs = discriminator(trains)
            discriminator.zero_grad()
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            random_noise = torch.randn(128, 100).to(device)  # Batch size of 128, noise size of 100
            fake_data = generator(random_noise)

            # 假图像的标签（你可以根据需要定义）
            fake_labels = torch.randint(0, num_class, (128,)).to(device)

            # 鉴别器对假图像的输出
            fake_outputs = discriminator(fake_data.detach())
            fake_loss = F.cross_entropy(fake_outputs, fake_labels)
            fake_loss.backward()
            optimizer.step()


            if total_batch % 100 == 0:
                # 每多少轮输出在训练集和验证集上的效果
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                dev_acc, dev_loss = evaluate(discriminator, generator, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(discriminator.state_dict(), 'saved_dict/GAN.ckpt')
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
                writer.add_scalar("loss/train", loss.item(), total_batch)
                writer.add_scalar("loss/dev", dev_loss, total_batch)
                writer.add_scalar("acc/train", train_acc, total_batch)
                writer.add_scalar("acc/dev", dev_acc, total_batch)
                discriminator.train()
            total_batch += 1
            if total_batch - last_improve > 1000:
                # 验证集loss超过1000batch没下降，结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
    writer.close()
    model_test(discriminator, generator, test_iter)


def model_test(discriminator, generator, test_iter):
    # test
    discriminator.load_state_dict(torch.load('saved_dict/GAN.ckpt'))
    discriminator.eval()
    start_time = time.time()
    test_acc, test_loss, test_report, test_confusion = evaluate(discriminator, generator, test_iter, test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


def evaluate(discriminator, generator, data_iter, test=False):
    discriminator.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for texts, labels in data_iter:
            outputs = discriminator(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)

    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    return acc, loss_total / len(data_iter)

In [7]:
train(discriminator, generator, train_iter, dev_iter, test_iter)

Epoch [1/5]
Iter:      0,  Train Loss:   2.6,  Train Acc:  7.81%,  Val Loss:   2.4,  Val Acc: 19.58%,  Time: 0:00:04 *
Iter:    100,  Train Loss:  0.76,  Train Acc: 78.12%,  Val Loss:   1.1,  Val Acc: 65.50%,  Time: 0:00:07 *
Iter:    200,  Train Loss:  0.54,  Train Acc: 83.59%,  Val Loss:  0.72,  Val Acc: 77.01%,  Time: 0:00:10 *
Iter:    300,  Train Loss:  0.64,  Train Acc: 78.12%,  Val Loss:  0.65,  Val Acc: 79.77%,  Time: 0:00:12 *
Iter:    400,  Train Loss:  0.53,  Train Acc: 82.81%,  Val Loss:  0.58,  Val Acc: 81.93%,  Time: 0:00:15 *
Iter:    500,  Train Loss:  0.53,  Train Acc: 81.25%,  Val Loss:  0.56,  Val Acc: 82.36%,  Time: 0:00:18 *
Iter:    600,  Train Loss:  0.39,  Train Acc: 87.50%,  Val Loss:  0.58,  Val Acc: 81.44%,  Time: 0:00:21 
Iter:    700,  Train Loss:  0.55,  Train Acc: 80.47%,  Val Loss:  0.51,  Val Acc: 83.79%,  Time: 0:00:23 *
Iter:    800,  Train Loss:  0.38,  Train Acc: 85.16%,  Val Loss:  0.51,  Val Acc: 83.85%,  Time: 0:00:26 *
Iter:    900,  Train Loss: