<a href="https://colab.research.google.com/github/overlyhandsome/Colab20210803/blob/main/Cynwell_Lau_Text_Level_GNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

参考Cynwell Lau对[Text-Level-GNN](https://arxiv.org/pdf/1910.02356.pdf)论文思想的代码复现：[code source](https://github.com/Cynwell/Text-Level-GNN/blob/main/train.py)

---
使用谷歌GPU资源

---

In [1]:
%%writefile parsing.py

# 标准库
import pandas as pd
import numpy as np
from time import time
import argparse

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

class GloveTokenizer:
    def __init__(self, filename, unk='<unk>', pad='<pad>'):
        self.filename = filename  # embedding文件的路径
        self.unk = unk  # 对于未知词汇的处理
        self.pad = pad  # 按最大长度对齐每个文本
        self.stoi = dict()  # {词：id}
        self.itos = dict()  # {id: 词}
        self.embedding_matrix = list()  # 定义一个列表
        with open(filename, 'r', encoding='utf8') as f: # Read tokenizer file  # 读tokenizer文件
            for i, line in enumerate(f):
                values = line.split()  # glove和word2vec的第一行不同，word2vec第一行是数据的统计信息，glove没有
                self.stoi[values[0]] = i  # 字典存储词和词的id
                self.itos[i] = values[0]  # 字典存储词的id和词
                self.embedding_matrix.append([float(v) for v in values[1:]])  # 列表存储embeddings的向量值
        if self.unk is not None: # Add unk token into the tokenizer  # 将未知的词添加到tokenizer
            i += 1  # i接着上面的循环值加1
            self.stoi[self.unk] = i  # unk转id
            self.itos[i] = self.unk  # id转unk
            self.embedding_matrix.append(np.random.rand(len(self.embedding_matrix[0])))  # 随机产生0-1之间的向量作为未知词的词嵌入
        if self.pad is not None: # Add pad token into the tokenizer  # 如果需要对齐
            i += 1  # i的值接着加1
            self.stoi[self.pad] = i  # 补丁符号对应id
            self.itos[i] = self.pad  # id对应补丁符号
            self.embedding_matrix.append(np.zeros(len(self.embedding_matrix[0])))  # 补丁用全0的向量表示
        self.embedding_matrix = np.array(self.embedding_matrix).astype(np.float32) # Convert if from double to float for efficiency  # 最终的矩阵，转浮点型

    def encode(self, sentence):  # 句子编码
        if type(sentence) == str:  # 如果是字符串型
            sentence = sentence.split(' ')  # 切分
        elif len(sentence): # Convertible to list  # 如果是大于0的字符串列表
            sentence = list(sentence)
        else:
            raise TypeError('sentence should be either a str or a list of str!')

        encoded_sentence = []
        for word in sentence:  # 遍历句子的每个词
            encoded_sentence.append(self.stoi.get(word, self.stoi[self.unk]))  # 词转id表示
        return encoded_sentence

    def decode(self, encoded_sentence):  # 句子解码
        try:
            encoded_sentence = list(encoded_sentence)  # encoded_sentence为str类型或者是可以转化为list类型的数据
        except Exception as e:
            print(e)
            raise TypeError('encoded_sentence should be either a str or a data type that is convertible to list type!')
        sentence = []
        for encoded_word in encoded_sentence:
            sentence.append(self.itos[encoded_word])  # id转词表示
        return sentence

    def embedding(self, encoded_sentence):
        return self.embedding_matrix[np.array(encoded_sentence)]  # 返回一个句子对应的词嵌入向量矩阵


class TextLevelGNNDataset(Dataset): # For instantiating train, validation and test dataset  # 用于实例化训练、验证和测试数据集
    def __init__(self, node_sets, neighbor_sets, public_edge_mask, labels):
        super(TextLevelGNNDataset).__init__()  # 继承Dataset父类
        self.node_sets = node_sets  # 节点集合，二维矩阵
        self.neighbor_sets = neighbor_sets  # 邻居集合，三维矩阵
        self.public_edge_mask = public_edge_mask  # 标记边的信息，出现次数小于k的将视为共用边
        self.labels = labels  # 独热编码型标签

    def __getitem__(self, i):  # 将以上4个属性封装到一个LongTensor长张量中
        return torch.LongTensor(self.node_sets[i]), \
               torch.nn.utils.rnn.pad_sequence([torch.LongTensor(neighbor) for neighbor in self.neighbor_sets[i]], batch_first=True, padding_value=1), \
               self.public_edge_mask[torch.LongTensor(self.node_sets[i]).unsqueeze(-1).repeat(1, torch.nn.utils.rnn.pad_sequence([torch.LongTensor(neighbor) for neighbor in self.neighbor_sets[i]], batch_first=True, padding_value=1).shape[-1]), torch.nn.utils.rnn.pad_sequence([torch.LongTensor(neighbor) for neighbor in self.neighbor_sets[i]], batch_first=True, padding_value=1)], \
               torch.FloatTensor(self.labels[i])

    def __len__(self):
        return len(self.labels)  # 返回labels的数量


class TextLevelGNNDatasetClass: # This class is used to achieve parameters sharing among datasets  # 该类用于实现参数在数据集之间的共享
    def __init__(self, train_filename, test_filename, tokenizer, MAX_LENGTH=10, p=2, min_freq=2, train_validation_split=0.9):  # 以10%的训练集最为验证集
        self.train_filename = train_filename  # 训练集
        self.test_filename = test_filename  # 测试集
        self.tokenizer = tokenizer  # 句子的切分形式
        self.MAX_LENGTH = MAX_LENGTH  # 按最大长度对齐文本
        self.p = p  # 窗口大小
        self.min_freq = min_freq  # 最小出现的节点次数，小于该节点将直接全局使用该词的词嵌入向量
        self.train_validation

        # pandas中的read_csv()方法可以读用"\t" (tab), ",", ", "等常见分隔符分割的csv、txt文本数据
        self.train_data = pd.read_csv(self.train_filename, sep='\t', header=None)  # 读训练集，默认没有第一行属性名
        self.test_data = pd.read_csv(self.test_filename, sep='\t', header=None)  # 读测试集

        self.stoi = {'<unk>': 0, '<pad>': 1} # Re-index  # 局部的string to index字典（仅根据训练集建立词典）
        self.itos = {0: '<unk>', 1: '<pad>'} # Re-index  # 局部字典，只存储了'<umk>'和'<pad>'
        self.vocab_count = len(self.stoi)  # 字典中词的数量
        self.embedding_matrix = None  # embedding矩阵
        self.label_dict = dict(zip(self.train_data[0].unique(), pd.get_dummies(self.train_data[0].unique()).values.tolist()))  # 将label转成独热编码表示

        # 从训练集中划出测试集
        self.train_dataset, self.validation_dataset = random_split(self.train_data.to_numpy(), [int(len(self.train_data) * train_validation_split), len(self.train_data) - int(len(self.train_data) * train_validation_split)])
        self.test_dataset = self.test_data.to_numpy()  # 将DataFrame转成NumPy数组
        
        # build_vocab(): 函数, 仅根据训练集构建词汇表，stoi和itos字典和embedding词嵌入矩阵
        self.build_vocab() # Based on train_dataset only. Updates self.stoi, self.itos, self.vocab_count and self.embedding_matrix

        # 调用prepare_dataset()函数得到训练集、验证集和测试集以及根据训练集构建的图的边的信息矩阵和一个标记是否为public_edge_mask的矩阵
        self.train_dataset, self.validation_dataset, self.test_dataset, self.edge_stat, self.public_edge_mask = self.prepare_dataset()

    def build_vocab(self):
        vocab_list = [sentence.split(' ') for _, sentence in self.train_dataset]  # 训练集有两列：label, sentence
        unique_vocab = []  # 词汇表
        for vocab in vocab_list:  # 遍历词列表
            unique_vocab.extend(vocab)
        unique_vocab = list(set(unique_vocab))  # 获得训练集上的词汇表

        for vocab in unique_vocab:
            if vocab in self.tokenizer.stoi.keys():  # 更新在训练集中出现的词对应的字典信息
                self.stoi[vocab] = self.vocab_count  # 词和id
                self.itos[self.vocab_count] = vocab  # id和词
                self.vocab_count += 1  # 计数训练集中独一无二的词
        self.embedding_matrix = self.tokenizer.embedding(self.tokenizer.encode(list(self.stoi.keys())))  # 获得仅根据训练集建立的word embedding

    def prepare_dataset(self): # will also build self.edge_stat and self.public_edge_mask
        # preparing self.train_dataset
        # 每个句子 (document) 只取前max_length个词，转id表示，node_sets是二维列表，所有训练集中句子的节点
        node_sets = [[self.stoi.get(vocab, 0) for vocab in sentence.strip().split(' ')][:self.MAX_LENGTH] for _, sentence in self.train_dataset] # Only retrieve the first MAX_LENGTH words in each document
        neighbor_sets = [create_neighbor_set(node_set, p=self.p) for node_set in node_sets]  # 三维矩阵，存储训练集中所有句子的所有词的邻居
        labels = [self.label_dict[label] for label, _ in self.train_dataset]  # 训练集标签的独热编码
        
        # Construct edge statistics and public edge mask  # 训练集、验证集和测试集共用的两个统计边信息的矩阵
        edge_stat, public_edge_mask = self.build_public_edge_mask(node_sets, neighbor_sets, min_freq=self.min_freq)  # 调用函数得到统计边信息的矩阵和标记是否是公共边的矩阵
       
        train_dataset = TextLevelGNNDataset(node_sets, neighbor_sets, public_edge_mask, labels)  # 至此，已将训练集文本数据转成图结构数据

        # preparing self.validation_dataset
        node_sets = [[self.stoi.get(vocab, 0) for vocab in sentence.strip().split(' ')][:self.MAX_LENGTH] for _, sentence in self.validation_dataset] # Only retrieve the first MAX_LENGTH words in each document  # 二维矩阵，存储每个句子的节点
        neighbor_sets = [create_neighbor_set(node_set, p=self.p) for node_set in node_sets]  # 三维矩阵，存储验证集中所有句子的所有词的邻居
        labels = [self.label_dict[label] for label, _ in self.validation_dataset]  # 验证集标签的独热编码
        validation_dataset = TextLevelGNNDataset(node_sets, neighbor_sets, public_edge_mask, labels)  # 至此，获得验证集的图结构数据

        # preparing self.test_dataset
        node_sets = [[self.stoi.get(vocab, 0) for vocab in sentence.strip().split(' ')][:self.MAX_LENGTH] for _, sentence in self.test_dataset] # Only retrieve the first MAX_LENGTH words in each document
        neighbor_sets = [create_neighbor_set(node_set, p=self.p) for node_set in node_sets]
        labels = [self.label_dict[label] for label, _ in self.test_dataset]  # 测试集的标签仅用于计算分类的准确率
        test_dataset = TextLevelGNNDataset(node_sets, neighbor_sets, public_edge_mask, labels)  # 测试集的图结构数据

        return train_dataset, validation_dataset, test_dataset, edge_stat, public_edge_mask  # 图神经网络分类将根据这5个数据进行计算

    def build_public_edge_mask(self, node_sets, neighbor_sets, min_freq=2):
        edge_stat = torch.zeros(self.vocab_count, self.vocab_count)  # 创建一个0矩阵
        for node_set, neighbor_set in zip(node_sets, neighbor_sets):  # 取一个句子和其对应的邻居矩阵
            for neighbor in neighbor_set:  # 取一个词的邻居
                for to_node in neighbor:  # 取邻居中的一个词
                    edge_stat[node_set, to_node] += 1  # 将一个句子节点的每个词与其邻居中的一个词相邻的边连接
        public_edge_mask = edge_stat < min_freq # mark True at uncommon edges  # 用新的矩阵标记出现的次数小于2的边为publicpublic_edge_mask=True
        return edge_stat, public_edge_mask  # edge_stat: 二维矩阵，统计训练集中所有存在的边(以及边的次数)，public_edge_mask矩阵则标记哪些边是公用的，训练时不改变边的权值


# 给定一个句子的节点集合，返回这个句子所有节点的邻居集合，二维矩阵
def create_neighbor_set(node_set, p=2):
    if type(node_set[0]) != int:  # 验证node_set是一维列表
        raise ValueError('node_set should be a 1D list!')
    if p < 0:  # 上下文窗口
        raise ValueError('p should be an integer >= 0!')
    sequence_length = len(node_set)  # 新的句子的长度是被截取max_length之后的长度
    neighbor_set = []
    for i in range(sequence_length):  # 遍历每个位置
        neighbor = []
        for j in range(-p, p+1):
            if 0 <= i + j < sequence_length:
                neighbor.append(node_set[i+j])
        neighbor_set.append(neighbor)
    return neighbor_set  # 二维矩阵，表示每个节点的邻居


def pad_custom_sequence(sequences):  # 返回4个列表，列表元素为张量
    '''
    To pad different sequences into a padded tensor for training. The main purpose of this function is to separate different sequence, pad them in different ways and return padded sequences.
    Input:
        sequences <list>: A sequence with a length of 4, representing the node sets sequence in index 0, neighbor sets sequence in index 1, public edge mask sequence in index 2 and label sequence in index 3.
                          And the length of each sequences are same as the batch size.
                          sequences: [node_sets_sequence, neighbor_sets_sequence, public_edge_mask_sequence, label_sequence]
    Return:
        node_sets_sequence <torch.LongTensor>: The padded node sets sequence (works with batch_size >= 1).
        neighbor_sets_sequence <torch.LongTensor>: The padded neighbor sets sequence (works with batch_size >= 1).
        public_edge_mask_sequence <torch.BoolTe                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    。nsor>: The padded public edge mask sequence (works with batch_size >= 1).
        label_sequence <torch.FloatTensor>: The padded label sequence (works with batch_size >= 1).
    '''
    node_sets_sequence = []  # 节点列表
    neighbor_sets_sequence = []  # 邻居列表
    public_edge_mask_sequence = []  # 二维矩阵，标记边是否为public
    label_sequence = []  # 边列表
    for node_sets, neighbor_sets, public_edge_mask, label in sequences:  # 取sequence中这四个属性
        node_sets_sequence.append(node_sets)  # 取第一个属性保存至对应列表
        neighbor_sets_sequence.append(neighbor_sets)   # 取第二个属性
        public_edge_mask_sequence.append(public_edge_mask)  # 取第三个属性保存至对应列表
        label_sequence.append(label)  # 取第四个属性保存至对应列表

    # 对4个属性采用不同的pad方法，batch_first=True，输出为B x T x * ; padding_value=1用1打补丁，1表示'<pad>'
    node_sets_sequence = torch.nn.utils.rnn.pad_sequence(node_sets_sequence, batch_first=True, padding_value=1)  # 使用rnn.pad_sequence()对齐
    neighbor_sets_sequence, _ = padding_tensor(neighbor_sets_sequence)  # 返回张量
    public_edge_mask_sequence, _ = padding_tensor(public_edge_mask_sequence)  # 返回张量
    label_sequence = torch.nn.utils.rnn.pad_sequence(label_sequence, batch_first=True, padding_value=1)  # 返回对齐的张量

    return node_sets_sequence, neighbor_sets_sequence, public_edge_mask_sequence, label_sequence


# 批量处理图数据时，将每个图补丁成相同数量的节点
def padding_tensor(sequences, padding_idx=1):
    '''
    To pad tensor of different shape to be of the same shape, i.e. padding [tensor.rand(2, 3), tensor.rand(3, 5)] to a shape (2, 3, 5), where 0th dimension is batch_size, 1st and 2nd dimensions are padded.
    Input:
        sequences <list>: A list of tensors
        padding_idx <int>: The index that corresponds to the padding index
    Return:
        out_tensor <torch.tensor>: The padded tensor
        mask <torch.tensor>: A boolean torch tensor where 1 (represents '<pad>') are marked as true
    '''
    num = len(sequences)
    max_len_0 = max([s.shape[0] for s in sequences])
    max_len_1 = max([s.shape[1] for s in sequences])
    out_dims = (num, max_len_0, max_len_1)
    out_tensor = sequences[0].data.new(*out_dims).fill_(padding_idx)
    for i, tensor in enumerate(sequences):
        len_0 = tensor.size(0)
        len_1 = tensor.size(1)
        out_tensor[i, :len_0, :len_1] = tensor
    mask = out_tensor == padding_idx # Marking all places with padding_idx as mask  # 如果是打补丁的位置，mask对应位置为True
    return out_tensor, mask  # mask标记out_tensor中被打补丁的地方


class MessagePassing(nn.Module):  # 处理一个图数据，图节点之间消息的传递机制
    def __init__(self, vertice_count, input_size, out_size, dropout_rate=0, padding_idx=1):
        super(MessagePassing, self).__init__()
        self.vertice_count = vertice_count # |V|  # 训练集上的词数量
        self.input_size = input_size # d  节点的词嵌入向量维度, 300
        self.out_size = out_size # c  # 类标签数量
        self.dropout_rate = dropout_rate  # 图神经网络节点的'失活'率
        self.padding_idx = padding_idx  # {1: 'pad'}
        self.information_rate = nn.Parameter(torch.rand(self.vertice_count, 1)) # (|V|, 1), which means it is a column vector  # 对应论文中的\eta _n，保留原节点特征向量的信息量
        self.linear = nn.Linear(self.input_size, self.out_size) # (d, c) 线性连接层，输入d, 输出c向量
        self.dropout = nn.Dropout(self.dropout_rate)

    def forward(self, node_sets, embedded_node, edge_weight, embedded_neighbor_node):
        # node_sets: (batch_size, l)  # 一个图的节点集合
        # embedded_node: (batch_size, l, d)  # I是节点的数量, max_sentence_length
        # edge_weight: (batch_size, max_sentence_length, max_neighbor_count)
        # embedded_neighbor_node: (batch_size, max_sentence_length, max_neighbor_count, d)

        tmp_tensor = (edge_weight.view(-1, 1) * embedded_neighbor_node.view(-1, self.input_size)).view(embedded_neighbor_node.shape) # (batch_size, max_sentence_length, max_neighbor_count, d)
        tmp_tensor = tmp_tensor.masked_fill(tmp_tensor == 0, -1e18) # (batch_size, max_sentence_length, max_neighbor_count, d), mask for M such that masked places are marked as -1e18  # M公式中每个张量被masked的地方用-1e18代替
        tmp_tensor = self.dropout(tmp_tensor)
        M = tmp_tensor.max(dim=2)[0] # (batch_size, max_sentence_length, d), which is same shape as embedded_node (batch_size, l, d)  # 论文中的M值，即其他邻居节点传来的信息，每个维度上取最大值，得到最终的向量
        
        information_rate = self.information_rate[node_sets] # (batch_size, l, 1)
        information_rate = information_rate.masked_fill((node_sets == self.padding_idx).unsqueeze(-1), 1) # (batch_size, l, 1), Fill the information rate of the padding index as 1, such that new e_n = (1-i_r) * M + i_r * e_n = (1-1) * 0 + 1 * e_n = e_n (no update)  # 被打补丁的地方不更新
        embedded_node = (1 - information_rate) * M + information_rate * embedded_node # (batch_size, l, d)  # r_n^'  # 更新后的节点词嵌入向量
        sum_embedded_node = embedded_node.sum(dim=1) # (batch_size, d)  # 将句子的每个节点的词向量相加作为最终的句子的词向量表示
        x = F.relu(self.linear(sum_embedded_node)) # (batch_size, c)  # 激活函数，线性全连接层，输出一个c维的预测值
        # 论文中全连接层后接一个dropout层, dp=0.5,
        # x = self.dropout(x) # if putting dropout with p=0.5 here, it is equivalent to wiping 4 choices out of 8 choices on the question sheet, which does not make sense. If a dropout layer is placed at here, it works the best when p=0 (disabled), followed by p=0.05, ..., p=0.5 (worst and does not even converge).
        y = F.softmax(x, dim=1) # (batch_size, c) along the c dimension  # 转换成softmax
        return y  # 输出预测值


class TextLevelGNN(nn.Module):
    def __init__(self, pretrained_embeddings, out_size=8, dropout_rate=0, padding_idx=1):  # out_size表示分类的数目, 依数据集而定
        super(TextLevelGNN, self).__init__()
        self.out_size = out_size # c
        self.padding_idx = padding_idx
        self.weight_matrix = nn.Parameter(torch.randn(pretrained_embeddings.shape[0], pretrained_embeddings.shape[0])) # (|V|, |V|)  
        self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False, padding_idx=self.padding_idx) # (|V|, d)
        self.message_passing = MessagePassing(vertice_count=pretrained_embeddings.shape[0], input_size=pretrained_embeddings.shape[1], out_size=self.out_size, dropout_rate=dropout_rate, padding_idx=self.padding_idx) # input_size: (d,); out_size: (c,)
        self.public_edge_weight = nn.Parameter(torch.randn(1, 1)) # (1, 1)

    def forward(self, node_sets, neighbor_sets, public_edge_mask):
        # node_sets: (batch_size, l)
        # neighbor_sets: (batch_size, max_sentence_length, max_neighbor_count)
        # neighbor_sets_mask: (batch_size, max_sentence_length, max_neighbor_count) (no need)
        # public_edge_mask: (batch_size, max_sentence_length, max_neighbor_count)

        embedded_node = self.embedding(node_sets) # (batch_size, l, d)  # 得到节点的特征表示
        edge_weight = model.weight_matrix[node_sets.unsqueeze(2).repeat(1, 1, neighbor_sets.shape[-1]), neighbor_sets] # (batch_size, max_sentence_length, max_neighbor_count), neighbor_sets.shape[-1]: eg p=2, this expression=5; p=3, this expression=7. This is to first make node_sets to have same shape with neighbor_sets, then just do 1 query instead of 32*100 queries to speed up performance
        a = edge_weight * ~public_edge_mask # (batch_size, max_sentence_length, max_neighbor_count)  # ~比特位全部反转
        b = self.public_edge_weight.unsqueeze(2).expand(1, public_edge_mask.shape[-2], public_edge_mask.shape[-1]) * public_edge_mask # (batch_size, max_sentence_length, max_neighbor_count)
        edge_weight = a + b # (batch_size, max_sentence_length, max_neighbor_count)  # 得到边的权重，即e_an
        embedded_neighbor_node = self.embedding(neighbor_sets) # (batch_size, max_sentece_length, max_neighbor_count, d)  # 得到邻居节点的表示

        # Apply mask to edge_weight, to mask and cut-off any relationships to the padding nodes  # 被padding的地方边权重为1，不更新
        edge_weight = edge_weight.masked_fill((node_sets.unsqueeze(2).repeat(1, 1, neighbor_sets.shape[-1]) == self.padding_idx) | (neighbor_sets == self.padding_idx), 0) # (batch_size, max_sentence_length, max_neighbor_count)
        x = self.message_passing(node_sets, embedded_node, edge_weight, embedded_neighbor_node) # (batch_size, c)
        return x  # 调用message_passing返回预测结果


parser = argparse.ArgumentParser()
parser.add_argument('--cuda', default='0', type=str, required=False,
                    help='Choosing which cuda to use')
parser.add_argument('--embedding_size', default=100, type=int, required=False,
                    help='Number of hidden units in each layer of the graph embedding part')
parser.add_argument('--p', default=3, type=int, required=False,
                    help='The window size')
parser.add_argument('--min_freq', default=2, type=int, required=False,
                    help='The minimum no. of occurrence for a word to be considered as a meaningful word. Words with less than this occurrence will be mapped to a globally shared embedding weight (to the <unk> token). It corresponds to the parameter k in the original paper.')
parser.add_argument('--max_length', default=70, type=int, required=False,
                    help='The max length of each document to be processed')
parser.add_argument('--dropout', default=0, type=float, required=False,
                    help='Dropout rate')
parser.add_argument('--lr', default=1e-3, type=float, required=False,
                    help='Initial learning rate')
parser.add_argument('--lr_decay_factor', default=0.9, type=float, required=False,
                    help='Multiplicative factor of learning rate decays')  # 学习率衰减的乘子
parser.add_argument('--lr_decay_every', default=5, type=int, required=False,
                    help='Decaying learning rate every ? epochs')
parser.add_argument('--weight_decay', default=1e-4, type=float, required=False,
                    help='Weight decay (L2 penalty)')
parser.add_argument('--warm_up_epoch', default=0, type=int, required=False,
                    help='Pretraining for ? epochs before early stopping to be in effect')
parser.add_argument('--early_stopping_patience', default=10, type=int, required=False,
                    help='Waiting for ? more epochs after the best epoch to see any further improvements')
parser.add_argument('--early_stopping_criteria', default='loss', type=str, required=False,
                    choices=['accuracy', 'loss'],
                    help='Early stopping according to validation accuracy or validation loss')
parser.add_argument("--epoch", default=100, type=int, required=False,
                    help='Number of epochs to train')
args = parser.parse_args()

tokenizer = GloveTokenizer(f'/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.{args.embedding_size}d.txt')

dataset = TextLevelGNNDatasetClass(train_filename='/content/drive/MyDrive/Colab_Notebooks/DATA/r8-train-all-terms.txt',
                                   test_filename='/content/drive/MyDrive/Colab_Notebooks/DATA/r8-test-all-terms.txt',
                                   train_validation_split=0.9,  # 训练集中10%划出作为验证集, train = train*0.x
                                   tokenizer=tokenizer,
                                   p=args.p,
                                   min_freq=args.min_freq,
                                   MAX_LENGTH=args.max_length)

train_loader = DataLoader(dataset.train_dataset, batch_size=256, shuffle=True, collate_fn=pad_custom_sequence)
validation_loader = DataLoader(dataset.validation_dataset, batch_size=256, shuffle=True, collate_fn=pad_custom_sequence)
test_loader = DataLoader(dataset.test_dataset, batch_size=256, shuffle=True, collate_fn=pad_custom_sequence)

device = torch.device(f'cuda:{args.cuda}') if torch.cuda.is_available() else torch.device('cpu')
#device = torch.device('cpu')  # 直接用CPU计算
if torch.cuda.is_available():
  print(f'device: {device}')
  print(f'name:{torch.cuda.get_device_name(0)}')
  print(f'memory:{torch.cuda.get_device_properties(0).total_memory/1e9}')

model = TextLevelGNN(pretrained_embeddings=torch.tensor(dataset.embedding_matrix), dropout_rate=args.dropout).to(device)
criterion = nn.BCELoss()

lr = args.lr
lr_decay_factor = args.lr_decay_factor
lr_decay_every = args.lr_decay_every
weight_decay = args.weight_decay

warm_up_epoch = args.warm_up_epoch
early_stopping_patience = args.early_stopping_patience
early_stopping_criteria = args.early_stopping_criteria
best_epoch = 0 # Initialize

training = {}  # {'accuracy':[], 'loss': []}
validation = {}
testing = {}
training['accuracy'] = []
training['loss'] = []
validation['accuracy'] = []
validation['loss'] = []
testing['accuracy'] = []
testing['loss'] = []

for epoch in range(args.epoch):
    model.train()  # 训练模式
    train_loss = 0
    train_correct_items = 0
    previous_epoch_timestamp = time()

    if epoch % lr_decay_every == 0: # Update optimizer for every lr_decay_every epochs
        if epoch != 0: # When it is the first epoch, disable the lr_decay_factor
            lr *= lr_decay_factor  # 乘以0.9的速度衰减
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)  # 更新优化器参数

    for i, (node_sets, neighbor_sets, public_edge_masks, labels) in enumerate(train_loader):  # 批量取图数据
        #print('Finished batch:', i)
        node_sets = node_sets.to(device)
        neighbor_sets = neighbor_sets.to(device)
        public_edge_masks = public_edge_masks.to(device)
        labels = labels.to(device)
        prediction = model(node_sets, neighbor_sets, public_edge_masks)  # 调用TextLevelGNN类中的forward方法返回预测预测值，batch_size张图各自对应的预测结果
        loss = criterion(prediction, labels).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_correct_items += (prediction.argmax(dim=1) == labels.argmax(dim=1)).sum().item()  # 获得训练集上总的预测正确的数量
    train_accuracy = train_correct_items / len(dataset.train_dataset)  # 计算训练集上的准确率

    model.eval()  # 评价模式
    validation_loss = 0
    validation_correct_items = 0
    for i, (node_sets, neighbor_sets, public_edge_masks, labels) in enumerate(validation_loader):
        node_sets = node_sets.to(device)
        neighbor_sets = neighbor_sets.to(device)
        public_edge_masks = public_edge_masks.to(device)
        labels = labels.to(device)
        prediction = model(node_sets, neighbor_sets, public_edge_masks)
        loss = criterion(prediction, labels).to(device)
        validation_loss += loss.item()
        validation_correct_items += (prediction.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
    validation_accuracy = validation_correct_items / len(dataset.validation_dataset)

#     model.eval()
    test_loss = 0
    test_correct_items = 0
    for i, (node_sets, neighbor_sets, public_edge_masks, labels) in enumerate(test_loader):
        node_sets = node_sets.to(device)
        neighbor_sets = neighbor_sets.to(device)
        public_edge_masks = public_edge_masks.to(device)
        labels = labels.to(device)
        prediction = model(node_sets, neighbor_sets, public_edge_masks)
        loss = criterion(prediction, labels).to(device)
        test_loss += loss.item()
        test_correct_items += (prediction.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
    test_accuracy = test_correct_items / len(dataset.test_dataset)
    
    print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {validation_loss:.4f}, Testing Loss: {test_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, Validation Accuracy: {validation_accuracy:.4f}, Testing Accuracy: {test_accuracy:.4f}, Time Used: {time()-previous_epoch_timestamp:.2f}s')
    training['accuracy'].append(train_accuracy)
    training['loss'].append(train_loss)
    validation['accuracy'].append(validation_accuracy)
    validation['loss'].append(validation_loss)
    testing['accuracy'].append(test_accuracy)
    testing['loss'].append(test_loss)

    # add warmup mechanism for warm_up_epoch epochs
    if epoch >= warm_up_epoch:
        best_epoch = warm_up_epoch
        # early stopping
        if early_stopping_criteria == 'accuracy':
            if validation['accuracy'][epoch] > validation['accuracy'][best_epoch]:
                best_epoch = epoch
            elif epoch >= best_epoch + early_stopping_patience:
                print(f'Early stopping... (No further increase in validation accuracy) for consecutive {early_stopping_patience} epochs.')
                break
        if early_stopping_criteria == 'loss':
            if validation['loss'][epoch] < validation['loss'][best_epoch]:
                best_epoch = epoch
            elif epoch >= best_epoch + early_stopping_patience:
                print(f'Early stopping... (No further decrease in validation loss) for consecutive {early_stopping_patience} epochs.')
                break
    elif epoch + 1 == warm_up_epoch:
        print('--- Warm up finished ---')

# 保存结果至表格
df = pd.concat([pd.DataFrame(training), pd.DataFrame(validation), pd.DataFrame(testing)], axis=1)
df.columns = ['Training Accuracy', 'Training Loss', 'Validation Accuracy', 'Validation Loss', 'Testing Accuracy', 'Testing Loss']
df.to_csv(f'/content/drive/MyDrive/Colab_Notebooks/DATA/r8-eb_size{args.embedding_size}-p{args.p}-k{args.min_freq}-max_len{args.max_length}-dp{args.dropout}-epoch{args.epoch}.csv') # Logging


# 保存结果图像
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(8,4))

ax1 = fig.add_subplot(121)
ax1.plot(training['loss'], label='Training Loss')
ax1.plot(validation['loss'], label='Validation Loss')
ax1.plot(testing['loss'], label='Testing Loss')
ax1.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')

ax2 = fig.add_subplot(122)
ax2.plot(training['accuracy'], label='Training Accuracy')
ax2.plot(validation['accuracy'], label='Validation Accuracy')
ax2.plot(testing['accuracy'], label='Testing Accuracy')
ax2.legend()
#plt.xlabel('Epoch')
#plt.ylabel('Accuracy')
plt.savefig(f'/content/drive/MyDrive/Colab_Notebooks/DATA/r8-eb_size{args.embedding_size}-p{args.p}-k{args.min_freq}-max_len{args.max_length}-dp{args.dropout}-epoch{args.epoch}.png', dpi=400, bbox_inches='tight')

Writing parsing.py


调参

In [None]:
!python parsing.py --cuda=0 --embedding_size=300 --p=3 --min_freq=2 --max_length=70 --dropout=0 --epoch=2

device: cuda:0
name:Tesla K80
memory:11.996954624
Epoch: 1, Training Loss: 46.5930, Validation Loss: 6.7239, Testing Loss: 21.4874, Training Accuracy: 0.0427, Validation Accuracy: 0.0492, Testing Accuracy: 0.0329, Time Used: 12.83s
Epoch: 2, Training Loss: 47.9986, Validation Loss: 6.9831, Testing Loss: 21.8662, Training Accuracy: 0.0733, Validation Accuracy: 0.0965, Testing Accuracy: 0.0900, Time Used: 12.73s
