<a href="https://colab.research.google.com/github/hshuai97/Colab20210803/blob/main/TextLevelGNN_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1.  Lianzhe Huang 源代码地址: [source link](https://github.com/mojave-pku/TextLevelGCN)
2.  v1版本：[source link](https://colab.research.google.com/drive/1i5ySiMxus-pV_fCbk6yxAof6JL9qyeIQ#scrollTo=ww9wY6vugDpl)
3. v2版本：[source link]( https://colab.research.google.com/drive/1GVaPym3UMeBLSCGFJ5CQyc_t7Ra1wzgy#scrollTo=lwY3NTaygDor)

4. 在v2的版本上，使用softmax计算每个语义的重要程度: $\alpha_i=\frac{exp(S_i)}{\sum_jexp(S_j)}$ 即GCN+GAT

# install libraries

In [None]:
# install deep graph labrary: https://www.dgl.ai/pages/start.html
import torch
try:
  import dgl
except ModuleNotFoundError:
  # Installing dgl package with specific CUDA
  CUDA = 'cu' + torch.version.cuda.replace('.','')
  !pip install dgl-{CUDA} -f https://data.dgl.ai/wheels/repo.html

# install word2vec
try:
  import word2vec  # type: ignore
except ModuleNotFoundError:
  !pip install word2vec # type: ignore

# install torch_scatter  refer link: https://colab.research.google.com/drive/1PfoLSjr4TID_ogKEIV4Jusl2NromfgL3#scrollTo=mMoeNLSF38C-
try:
  import torch_scatter
except ModuleNotFoundError:
  TORCH = torch.__version__.split('+')[0]
  CUDA = 'cu' + torch.version.cuda.replace('.','')
  !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html

Looking in links: https://data.dgl.ai/wheels/repo.html
Collecting dgl-cu111
  Downloading https://data.dgl.ai/wheels/dgl_cu111-0.8.0.post1-cp37-cp37m-manylinux1_x86_64.whl (252.7 MB)
[K     |████████████████████████████████| 252.7 MB 52 kB/s 
Installing collected packages: dgl-cu111
Successfully installed dgl-cu111-0.8.0.post1
Collecting word2vec
  Downloading word2vec-0.11.1.tar.gz (42 kB)
[K     |████████████████████████████████| 42 kB 872 kB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: word2vec
  Building wheel for word2vec (PEP 517) ... [?25l[?25hdone
  Created wheel for word2vec: filename=word2vec-0.11.1-py2.py3-none-any.whl size=156420 sha256=574b2d03bddd4f28216fa604336128b1659d6f95e85c0ae23127564cdedfb944
  Stored in directory: /root/.cache/pip/wheels/c9/c0/d4/29d797817e268124a32b6cf8beb8b8fe87b86f099d5a049e61
S

# parsing

In [None]:
%%writefile parsing.py
# build graph
import torch
import dgl

# dada_helper
import os
import csv
import re

# model
import torch.nn.functional as F
import numpy as np
import word2vec
from torch_scatter import scatter_sum, scatter_max, scatter_mean  # 图计算需要
import math

# train
import tqdm
import sys, random
import argparse
from time import time


# ------------------------------------------------------------------------buildGraph.py------------------------------------------------------------------------
class GraphBuilder(object):
    def __init__(self, words, hiddenSizeNode):
        self.graph = dgl.DGLGraph()
        self.word2id = dict(zip(words, range(len(words))))
        self.graph.add_nodes(len(words))

        # add hidden para for nodes.
        self.graph.ndata['h'] = torch.nn.Parameter(
            torch.Tensor(len(words), hiddenSizeNode)
        )

        # all node are supposed to connected.
        # warning: self-connected enabled.
        for i in range(len(words)):
            self.graph.add_edges(i, range(0, len(words)))

        # add hidden para for edges. Only edge weight (size = 1 )
        self.graph.edata['h'] = torch.nn.Parameter(
            torch.Tensor(self.graph.number_of_edges(), 1)

        )


# ------------------------------------------------------------------------data_helper.py------------------------------------------------------------------------
class DataHelper(object):
    def __init__(self, dataset, mode='train', vocab=None):
        allowed_data = ['r8', '20ng', 'r52', 'mr', 'oh', 'agnews']
        tmp_data = dataset.split('/')
        dataset = tmp_data[len(tmp_data)-1]

        if dataset not in allowed_data:
            raise ValueError('currently allowed data: %s' % ','.join(allowed_data))
        else:
            self.dataset = dataset

        self.mode = mode

        self.base = os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data', self.dataset)  # 'data/r8'

        self.current_set = os.path.join(self.base, '%s-%s-stemmed.txt' % (self.dataset, self.mode))

        with open(os.path.join(self.base, 'label.txt')) as f:
            labels = f.read()
        self.labels_str = labels.split('\n')

        content, label = self.get_content()  # 获得数据集中的content和label
        self.label = self.label_to_onehot(label)  # 转1-hot编码
        
        if vocab is None:
            self.vocab = []
            try:
                self.get_vocab()
            except FileNotFoundError:
                self.build_vocab(content, min_count=30)
        else:
            self.vocab = vocab

        self.d = dict(zip(self.vocab, range(len(self.vocab))))  # 将词转字典
#-----------------------------self.content-----------------------------
        temp = []
        for doc in content:  # 取一个样本
            doc_tmp = []
            sentence = doc.split('[sep]')  # 根据' * '标记划分为句子
            for sen in sentence:  # 取一个句子
                sen = sen.strip()  # 删除可能的空格
                words = list(map(lambda x: self.word2id(x), sen.split()))
                if len(words)>0:  # 删除短的子句
                  doc_tmp.append(words)
            temp.append(doc_tmp)
        
        self.content = temp  # 此时content三维

        #self.content = [list(map(lambda x: self.word2id(x), doc.split(' '))) for doc in content]  # id型content


    def label_to_onehot(self, label_str):
        return [self.labels_str.index(l) for l in label_str]


    def get_content(self):
        with open(self.current_set) as f:  # open dataset
            all = f.read()
            all = all.split('\n')
            content = [line.split('\t') for line in all]  # 2 elements:(label, str_text)
        if self.dataset == '20ng' or 'r52' or 'mr':
            cleaned = []
            for i, pair in enumerate(content):
                if len(pair) < 2:  # remove some exceptional sentences
                    #print(f'i, pair: {i, pair}')
                    pass
                else:
                    cleaned.append(pair)
        else:
            cleaned = content

        label, content = zip(*cleaned)  # '*' means unpack a list

        return content, label


    def word2id(self, word):
        try:
            result = self.d[word]
        except KeyError:
            result = self.d['UNK']

        return result


    def get_vocab(self):
        with open(os.path.join(self.base, 'vocab-30.txt')) as f:
            vocab = f.read()
            self.vocab = vocab.split('\n')
            #print(f'self.vocab: {self.vocab}')


    def build_vocab(self, content, min_count=2):
        vocab = []

        for c in content:
            words = c.split(' ')
            for word in words:
                if word not in vocab and word !='[sep]' and word!='':
                    vocab.append(word)

        freq = dict(zip(vocab, [0 for i in range(len(vocab))]))

        for c in content:
            words = c.split()
            for word in words:
              if  word !='[sep]' and word!='':
                freq[word] += 1

        results = []
        for word in freq.keys():
            if freq[word] < min_count:
                continue
            else:
                results.append(word)

        results.insert(0, 'UNK')
        with open(os.path.join(self.base, 'vocab-30.txt'), 'w') as f:
            f.write('\n'.join(results))

        self.vocab = results


    def batch_iter(self, batch_size, num_epoch):
        for i in range(num_epoch):
            num_per_epoch = int(len(self.content) / batch_size)
            for batch_id in range(num_per_epoch):
                start = batch_id * batch_size
                end = min((batch_id + 1) * batch_size, len(self.content))

                content = self.content[start:end]
                label = self.label[start:end]

                device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
                y = torch.tensor(label).to(device)  # gup内存上的label

                yield content, y, i 


# ------------------------------------------------------------------------model.py------------------------------------------------------------------------
class Model(torch.nn.Module):
    def __init__(self,
                 class_num,
                 hidden_size_node,
                 vocab,
                 n_gram,
                 drop_out,
                 edges_num,  # 边的数量
                 edges_matrix,  # 边的连接关系
                 max_length=50,  # 350是文档限制，100对应句子限制
                 trainable_edges=True,
                 pmi=None,  # 边的权重
                 ):
        super(Model, self).__init__()

        self.vocab = vocab
        # print(len(vocab))
        # self.seq_edge_w = torch.nn.Embedding(edges_num, 1)  # 模型参数：边的权重
        print(f'edges_num: {edges_num}')
        print(f'pmi.shape: {pmi.shape}')

        self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)  # 模型参数： 隐藏层节点
        
        #self.seq_edge_w = torch.nn.Embedding.from_pretrained(pmi, freeze=True)  # 加载预训练的边的权重
            
        self.edges_num = edges_num
        if trainable_edges:
            self.seq_edge_w = torch.nn.Embedding.from_pretrained(torch.ones(edges_num, 1), freeze=False)
        else:
            self.seq_edge_w = torch.nn.Embedding.from_pretrained(pmi, freeze=True)

        self.hidden_size_node = hidden_size_node

        self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.300d.w2vformat.txt')))
        self.node_hidden.weight.requires_grad = True

        self.len_vocab = len(vocab)

        self.ngram = n_gram

        self.d = dict(zip(self.vocab, range(len(self.vocab))))

        self.max_length = max_length

        self.eta = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)  # 设置可训练的eta参数并初始化eta

        self.edges_matrix = edges_matrix

        self.dropout = torch.nn.Dropout(p=drop_out)

        self.activation = torch.nn.ReLU()

        self.Linear = torch.nn.Linear(hidden_size_node, class_num, bias=True)  # 一个MLP层

#----------------------------------------------------start-------------------------------------------------------------
        self.max_L = 30   # 每个样本最多的子图数量
        self.query = torch.nn.Parameter(torch.rand(32, hidden_size_node))  # (batch_size, 300)

    def size_splits(self, tensor, split_sizes, dim=0):   # refers to:  https://github.com/pytorch/pytorch/issues/3223
        """Splits the tensor according to chunks of split_sizes.
        # source link:  https://github.com/pytorch/pytorch/issues/3223
        
        Arguments:
            tensor (Tensor): tensor to split.
            split_sizes (list(int)): sizes of chunks
            dim (int): dimension along which to split the tensor.
        """
        if dim < 0:
            dim += tensor.dim()
        
        dim_size = tensor.size(dim)
        if dim_size != torch.sum(torch.Tensor(split_sizes)):
            raise KeyError("Sum of split sizes exceeds tensor dim")
        
        splits = torch.cumsum(torch.Tensor([0] + split_sizes), dim=0)[:-1]

        return tuple(tensor.narrow(int(dim), int(start), int(length)) 
            for start, length in zip(splits, split_sizes))
        
    
    def scaled_dot_product(self, q, k, v):  # refers to tutorial 6 written by Lippe (source link is in paper)
      d_k = q.size()[-1]
      attn_logits = torch.matmul(q, k.transpose(-2, -1))
      attn_logits = attn_logits / math.sqrt(d_k)
      attention = F.softmax(attn_logits, dim=-1)
      values = torch.matmul(attention, v)

      return values
        
    def Sen_Attn(self, v_s, graph_id):
      # v_s: feature vector of sentence level node
      # id_split: [[0,4], [5,11], ..., [266, 269]]  mark the node belongs to which document
      #alpha = torch.ones(len(v_s), 1)  # attention value
      q = self.query[graph_id].unsqueeze(0)  # (1, hidden_size_node)  二维张量
      k = torch.tanh(v_s)
      h = self.scaled_dot_product(q, k ,v_s)

      return h  # return the vector multipled by \alpha_{i}

#----------------------------------------------------end-------------------------------------------------------------


    def load_word2vec(self, word2vec_file):
        model = word2vec.load(word2vec_file)

        embedding_matrix = []

        for word in self.vocab:
            try:
                embedding_matrix.append(model[word])
            except KeyError:
                # print(f'Line 269. The word not in vocab:{word}')
                #unk_word=np.zeros(300)
                #print(f"line 271, model['the].shape: {len(model['the'])}")  # [,300]
                embedding_matrix.append(np.zeros(len(model['the'])))  # 其他词用0向量代替

        embedding_matrix = np.array(embedding_matrix)

        return embedding_matrix


    def add_seq_edges(self, doc_ids: list, old_to_new: dict):
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(doc_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(doc_ids))):
                dst_word_old = doc_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            edges.append([src, src])  # 二维列表存储边的关系
            old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id


    def seq_to_graph(self, doc_ids: list) -> dgl.DGLGraph():  # function annotation
        if len(doc_ids) > self.max_length:
            doc_ids = doc_ids[:self.max_length]  # 截取max_length长度的句子

        local_vocab = set(doc_ids)

        old_to_new = dict(zip(local_vocab, range(len(local_vocab))))

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        local_vocab = torch.tensor(list(local_vocab)).to(device)

        sub_graph = dgl.DGLGraph()  # 定义子图, sub_graph = dgl.DGLGraph()
        sub_graph = sub_graph.to(device)  # 将cpu上的图也放到device上

        sub_graph.add_nodes(len(local_vocab))  # 为图添加节点
        local_node_hidden = self.node_hidden(local_vocab)  # 词ID转词嵌入向量

        sub_graph.ndata['h'] = local_node_hidden  # Return a node data view for setting/getting node features

        seq_edges, seq_old_edges_id = self.add_seq_edges(doc_ids, old_to_new)

        edges, old_edge_id = [], []
        # edges = []

        edges.extend(seq_edges)

        old_edge_id.extend(seq_old_edges_id)

        old_edge_id = torch.LongTensor(old_edge_id).to(device)

        srcs, dsts = zip(*edges)  # '*' denotes unpack
        sub_graph.add_edges(srcs, dsts)  # 图添加边
        try:
            seq_edges_w = self.seq_edge_w(old_edge_id)
        except RuntimeError:
            print(f'old_edge_id: {old_edge_id}')
        sub_graph.edata['w'] = seq_edges_w  # 图的边添加权重

        return sub_graph  #  返回一个文本对应的一张图，节点特征和边的权重已被初始化

    
    def forward(self, doc_ids):  # 一次传一个batch_size数量的样本
        
        graphs = [self.seq_to_graph(doc_ids[i][j])  for i in range(len(doc_ids)) for j in range(len(doc_ids[i])) if j<self.max_L]  # 一个句子一个图, 至多取前20个子图
        graphs_id = [i for i in range(len(doc_ids)) for j in range(len(doc_ids[i])) if j < self.max_L]  # 句子所属文档的标记  [0,0,0,1,1,1,1, ...]

#--------------------------------------------------------------------Start-------------------------------------------------------------------
        id = torch.tensor([i for i in range(len(graphs_id))])
        split_size = [graphs_id.count(i) for i in range(len(doc_ids))]  # 将id按样本切分
        id_mapping = self.size_splits(tensor=id, split_sizes=split_size, dim=0)  # 按样本切分 [[0,1,2],[3,4,5,6], ...]

        id_split = torch.tensor([[tmp[0], tmp[len(tmp)-1]] for tmp in id_mapping])   # [[0,2], [3,6], ..., [266,269]]
        doc_id = torch.tensor([i for i in range(len(doc_ids))])  # (0, 1, 2, 3, ...., 31)

       
#--------------------------------------------------------------------End-------------------------------------------------------------------

        batch_graph = dgl.batch(graphs)
        batch_graph.update_all(
            message_func=dgl.function.src_mul_edge('h', 'w', 'weighted_message'),  # src_mul_edge(): u_mul_v()

            reduce_func=dgl.function.max('weighted_message', 'M')  # M_n

        )

        # 更新节点r_n^'
        eta = self.eta
        batch_graph.ndata['h'] = batch_graph.ndata['M']*(1-eta) + batch_graph.ndata['h']*eta
        out1 = dgl.sum_nodes(batch_graph, feat='h')  # (269, 300)
#--------------------------------------------------------------------start-------------------------------------------------------------------
        id_all = zip(id_split, doc_id)
        out2 = [self.Sen_Attn(out1[id_spl[0]: id_spl[1]+1], d_id)  for id_spl,  d_id in id_all]  # id_split: [[0,4], [5,11], ..., [266, 269]] 划分的下标，用于求Attention: \alpha_{ij}
        out2 = torch.cat(out2, dim=0)  # (32, 300)

#----------------------------------------------------------------end----------------------------------------------------------------------
        
        drop1 = self.dropout(out2)
        act1 = self.activation(drop1)

        l = self.Linear(act1)

        return l



# ------------------------------------------------------------------------pmi.py------------------------------------------------------------------------
def cal_PMI(dataset: str, window_size=20):  # point wise mutual information
    helper = DataHelper(dataset=dataset, mode="train")
    content, _ = helper.get_content()  # function returns (content, label)
    pair_count_matrix = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=int)  # p(i,j)
    word_count =np.zeros(len(helper.vocab), dtype=int)  # 公式中的#W(i)
    print(f'len of word_count: {len(helper.vocab)}')
    
    for sentence in content:  # one  sentence per document
        sentence = sentence.split()  # get the words in a sentence
        for i, word in enumerate(sentence):
            try:
                word_count[helper.d[word]] += 1
            except KeyError:
                continue
            start_index = max(0, i - window_size)
            end_index = min(len(sentence), i + window_size)
            for j in range(start_index, end_index):
                if i == j:
                    continue
                else:
                    target_word = sentence[j]
                    try:
                        pair_count_matrix[helper.d[word], helper.d[target_word]] += 1  # p(i,j)
                    except KeyError:
                        continue
        
    total_count = np.sum(word_count)
    # print(f'line 408:total_count: {total_count}')
    word_count = word_count / total_count
    # print(f'word_count: {word_count}')
    pair_count_matrix = pair_count_matrix / total_count
    
    pmi_matrix = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=float)
    for i in range(len(helper.vocab)):
        for j in range(len(helper.vocab)):
            pmi_matrix[i, j] = np.log(
                pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # 正则化后值很小，被取整为0，发生除0错误, 得到NaN值
            )
    
    pmi_matrix = np.nan_to_num(pmi_matrix)  # replace NaN with zero and infinity with large finite number
    
    pmi_matrix = np.maximum(pmi_matrix, 0.0)  # 大于0的元素不变，小于0的元素取0

    edges_weights = [0.0]
    count = 1  # number of edges
    edges_mappings = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=int)
    for i in range(len(helper.vocab)):
        for j in range(len(helper.vocab)):
            if pmi_matrix[i, j] != 0:  # the value of PMI, >0 positive correlation
                edges_weights.append(pmi_matrix[i, j])
                edges_mappings[i, j] = count
                count += 1

    edges_weights = np.array(edges_weights)

    edges_weights = edges_weights.reshape(-1, 1)
    # print(edges_weights.shape)
    edges_weights = torch.Tensor(edges_weights)
    
    return edges_weights, edges_mappings, count


# ------------------------------------------------------------------------train.py------------------------------------------------------------------------
NUM_ITER_EVAL = 100
EARLY_STOP_EPOCH = 10


def dev(model, dataset):
    data_helper = DataHelper(dataset, mode='dev')

    total_pred = 0
    correct = 0
    iter = 0
    for content, label, _ in data_helper.batch_iter(batch_size=32, num_epoch=1):
        iter += 1
        model.eval()

        logits = model(content)
        pred = torch.argmax(logits, dim=1)

        correct_pred = torch.sum(pred == label)

        correct += correct_pred
        total_pred += len(content)

    total_pred = float(total_pred)
    correct = correct.float()
    # print(torch.div(correct, total_pred))
    return torch.div(correct, total_pred)


def test(model_name, dataset):
    model = torch.load(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model_name + '.pkl'))

    data_helper = DataHelper(dataset, mode='test')

    total_pred = 0
    correct = 0
    iter = 0
    for content, label, _ in data_helper.batch_iter(batch_size=32, num_epoch=1):
        iter += 1
        model.eval()

        logits = model(content)
        pred = torch.argmax(logits, dim=1)

        correct_pred = torch.sum(pred == label)

        correct += correct_pred
        total_pred += len(content)

    total_pred = float(total_pred)
    #correct = correct.float()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    x = torch.div(correct, total_pred)
    x = x.to(device)
    return x


def train(ngram, model_name, drop_out, dataset, edges=True, total_epoch=1):
    print('load data helper.')
    data_helper = DataHelper(dataset, mode='train')

    tmp_data = dataset.split('/')
    dataset_name = tmp_data[len(tmp_data)-1]
    if os.path.exists(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model_name+'.pkl')) and model_name != 'temp_model':
        print('load model from file.')
        model = torch.load(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model_name+'.pkl'))
    else:
        print('new model.')
        if model_name == 'temp_model':
            model_name = f'temp_model_{dataset_name}'
        # edges_num, edges_matrix = edges_mapping(len(data_helper.vocab), data_helper.content, ngram)
        edges_weights, edges_mappings, count = cal_PMI(dataset=dataset)
        
        model = Model(class_num=len(data_helper.labels_str), hidden_size_node=300,
                      vocab=data_helper.vocab, n_gram=ngram, drop_out=drop_out, edges_matrix=edges_mappings, edges_num=count,
                      trainable_edges=edges, pmi=edges_weights)

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model.to(device)
    loss_func = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), weight_decay=1e-4)

    iter = 0

    best_acc = 0.0
    last_best_epoch = 0
    # start_time = time()
    total_loss = 0.0
    total_correct = 0
    total = 0
    for content, label, epoch in data_helper.batch_iter(batch_size=32, num_epoch=total_epoch):
        improved = ''
        model.train()

        logits = model(content)  # doc_ids传给forward()方法
        loss = loss_func(logits, label)

        pred = torch.argmax(logits, dim=1)

        correct = torch.sum(pred == label)

        total_correct += correct
        total += len(label)

        total_loss += loss.item()

        optim.zero_grad()
        loss.backward()
        optim.step()

        iter += 1
        if iter % NUM_ITER_EVAL == 0:

            val_acc = dev(model, dataset=dataset)
            if val_acc > best_acc:
                best_acc = val_acc
                last_best_epoch = epoch
                improved = '*'

                torch.save(model, f'/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model/{model_name}.pkl')

            if epoch - last_best_epoch >= EARLY_STOP_EPOCH:
                print('Early stopping...')
                return model_name
            print(f'Epoch:{epoch}, iter:{iter}, train loss:{total_loss/ NUM_ITER_EVAL :.4f}, train acc: {float(total_correct) / float(total) :.4f}, val acc: {val_acc:.4f},{improved}')

            total_loss = 0.0
            total_correct = 0
            total = 0

    return model_name


parser = argparse.ArgumentParser()
parser.add_argument('--ngram', required=False, type=int, default=4, help='ngram number')
parser.add_argument('--name', required=False, type=str, default='temp_model', help='project name')
parser.add_argument('--dropout', required=False, type=float, default=0.5, help='dropout rate')
parser.add_argument('--dataset', required=True, type=str, help='dataset')
parser.add_argument('--edges', required=False, type=int, default=1, help='trainable edges')
parser.add_argument('--rand', required=False, type=int, default=42, help='rand_seed')
parser.add_argument('--epoch', required=False, type=int, default=1, help='training epoch')
parser.add_argument('--lr', default=1e-3, type=float, required=False, help='Initial learning rate')

args = parser.parse_args()

device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
if torch.cuda.is_available():
  print(f'device: {device}')
  print(f'name:{torch.cuda.get_device_name(0)}')
  print(f'memory:{torch.cuda.get_device_properties(0).total_memory/1e9}')
print(f'*'*50)

print('ngram: %d' % args.ngram)
print('project_name: %s' % args.name)
print('dataset: %s' % args.dataset)
print('trainable_edges: %s' % args.edges)
print(f'*'*50)
# #
SEED = args.rand
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if args.edges == 1:
    edges = True
    print('trainable edges is True')
else:
    edges = False

model = train(args.ngram, args.name, args.dropout, dataset=args.dataset, edges=edges, total_epoch=args.epoch)
print(f'test acc: {test(model, args.dataset).cpu().numpy():.4f}')

Writing parsing.py


# run

In [None]:
!python parsing.py --dataset='/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data/r52' --name='temp_model' --edges=1 --ngram=3 --dropout=0.5 --epoch=100

DGL backend not selected or invalid.  Assuming PyTorch for now.
Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
**************************************************
ngram: 3
project_name: temp_model
dataset: /content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data/r52
trainable_edges: 1
**************************************************
trainable edges is True
load data helper.
new model.
len of word_count: 1552
  pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # 正则化后值很小，被取整为0，发生除0错误, 得到NaN值
  pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # 正则化后值很小，被取整为0，发生除0错误, 得到NaN值
edges_num: 1025075
pmi.shape: torch.Size([1025075, 1])
Epoch:0, iter:100, train loss:1.8600, train acc: 0.6141, val acc: 0.6391,*
Epoch:1, iter:200, train loss:1.0356, train acc: 0.7816, val acc: 0.8078,*
Epoch:1, iter:300, train loss:0.6861, t

# result

 oh: (min_count5, words8140)

 1. edge=0, ngram=1,dp=0.5: acc0.6778, cpu=3h48mins

 2. edge=0, ngram=4,dp=0.5: acc0.662, gpu34mins
 
 3. edge=1, ngram=1,dp=0.5: acc

20ng: (min_count=15,words12095)  min_count=5, words25782; min_count=10, words16016; 

1. edge=0,ngram=2,dp=0.5: acc0.8234,cpu6h

2. edge=0,ngram=3,dp=0.5: acc

mr: (min_count=5, words2934)

1. edge=0,ngram=2,dp=0.5:  acc0.7033, cpu34mins;

2. edge=0,ngram=1,dp=0.5: acc0.6948, cpu14mins

3. edge=0,ngram=2,dp=0.7: acc0.7097, cpu14mins

4. edge=1,ngram=2,dp=0.7: acc0.7365, cpu18mins

5. edge=1,ngram=3,dp=0.7: acc0.7342, cpu18mins

r8: (min_count=5, words4380)

1. edge=1,ngram=2,dp=0.7: acc0.9733, cpu17mins

2. edge=1,ngram=3,dp=0.7: acc0.9747, cpu15mins

3. edge=1,ngram=3,dp=0.5: acc0.9756, cpu18mins

4. edge=1,ngram=4,dp=0.5: acc0.9756, cpu19mins

5. edge=1,ngram=3,dp=0.4: acc0.9743, cpu18mins

r52: (min_coun=5, words5019) min_coun=30, words1552

1. edge=1,ngram=2,dp=0.7: acc0.9344, cpu33mins

2.  edge=1,ngram=3,dp=0.7: acc0.9355, cpu27mins

3. edge=1,ngram=3,dp=0.5: acc0.9320, cpu20mins, (min_coun=30, words1552)

agnews: (min_count=30, words9690) min_count=15, words14404； min_count=20, words12223

1. edge=1, ngram=2, dp=0.7: acc