<a href="https://colab.research.google.com/github/hshuai97/Colab20210803/blob/main/HieGNN(v5_2).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Use GAT+GAT for text classification.

References:

1.   [Lianzhe Huang source code](https://github.com/mojave-pku/TextLevelGCN)

2. [Lippe tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html)

3. [GAT: labml.ai Annotated Paper Implementations](https://nn.labml.ai/)

4. [GAT inplementation in Deep Graph library](https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py):

Equation 1:

\begin{equation}
z_i^{(l)}=W^{(l)}h_i^{(l)}
\end{equation}

Equation 2:
\begin{equation}
e_{ij}^{(l)}=LeakyReLU(\vec{a}^{(l)^T} (z_i^{(l)} || z_j^{(l)}))
\end{equation}

Equation 3:
\begin{equation}
\alpha_{ij}^{(l)} = \frac{exp(e_{ij}^{(l)})}{\sum_{k \in \mathcal{N(i)}}exp(e_{ik}^{(l)})}
\end{equation}

Equation 4:
\begin{equation}
h_i^{(l+1)} = \sigma{(\sum_{j \in \mathcal{N(i)}} \alpha_{ij}^{(l)} z_j^{(l)})}
\end{equation}


# Install labraries

In [None]:
# Install deep graph labrary: https://www.dgl.ai/pages/start.html
import torch
try:
  import dgl
except ModuleNotFoundError:
  # Installing dgl package with specific CUDA
  CUDA = 'cu' + torch.version.cuda.replace('.','')
  !pip install dgl-{CUDA} -f https://data.dgl.ai/wheels/repo.html

# Install word2vec
try:
  import word2vec  # type: ignore
except ModuleNotFoundError:
  !pip install word2vec # type: ignore

# Install torch_scatter
try:
  import torch_scatter
except ModuleNotFoundError:
  TORCH = torch.__version__.split('+')[0]
  CUDA = 'cu' + torch.version.cuda.replace('.','')
  !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html

# Install LabML: https://nn.labml.ai/
try:
  from labml_helpers.module import Module
except ModuleNotFoundError:
  !pip install labml-helpers

# Parsing

In [6]:
%%writefile parsing.py
# build graph
import torch
import dgl

# dada_helper
import os
import re

# model
import torch.nn.functional as F
import numpy as np
import word2vec
import math
from torch_scatter import scatter_add, scatter_max, scatter_mean
import multiprocessing

# train
import tqdm
import sys, random
import argparse
from time import time

# GCN in dgl
from dgl.nn import GATConv

# GAT In labmlai
from labml_helpers.module import Module  # Different with torch.nn.Module.


class DataHelper(object):  # Preprocess dataset.
    def __init__(self, dataset, mode='train', vocab=None):
        allowed_data = ['r8', '20ng', 'r52', 'mr', 'oh','agnews']  # six datasets
        tmp_data = dataset.split('/')
        dataset = tmp_data[len(tmp_data)-1]

        if dataset not in allowed_data:
            raise ValueError('currently allowed data: %s' % ','.join(allowed_data))
        else:
            self.dataset = dataset

        if self.dataset =='20ng':  # The 20ng  dataset is large.
          self.min_count = 25
        elif self.dataset == 'agnews':  # The agnews  dataset is large.
          self.min_count = 35
        elif dataset == 'mr':
          self.min_count =2  # The 'mr' dataset is small.
        else:
          self.min_count = 5

        self.mode = mode
        self.base = os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data', self.dataset)  # 'data/r8'
        self.current_set = os.path.join(self.base, '%s-%s-stemmed.txt' % (self.dataset, self.mode))

        with open(os.path.join(self.base, 'label.txt')) as f:
            labels = f.read()
        self.labels_str = labels.split('\n')

        content, label = self.get_content()  # Get content and label from dataset.
        self.label = self.label_to_onehot(label)  # Convert to 1-hot encode.
        
        if vocab is None:
            self.vocab = []
            try:
                self.get_vocab()
            except FileNotFoundError:
                self.build_vocab(content, self.min_count)
        else:
            self.vocab = vocab

        self.d = dict(zip(self.vocab, range(len(self.vocab))))

        temp = []  # For word-level and sentence-level split.
        for doc in content:  # Take one sample.
            doc_tmp = []
            sentence = doc.split('[sep]')  # Split document according to '[sep]' symbol.
            for sen in sentence:  # Take sentence from document
                sen = sen.strip()  # Remove space.
                if len(sen)>0:  # Remove blank exception.
                    words = list(map(lambda x: self.word2id(x), sen.split()))
                    doc_tmp.append(words)
            temp.append(doc_tmp)
        
        self.content = temp  # Word-level content split, 3-d.

        self.content_doc = [list(map(lambda x: self.word2id(x), doc.split())) for doc in content]  # Doc-level content split, 2-d.

    def label_to_onehot(self, label_str):
        return [self.labels_str.index(l) for l in label_str]

    def get_content(self):
        with open(self.current_set) as f:  # open dataset
            all = f.read()
            all = all.split('\n')
            content = [line.split('\t') for line in all]  # 2 elements:(label, str_text)
        if self.dataset in ['r8', '20ng', 'r52', 'mr', 'oh','agnews']:
            cleaned = []
            for i, pair in enumerate(content):
                if len(pair) < 2 or len(pair[1]) <5:  # # remove the sample that lack of 'str_text' or 'label'; or remove short text
                    # print(i, pair)
                    pass
                else:
                    cleaned.append(pair)
        else:
            cleaned = content

        label, content = zip(*cleaned)  # '*' means unpack a list

        return content, label

    def word2id(self, word):
        try:
            result = self.d[word]
        except KeyError:
            result = self.d['UNK']

        return result

    def get_vocab(self):
        with open(os.path.join(self.base, 'vocab-'+str(self.min_count)+'.txt')) as f:  # For example, vocab-5.txt
            vocab = f.read()
            self.vocab = vocab.split('\n')
            #print(f'self.vocab: {self.vocab}')

    def build_vocab(self, content, min_count=10):
        vocab = []

        for c in content:
            words = c.split()
            for word in words:
                if word not in vocab and word !='[sep]' and word!='':
                    vocab.append(word)

        freq = dict(zip(vocab, [0 for i in range(len(vocab))]))

        for c in content:
            words = c.split()
            for word in words:
              if  word !='[sep]' and word!='':
                freq[word] += 1

        results = []
        for word in freq.keys():
            if freq[word] < min_count:
                continue
            else:
                results.append(word)

        results.insert(0, 'UNK')
        with open(os.path.join(self.base, 'vocab-'+str(self.min_count)+'.txt'), 'w') as f:
            f.write('\n'.join(results))

        self.vocab = results

    def batch_iter(self, batch_size, num_epoch):
        for i in range(num_epoch):
            num_per_epoch = int(len(self.content) / batch_size)  # Split content by batch size.
            for batch_id in range(num_per_epoch):
                start = batch_id * batch_size
                end = min((batch_id + 1) * batch_size, len(self.content))

                content = self.content[start:end]  # Word-level graph content, 3-d.
                label = self.label[start:end]

                content_doc = self.content_doc[start:end]  # Document-level graph content, 2-d.

                device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
                
                y = torch.tensor(label).to(device)

                yield content, content_doc, y, i   # Return content, label and epoch.

class GraphAttentionLayer(Module): # GAT module
    def __init__(self, in_features, out_features, n_heads,  is_concat=True, dropout=0.6, leaky_relu_negative_slope=0.2):
      super(SentenceLevel, self).__init__()
      
      self.is_concat = is_concat
      self.n_heads = n_heads

      if is_concat:  # Concatenation
        assert out_features %  n_heads == 0
        self.n_hidden = out_features // n_heads
      else:  # Average
        self.n_hidden = out_features
      
      self.linear = torch.nn.Linear(in_features, self.n_hidden * n_heads, bias=False)  # Transform the node embeddings before self-attentin
      self.attn = torch.nn.Linear(self.n_hidden * 2, 1, bias=False)  # Linear layer to compute attention score e_ij
      self.activation = torch.nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)   # Activation for e_ij
      self.softmax = torch.nn.Softmax(dim=1)  # Softmax to compute attention alpha_ij
      self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, h, adj_mat):
      '''
        h: input node embeddings of shape [n_nodes, in_features]
        adj_mat: the adjacency matrix of shape [n_nodes, n_nodes, n_heads=1]. Different heads using same adjacency matrix.
      '''
      n_nodes = h.shape[0]
      g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)  # For each head: g_i^k = W^k h_i

      g_repeat = g.repeat(n_nodes, 1, 1)
      g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
      g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
      g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
      e = self.activation(self.attn(g_concat))
      e = e.squeeze(-1)

      assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
      assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
      assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads

      e = e.masked_fill(adj_mat == 0, float('-inf'))
      a = self.softmax(e)
      a = self.dropout(a)
      attn_res = torch.einsum('ijh,jhf->ihf', a, g)

      if self.is_concat:
        return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
      else:
        return attn_res.mean(dim=1)

class WordLevel(torch.nn.Module): # Word level graph building. Using message passing mechanism.
    def __init__(self, hidden_size_node, vocab, n_gram, edges_matrix, edges_num, max_length=50):
        super(WordLevel, self).__init__()

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        self.vocab = vocab  # Prepare for other methods in the same class.

        self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)  # Word-level node hidden (embedding).
        self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.300d.w2vformat.txt')))
        self.node_hidden.weight.requires_grad = True

        self.node_eta = torch.nn.Embedding.from_pretrained(torch.rand(len(vocab), 1), freeze=False)  # Word-level node save itself feature according to eta percent.

        self.edges_matrix = edges_matrix  # All global word-level edges are stored in the matrix.
        self.edges_num = edges_num  # The number of word-level edges.
        self.node_edge_w = torch.nn.Embedding.from_pretrained(torch.ones(edges_num, 1), freeze=False)  # Setting trainable parameters for word-level edges.

        self.len_vocab = len(vocab)
        self.d = dict(zip(vocab, range(len(vocab))))  # Set new index for vocab.

        self.ngram = n_gram
        self.max_length = max_length  # Maximum number of word-level nodes.

    def load_word2vec(self, word2vec_file):  # Loading embeddings representation from w2v.
      model = word2vec.load(word2vec_file)

      embedding_matrix = []

      for word in self.vocab:
          try:
              embedding_matrix.append(model[word])
          except KeyError:
              # print(f'Line 269. The word not in vocab:{word}')
              #unk_word=np.zeros(300)
              #print(f"line 271, model['the].shape: {len(model['the'])}")  # [,300]
              embedding_matrix.append(np.zeros(len(model['the'])))  # other unknow words use zeros vector

      embedding_matrix = np.array(embedding_matrix)

      return embedding_matrix

    def word_level_graph(self, word_ids):  # Build word level graph.
        if len(word_ids) > self.max_length:
            word_ids = word_ids[ : self.max_length]
        
        local_vocab = set(word_ids)
        old_to_new = dict(zip(local_vocab, range(len(local_vocab))))

        local_vocab = torch.tensor(list(local_vocab)).to(self.device)  # Graph on device, so need local_vocab on same device.

        graph = dgl.DGLGraph()
        graph = graph.to(self.device)

        graph.add_nodes(len(local_vocab)) # Add nodes for graph.
        graph.ndata['h'] = self.node_hidden(local_vocab)  # Add node features for graph.
        
        eta = torch.sigmoid(self.node_eta(local_vocab))  # Limit to [0-1].
        graph.ndata['eta'] = eta  # Add node eta for graph

        edges, edges_id = self.word_level_edges(word_ids, old_to_new)
        edges_id = torch.LongTensor(edges_id).to(self.device)  # The edges_id is a 1-d list.
        srcs, dsts = zip(*edges)  # get source and destination nodes id (local id).
        
        graph.add_edges(srcs, dsts)  # Add edges for graph.
        graph.edata['w'] = self.node_edge_w(edges_id) # Add weight for edges.

        return graph
    
    def word_level_edges(self, word_ids: list, old_to_new:dict):  # Add word-level edges for word-level graph.
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(word_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(word_ids))):
                dst_word_old = word_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            #edges.append([src, src])  # All edges.
            #old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id
    
    def forward(self, word_id):  # Update word-level graph

      graphs = [self.word_level_graph(word_id[i][j]) for i in range(len(word_id)) for j in range(len(word_id[i])) if j<30]  # Word-level graphs.
      graphs_id = [i for i in range(len(word_id)) for j in range(len(word_id[i])) if j<30]  # Word-level graph id indicates it belongs to which sample.

      batch_graph = dgl.batch(graphs)  # batching update
      batch_graph.update_all(
          message_func = dgl.function.src_mul_edge('h', 'w', 'weighted_message'),
          reduce_func = dgl.function.max('weighted_message','M')
      )
      batch_graph.ndata['h'] = batch_graph.ndata['eta']*batch_graph.ndata['h'] + batch_graph.ndata['M'] * (1 - batch_graph.ndata['eta'])  # Update word-level node feature.
      out_w = dgl.sum_nodes(batch_graph, feat='h')  # Convert word-level graph to one vector (num_graph, 300)

      return out_w, graphs_id  # Return word-level graph representation and its id (indicate the vector belongs to which sample).

'''
class SentenceLevel(torch.nn.Module):
  def __init__(self, hidden_size_node, vocab, n_gram, edges_matrix, edges_num, max_length=300):
    super()
'''

class GATLayer(torch.nn.Module):
  def __init__(self, hidden_size_node, vocab, n_gram, edges_matrix, edges_num, max_length=300):
    super(GATLayer, self).__init__()

    self.fc = torch.nn.Linear(hidden_size_node, hidden_size_node, bias=False)  # equation (1)
    self.attn_fc = torch.nn.Linear(2 * hidden_size_node, 1, bias=False)  # equation (2)
    self.reset_parameters()

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    self.vocab = vocab  # Prepare for other methods in this class.
    self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)  # Word-level node hidden (embedding).
    self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.300d.w2vformat.txt')))
    self.node_hidden.weight.requires_grad = True

    self.ngram = n_gram
    self.max_length = max_length  # Maximum number of doc-level nodes.
    self.edges_matrix = edges_matrix  # All global doc-level edges are stored in the matrix.

  def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_normal_(self.fc.weight, gain=gain)
        torch.nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
  
  def load_word2vec(self, word2vec_file):  # Loading embeddings representation from w2v.
      model = word2vec.load(word2vec_file)

      embedding_matrix = []

      for word in self.vocab:
          try:
              embedding_matrix.append(model[word])
          except KeyError:
              embedding_matrix.append(np.zeros(len(model['the'])))  # other unknow words use zeros vector

      embedding_matrix = np.array(embedding_matrix)

      return embedding_matrix
  
  def edge_attention(self, edges):  # func = lambda edges : F.leaky_relu(self.attn_fc(z2))
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

  def message_func(self, edges):  # Passing two tensors: the transformed z embedding of source node and un-normalized attention score e
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

  def reduce_func(self, nodes):  # Performing two tasks: 1 normalize the attention score e, 2 Aggregate neighbor embeddings weighted by the attention scores
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}  # New embeddings of nodes h'
  
  def gat_edges(self, word_ids: list, old_to_new:dict):  # Add edges for graph.
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(word_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(word_ids))):
                dst_word_old = word_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            #edges.append([src, src])  # All edges.
            #old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id
  
  def gat_graph(self, word_ids):  # Build GAT graph
    if len(word_ids) > self.max_length:
        word_ids = word_ids[ : self.max_length]
    
    local_vocab = set(word_ids)
    old_to_new = dict(zip(local_vocab, range(len(local_vocab))))

    local_vocab = torch.tensor(list(local_vocab)).to(self.device)  #Graph on device, so need local_vocab on same device.

    graph = dgl.DGLGraph()
    graph = graph.to(self.device)

    graph.add_nodes(len(local_vocab)) # Add nodes for graph.
    z = self.fc(self.node_hidden(local_vocab))  # equation (1)
    graph.ndata['z'] = z  # Add node embeddings for graph
    

    edges, edges_id = self.gat_edges(word_ids, old_to_new)
    srcs, dsts = zip(*edges)  # get source and destination nodes id (local id).
    
    graph.add_edges(srcs, dsts)  # Add edges for graph.
    graph.apply_edges(self.edge_attention)  # Update the features of the specified edges, graph.edata['e] = F.leaky_relu(a)

    return graph

  def forward(self, doc_id):
    graphs = [self.gat_graph(doc) for doc in doc_id ]  # Doc-level graphs.

    batch_graph = dgl.batch(graphs)  # batching update

    batch_graph.update_all(self.message_func, self.reduce_func)

    out_doc = dgl.sum_nodes(batch_graph, feat='h')

    return out_doc


class DocLevel(torch.nn.Module): # The same as Huang lianzhe 'Model' parts of paper code, see References.
    def __init__(self, hidden_size_node, vocab, n_gram, edges_matrix, edges_num, max_length=300):
      super(DocLevel, self).__init__()

      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

      self.vocab = vocab  # Prepare for other methods in this class.

      self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)  # Word-level node hidden (embedding).
      self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.300d.w2vformat.txt')))
      self.node_hidden.weight.requires_grad = True

      self.node_eta = torch.nn.Embedding.from_pretrained(torch.rand(len(vocab), 1), freeze=False)  # Word-level node save itself feature according to eta percent.

      self.edges_matrix = edges_matrix  # All global doc-level edges are stored in the matrix.
      self.edges_num = edges_num  # The number of doc-level edges.
      self.node_edge_w = torch.nn.Embedding.from_pretrained(torch.ones(edges_num, 1), freeze=False)  # Setting trainable parameters for word-level edges.

      self.len_vocab = len(vocab)
      self.d = dict(zip(vocab, range(len(vocab))))  # Set new index for vocab.

      self.ngram = n_gram
      self.max_length = max_length  # Maximum number of doc-level nodes.

    def doc_level_edges(self, word_ids: list, old_to_new:dict):  # Add doc-level edges for word-level graph.
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(word_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(word_ids))):
                dst_word_old = word_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            edges.append([src, src])  # All edges.
            old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id
    
    def load_word2vec(self, word2vec_file):  # Loading embeddings representation from w2v.
      model = word2vec.load(word2vec_file)

      embedding_matrix = []

      for word in self.vocab:
          try:
              embedding_matrix.append(model[word])
          except KeyError:
              # print(f'Line 269. The word not in vocab:{word}')
              #unk_word=np.zeros(300)
              #print(f"line 271, model['the].shape: {len(model['the'])}")  # [,300]
              embedding_matrix.append(np.zeros(len(model['the'])))  # other unknow words use zeros vector

      embedding_matrix = np.array(embedding_matrix)

      return embedding_matrix

    def doc_level_graph(self, word_ids):  # Build doc level graph.
        if len(word_ids) > self.max_length:
            word_ids = word_ids[ : self.max_length]
        
        local_vocab = set(word_ids)
        old_to_new = dict(zip(local_vocab, range(len(local_vocab))))

        local_vocab = torch.tensor(list(local_vocab)).to(self.device)  #Graph on device, so need local_vocab on same device.

        graph = dgl.DGLGraph()
        graph = graph.to(self.device)

        graph.add_nodes(len(local_vocab)) # Add nodes for graph.
        graph.ndata['h'] = self.node_hidden(local_vocab)  # Add node features for graph.
        
        eta = torch.sigmoid(self.node_eta(local_vocab))  # Limit to [0-1].
        graph.ndata['eta'] = eta  # Add node eta for graph

        edges, edges_id = self.doc_level_edges(word_ids, old_to_new)
        edges_id = torch.LongTensor(edges_id).to(self.device)  # The edges_id is a 1-d list.
        srcs, dsts = zip(*edges)  # get source and destination nodes id (local id).
        
        graph.add_edges(srcs, dsts)  # Add edges for graph.
        graph.edata['w'] = self.node_edge_w(edges_id) # Add weight for edges.

        return graph

    def doc_level_edges(self, word_ids: list, old_to_new:dict):  # Add doc-level edges for word-level graph.
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(word_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(word_ids))):
                dst_word_old = word_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            edges.append([src, src])  # All edges.
            old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id

    def forward(self, doc_id):  # Update doc-level graph
      graphs = [self.doc_level_graph(doc) for doc in doc_id ]  # Doc-level graphs.

      batch_graph = dgl.batch(graphs)  # batching update
      batch_graph.update_all(
          message_func = dgl.function.src_mul_edge('h', 'w', 'weighted_message'),
          reduce_func = dgl.function.max('weighted_message','M')
      )
      batch_graph.ndata['h'] = batch_graph.ndata['eta']*batch_graph.ndata['h'] + batch_graph.ndata['M'] * (1 - batch_graph.ndata['eta'])  # Update word-level node feature.
      out_doc = dgl.sum_nodes(batch_graph, feat='h')  # Convert doc-level graph to one vector (batch_size, 300)

      return out_doc  # Return doc-level graph representation.


class HieGNN(torch.nn.Module):  # Paper model.
  def __init__(self, class_num, hidden_size_node,vocab, n_gram, drop_out, edges_matrix, edges_num):
    super(HieGNN, self).__init__()

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    #self.alpha = torch.nn.Parameter(torch.sigmoid(torch.rand(1, 1)), requires_grad=True)

    self.gat= GATLayer(hidden_size_node, vocab, n_gram, edges_matrix, edges_num)

    self.dropout = torch.nn.Dropout(p=drop_out)  # Output layer.
    self.activation = torch.nn.ReLU()
    self.Linear = torch.nn.Linear(hidden_size_node, class_num, bias=True)  # y = Wx + b
  

  def forward(self, content, content_doc):

    output = self.gat(content_doc)  # Doc-level output hidden state and graph id, 2-d: (batch_size, 300). 
    
    drop = self.dropout(output)
    act = self.activation(drop)

    I = self.Linear(act)

    return I


def cal_PMI(dataset: str, window_size=20):  # The point wise mutual information.
    helper = DataHelper(dataset=dataset, mode="train")
    content, _ = helper.get_content()  # function returns (content, label)
    pair_count_matrix = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=int)  # p(i,j)
    word_count =np.zeros(len(helper.vocab), dtype=int)  # 公式中的#W(i)
    print(f'Len of word_count: {len(helper.vocab)}')
    
    for sentence in content:  # one  sentence per document
        sentence = sentence.split()  # get the words in a sentence
        for i, word in enumerate(sentence):
            try:
                word_count[helper.d[word]] += 1
            except KeyError:
                continue
            start_index = max(0, i - window_size)
            end_index = min(len(sentence), i + window_size)
            for j in range(start_index, end_index):
                if i == j:
                    continue
                else:
                    target_word = sentence[j]
                    try:
                        pair_count_matrix[helper.d[word], helper.d[target_word]] += 1  # p(i,j)
                    except KeyError:
                        continue
        
    total_count = np.sum(word_count)
    # print(f'line 408:total_count: {total_count}')
    word_count = word_count / total_count
    # print(f'word_count: {word_count}')
    pair_count_matrix = pair_count_matrix / total_count
    
    pmi_matrix = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=float)
    for i in range(len(helper.vocab)):
        for j in range(len(helper.vocab)):
            pmi_matrix[i, j] = np.log(
                pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # 正则化后值很小，被取整为0，发生除0错误, 得到NaN值
            )
    
    pmi_matrix = np.nan_to_num(pmi_matrix)  # replace NaN with zero and infinity with large finite number
    
    pmi_matrix = np.maximum(pmi_matrix, 0.0)  # Less than 0, set 0.

  
    count = 1  # The number of edges.
    edges_mappings = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=int)
    for i in range(len(helper.vocab)):
        for j in range(len(helper.vocab)):
            if pmi_matrix[i, j] > 0:  # The value of PMI, more than 0 is positive correlation.
                edges_mappings[i, j] = count
                count += 1
   
    return edges_mappings, count

NUM_ITER_EVAL = 100
EARLY_STOP_EPOCH = 10

def train(ngram, model_name, drop_out, dataset, total_epoch=1):  # Training function.
    data_helper = DataHelper(dataset, mode='train')

    tmp_data = dataset.split('/')
    dataset_name = tmp_data[len(tmp_data)-1]
    if os.path.exists(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model_name+'.pkl')) and model_name != 'temp_model':
        print('load model from file.')
        model = torch.load(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model_name+'.pkl'))
    else:
        print('Build new model.')
        if model_name == 'temp_model':
            model_name = f'temp_model_{dataset_name}'
        edges_mappings, count = cal_PMI(dataset=dataset)
        
        model = HieGNN(class_num=len(data_helper.labels_str), hidden_size_node=300,
                      vocab=data_helper.vocab, n_gram=ngram, drop_out=drop_out, edges_matrix=edges_mappings, edges_num=count)

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model.to(device)
    loss_func = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), weight_decay=1e-4)

    iter = 0

    best_acc = 0.0
    last_best_epoch = 0
    total_loss = 0.0
    total_correct = 0
    total = 0
    for content, content_doc, label, epoch in data_helper.batch_iter(batch_size=64, num_epoch=total_epoch):
        improved = ''
        model.train()

        logits = model(content, content_doc)
        loss = loss_func(logits, label)

        pred = torch.argmax(logits, dim=1)

        correct = torch.sum(pred == label)

        total_correct += correct
        total += len(label)

        total_loss += loss.item()

        optim.zero_grad()
        loss.backward()
        optim.step()

        iter += 1
        if iter % NUM_ITER_EVAL == 0:

            val_acc = dev(model, dataset=dataset)
            if val_acc > best_acc:
                best_acc = val_acc
                last_best_epoch = epoch
                improved = '*'

                torch.save(model, f'/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model/{model_name}.pkl')

            if epoch - last_best_epoch >= EARLY_STOP_EPOCH:
                print('Early stopping...')
                return model_name
            print(f'Epoch:{epoch}, iter:{iter}, train loss:{total_loss/ NUM_ITER_EVAL :.4f}, train acc: {float(total_correct) / float(total) :.4f}, val acc: {val_acc:.4f},{improved}')

            total_loss = 0.0
            total_correct = 0
            total = 0

    return model_name

def dev(model, dataset):
  data_helper = DataHelper(dataset, mode='dev')

  total_pred =0
  correct = 0
  iter = 0

  for content, content_doc, label, _ in data_helper.batch_iter(batch_size=64, num_epoch=1):
    iter += 1
    model.eval()

    logits = model(content, content_doc)
    pred = torch.argmax(logits, dim=1)

    correct_pred = torch.sum(pred==label)
    correct += correct_pred
    total_pred += len(content)

  total_pred = float(total_pred)
  correct = correct.float()

  return torch.div(correct, total_pred)

def test(model, dataset):
  model = torch.load(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model+'.pkl'))
  data_helper = DataHelper(dataset, mode='test')

  total_pred = 0
  correct = 0
  iter = 0
  for content, content_doc, label, _ in data_helper.batch_iter(batch_size=64, num_epoch=1):
    iter += 1
    model.eval()

    logits = model(content, content_doc)
    pred = torch.argmax(logits, dim=1)

    correct_pred = torch.sum(pred==label)
    correct += correct_pred
    total_pred += len(content)
  
  total_pred = float(total_pred)
  correct = correct.float()

  return torch.div(correct, total_pred)

parser = argparse.ArgumentParser()
parser.add_argument('--ngram', required=False, type=int, default=3, help='word level and doc level n-gram')
parser.add_argument('--name', required=False, type=str, default='temp_model', help='model name')
parser.add_argument('--dropout', required=False, type=float, default=0.5, help='drop out rate')
parser.add_argument('--dataset', required=True, type=str, help='dataset')
parser.add_argument('--rand', required=False, type=int, default=42, help='rand seed')
parser.add_argument('--epoch', required=False, type=int, default=50, help='training epoch')

args = parser.parse_args()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # Print running machine
if torch.cuda.is_available():
  print(f'device: {device}')
  print(f'name: {torch.cuda.get_device_name(0)}')
  print(f'memory: {torch.cuda.get_device_properties(0).total_memory/1e9}')
  print(f'*'*50)

SEED = args.rand
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

model = train(args.ngram, args.name, args.dropout, dataset=args.dataset, total_epoch=args.epoch)
print(f'test acc: {test(model, args.dataset).cpu().numpy():.4f}')


Overwriting parsing.py


# Run

In [None]:
!python parsing.py --dataset='/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data/r52' --ngram=2 --dropout=0.5 --epoch=100

Build new model.
Len of word_count: 5019
  pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # 正则化后值很小，被取整为0，发生除0错误, 得到NaN值
  pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # 正则化后值很小，被取整为0，发生除0错误, 得到NaN值
Epoch:1, iter:100, train loss:1.2262, train acc: 0.7617, val acc: 0.8406,*
Epoch:2, iter:200, train loss:0.3490, train acc: 0.9297, val acc: 0.9156,*
Epoch:3, iter:300, train loss:0.1531, train acc: 0.9695, val acc: 0.9187,*
Epoch:4, iter:400, train loss:0.0776, train acc: 0.9842, val acc: 0.9141,
Epoch:5, iter:500, train loss:0.0454, train acc: 0.9906, val acc: 0.9187,


# Test