<a href="https://colab.research.google.com/github/hshuai97/Colab20210803/blob/main/HieGNN(v5_4).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

word-level: GAT or GCN

sentence-level: GAT or GCN

Reference:

1. [GAT and GCN inplementation in Deep Graph Library](https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py)

2.  [Lianzhe Huang source code: ](https://github.com/mojave-pku/TextLevelGCN)

# Install labraries

In [None]:
# Install deep graph labrary: https://www.dgl.ai/pages/start.html
import torch
try:
  import dgl
except ModuleNotFoundError:
  # Installing dgl package with specific CUDA
  CUDA = 'cu' + torch.version.cuda.replace('.','')
  !pip install dgl-{CUDA} -f https://data.dgl.ai/wheels/repo.html

# Install word2vec
try:
  import word2vec  # type: ignore
except ModuleNotFoundError:
  !pip install word2vec # type: ignore

try:
  import torch_scatter
except ModuleNotFoundError:
  TORCH = torch.__version__.split('+')[0]
  CUDA = 'cu' + torch.version.cuda.replace('.','')
  !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html

# Parsing

In [None]:
%%writefile parsing.py
# build graph
import torch
import dgl

# dada_helper
import os
import re

# model
import torch.nn.functional as F
import numpy as np
import word2vec
import math
from torch_scatter import scatter_add, scatter_max, scatter_mean

# train
import random
import argparse
from time import time


class DataHelper(object):  # Preprocess dataset. (Almost the same code as Huang, see reference)
    def __init__(self, dataset, mode='train', vocab=None):
        allowed_data = ['20ng', 'r8', 'r52', 'oh', 'mr', 'agnews']  # six datasets
        tmp_data = dataset.split('/')
        dataset = tmp_data[len(tmp_data)-1]

        if dataset not in allowed_data:
            raise ValueError('currently allowed data: %s' % ','.join(allowed_data))
        else:
            self.dataset = dataset

        if self.dataset =='20ng':  # The 20ng  dataset is large.
          self.min_count = 20
        elif self.dataset == 'agnews':  # The agnews  dataset is large.
          self.min_count = 35
        elif dataset == 'mr':
          self.min_count =2  # The 'mr' dataset is small.
        else:
          self.min_count = 5

        self.mode = mode
        self.base = os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data', self.dataset)  # 'data/r8'
        self.current_set = os.path.join(self.base, '%s-%s-stemmed.txt' % (self.dataset, self.mode))

        with open(os.path.join(self.base, 'label.txt')) as f:
            labels = f.read()
        self.labels_str = labels.split('\n')

        content, label = self.get_content()  # Get content and label from dataset.
        self.label = self.label_to_onehot(label)  # Convert to 1-hot encode.
        
        if vocab is None:
            self.vocab = []
            try:
                self.get_vocab()
            except FileNotFoundError:
                self.build_vocab(content, self.min_count)
        else:
            self.vocab = vocab

        self.d = dict(zip(self.vocab, range(len(self.vocab))))

        temp = []  # Word-level
        for doc in content:
            doc_tmp = []
            sentence = doc.split('[sep]')  # Split document according to '[sep]' symbol.
            for sen in sentence:  # Take sentence from document
                sen = sen.strip()  # Remove space.
                if len(sen)>0:  # Remove blank exception.
                    words = list(map(lambda x: self.word2id(x), sen.split()))
                    doc_tmp.append(words)
            temp.append(doc_tmp)
        
        self.content = temp  # Word-level content split, 3-d.

        self.content_doc = [list(map(lambda x: self.word2id(x), doc.split())) for doc in content]  # Doc-level content split, 2-d.

    def label_to_onehot(self, label_str):
        return [self.labels_str.index(l) for l in label_str]

    def get_content(self):
        with open(self.current_set) as f:  # open dataset
            all = f.read()
            all = all.split('\n')
            content = [line.split('\t') for line in all]  # 2 elements:(label, str_text)
        if self.dataset in ['r8', '20ng', 'r52', 'mr', 'oh','agnews']:
            cleaned = []
            for i, pair in enumerate(content):
                if len(pair) < 2 or len(pair[1]) <5:  # # remove the sample that lack of 'str_text' or 'label'; or remove short text
                    # print(i, pair)
                    pass
                else:
                    cleaned.append(pair)
        else:
            cleaned = content

        label, content = zip(*cleaned)  # '*' means unpack a list

        return content, label

    def word2id(self, word):
        try:
            result = self.d[word]
        except KeyError:
            result = self.d['UNK']

        return result

    def get_vocab(self):
        with open(os.path.join(self.base, 'vocab-'+str(self.min_count)+'.txt')) as f:  # For example, vocab-5.txt
            vocab = f.read()
            self.vocab = vocab.split('\n')
            #print(f'self.vocab: {self.vocab}')

    def build_vocab(self, content, min_count=10):
        vocab = []

        for c in content:
            words = c.split()
            for word in words:
                if word not in vocab and word !='[sep]' and word!='':
                    vocab.append(word)

        freq = dict(zip(vocab, [0 for i in range(len(vocab))]))

        for c in content:
            words = c.split()
            for word in words:
              if  word !='[sep]' and word!='':
                freq[word] += 1

        results = []
        for word in freq.keys():
            if freq[word] < min_count:
                continue
            else:
                results.append(word)

        results.insert(0, 'UNK')
        with open(os.path.join(self.base, 'vocab-'+str(self.min_count)+'.txt'), 'w') as f:
            f.write('\n'.join(results))

        self.vocab = results

    def batch_iter(self, batch_size, num_epoch):
        for i in range(num_epoch):
            num_per_epoch = int(len(self.content) / batch_size)  # Split content by batch size.
            for batch_id in range(num_per_epoch):
                start = batch_id * batch_size
                end = min((batch_id + 1) * batch_size, len(self.content))

                content = self.content[start:end]  # Word-Level: (batch_size, num_sen, num_words) 
                label = self.label[start:end]

                content_doc = self.content_doc[start:end]  # Sentence-Level: (batch_size, num_words)

                device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
                
                y = torch.tensor(label).to(device)

                yield content, content_doc, y, i   # Return content, label and epoch.


class WordLevelGCN(torch.nn.Module): # GCN for word-level
    def __init__(self, hidden_size_node, vocab, n_gram, edges_matrix, edges_num, max_length=300):
        super(WordLevelGCN, self).__init__()

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        self.vocab = vocab  # Prepare for other methods in the same class.

        self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)  # Word-level node hidden (embedding).
        self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.100d.w2vformat.txt')))
        self.node_hidden.weight.requires_grad = True

        self.node_eta = torch.nn.Embedding.from_pretrained(torch.rand(len(vocab), 1), freeze=False)  # Word-level node save itself feature according to eta percent.

        self.edges_matrix = edges_matrix  # All global word-level edges are stored in the matrix.
        self.edges_num = edges_num  # The number of word-level edges.
        self.node_edge_w = torch.nn.Embedding.from_pretrained(torch.ones(edges_num, 1), freeze=False)  # Setting trainable parameters for word-level edges.

        self.len_vocab = len(vocab)
        self.d = dict(zip(vocab, range(len(vocab))))  # Set new index for vocab.

        self.ngram = n_gram
        self.max_length = max_length  # Maximum number of word-level nodes.

    def load_word2vec(self, word2vec_file):  # Loading embeddings representation from w2v.
      model = word2vec.load(word2vec_file)

      embedding_matrix = []

      for word in self.vocab:
          try:
              embedding_matrix.append(model[word])
          except KeyError:
              # print(f'Line 269. The word not in vocab:{word}')
              #unk_word=np.zeros(300)
              #print(f"line 271, model['the].shape: {len(model['the'])}")  # [,300]
              embedding_matrix.append(np.zeros(len(model['the'])))  # other unknow words use zeros vector

      embedding_matrix = np.array(embedding_matrix)

      return embedding_matrix

    def build_graph(self, word_ids):  # Build word level graph.
        if len(word_ids) > self.max_length:
            word_ids = word_ids[ : self.max_length]
        
        local_vocab = set(word_ids)
        old_to_new = dict(zip(local_vocab, range(len(local_vocab))))

        local_vocab = torch.tensor(list(local_vocab)).to(self.device)  # Graph on device, so need local_vocab on same device.

        graph = dgl.DGLGraph()
        graph = graph.to(self.device)

        graph.add_nodes(len(local_vocab)) # Add nodes for graph.
        graph.ndata['h'] = self.node_hidden(local_vocab)  # Add node features for graph.
        
        eta = torch.sigmoid(self.node_eta(local_vocab))  # Limit to [0-1].
        graph.ndata['eta'] = eta  # Add node eta for graph

        edges, edges_id = self.add_edges(word_ids, old_to_new)
        edges_id = torch.LongTensor(edges_id).to(self.device)  # The edges_id is a 1-d list.
        srcs, dsts = zip(*edges)  # get source and destination nodes id (local id).
        
        graph.add_edges(srcs, dsts)  # Add edges for graph.
        graph.edata['w'] = self.node_edge_w(edges_id) # Add weight for edges.

        return graph
    
    def add_edges(self, word_ids: list, old_to_new:dict):  # Add word-level edges for word-level graph.
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(word_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(word_ids))):
                dst_word_old = word_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            edges.append([src, src])  # All edges.
            old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id
    
    def forward(self, word_id):  # Update word-level graph
      graphs = []
      graphs_id = {}  # Dictionary data formar, such as {0:[0,3]}

      count=0
      for i in range(len(word_id)):
        t_index = [count]
        if len(word_id[i])>0:
          for j in range(len(word_id[i])):
            graphs.append(self.build_graph(word_id[i][j]))
            count+=1
          t_index.append(count)  # [start, end]
          graphs_id[i]=t_index
        else:
          print('Error! Sentence length must be greater than 0.')
          break


      batch_graph = dgl.batch(graphs)  # batching update
      batch_graph.update_all(
          message_func = dgl.function.src_mul_edge('h', 'w', 'weighted_message'),
          reduce_func = dgl.function.max('weighted_message','M')
      )
      batch_graph.ndata['h'] = batch_graph.ndata['eta']*batch_graph.ndata['h'] + batch_graph.ndata['M'] * (1 - batch_graph.ndata['eta'])  # Update word-level node feature.
      out_w = dgl.sum_nodes(batch_graph, feat='h')  # Convert word-level graph to one vector (num_graph, 300)

      return out_w, graphs_id  # (num_sen, hidden_size_node); graphs_id is  dictionary data format


class SentenceLevelGCN(torch.nn.Module):  # GCN for sentence-level
    def __init__(self, hidden_size_node, vocab, n_gram, edges_matrix, edges_num):
      super(SentenceLevelGCN, self).__init__()
      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


class WordGATLayer(torch.nn.Module):  # GAT for word-level
  def __init__(self, in_dim, out_dim, hidden_size_node, vocab, n_gram, edges_matrix, edges_num, max_length=300):
      super(WordGATLayer, self).__init__()

      self.fc = torch.nn.Linear(in_dim, out_dim, bias=False)
      self.attn_fc = torch.nn.Linear(2*out_dim, 1, bias=False)
      self.reset_parameters()

      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      self.vocab = vocab
      
      self.word_ngram = n_gram   # n-gram for word-level
      self.edges_matrix = edges_matrix
      self.max_length = max_length

      self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)  # Word-level node hidden (embedding).
      self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.100d.w2vformat.txt')))
      self.node_hidden.weight.requires_grad = True
  
  def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_normal_(self.fc.weight, gain=gain)
        torch.nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)

  def load_word2vec(self, word2vec_file):  # Loading embeddings representation from w2v.
      model = word2vec.load(word2vec_file)

      embedding_matrix = []

      for word in self.vocab:
          try:
              embedding_matrix.append(model[word])
          except KeyError:
              embedding_matrix.append(np.zeros(len(model['the'])))  # other unknow words use zeros vector

      embedding_matrix = np.array(embedding_matrix)

      return embedding_matrix

  def edge_attention(self, edges):  # func = lambda edges : F.leaky_relu(self.attn_fc(z2))
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)  # (2*out_dim, 1)
        return {'e': F.leaky_relu(a)}

  def message_func(self, edges):  # Passing two tensors: the transformed z embedding of source node and un-normalized attention score e
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

  def reduce_func(self, nodes):  # Performing two tasks: 1 normalize the attention score e, 2 Aggregate neighbor embeddings weighted by the attention scores
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}  # New embeddings of nodes h'

  def a_edges(self, word_ids: list, old_to_new:dict):  # Add edges for graph.
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(word_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.word_ngram), min(index + self.word_ngram + 1, len(word_ids))):
                dst_word_old = word_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            #self circle
            edges.append([src, src])  # All edges.
            old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id

  def build_graph(self, word_ids):  # Build GAT graph
    if len(word_ids) > self.max_length:
        word_ids = word_ids[ : self.max_length]
    
    local_vocab = set(word_ids)
    old_to_new = dict(zip(local_vocab, range(len(local_vocab))))
    local_vocab = torch.tensor(list(local_vocab)).to(self.device)  # Graph on device, so need local_vocab on same device.

    graph = dgl.DGLGraph()
    graph = graph.to(self.device)

    graph.add_nodes(len(local_vocab)) # Add nodes for graph.
    z = self.fc(self.node_hidden(local_vocab))  # equation (1)
    graph.ndata['z'] = z  # Add node embeddings for graph

    edges, edges_id = self.a_edges(word_ids, old_to_new)
    srcs, dsts = zip(*edges)  # get source and destination nodes id (local id).
    
    graph.add_edges(srcs, dsts)  # Add edges for graph.
    graph.apply_edges(self.edge_attention)  # Update the features of the specified edges, graph.edata['e] = F.leaky_relu(a)

    return graph

  def forward(self, word_id):  # word_id: 3d
      graphs = []  # word-level graphs
      graphs_id = {}  # Dictionary data formar, such as {0:[0,3], 1:[3:6], ...}
      g_id = []  # [0,0,0,1,1,1,...]

      count=0
      for i in range(len(word_id)):
        t_index = [count]
        if len(word_id[i])>0:
          for j in range(len(word_id[i])):
            g_id.append(i)
            graphs.append(self.build_graph(word_id[i][j]))
            count+=1
          t_index.append(count)  # [start, end]
          graphs_id[i]=t_index
        else:
          print('Error! Sentence length must be greater than 0.')
          break

      batch_graph = dgl.batch(graphs)  # batching update

      batch_graph.update_all(self.message_func, self.reduce_func)

      out_w = dgl.sum_nodes(batch_graph, feat='h')  # (num_sen, hidden_size_node)
      
      g_id = torch.tensor(g_id, device=self.device) 
      output_sum = scatter_add(out_w, g_id, dim=0)  # (batch_size, hidden_size_node)

      return out_w, graphs_id, output_sum


class SentenceGATLayer(torch.nn.Module):  # GAT for sentence-level
  def __init__(self, in_dim, out_dim, hidden_size_node, vocab, n_gram, edges_matrix, edges_num):
    super(SentenceGATLayer, self).__init__()

    self.fc = torch.nn.Linear(in_dim, out_dim, bias=False)  # equation (1)
    self.attn_fc = torch.nn.Linear(2 * out_dim, 1, bias=False)  # equation (2)
    self.reset_parameters()

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    self.vocab = vocab  # Prepare for other methods in this class.
    self.dim = hidden_size_node

    self.sen_ngram = 2  # For R8 and R52, ngram=1, because ithey only have 1 sentence, and for 20NG and Ohsumed MR, ngram is 2 or 3...

    self.word_gat = WordGATLayer(self.dim, self.dim, self.dim, vocab, n_gram, edges_matrix, edges_num)  # GCN for wprd-level

  def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_normal_(self.fc.weight, gain=gain)
        torch.nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
  
  def edge_attention(self, edges):  # func = lambda edges : F.leaky_relu(self.attn_fc(z2))
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)  # (2*out_dim, 1)
        return {'e': F.leaky_relu(a)}

  def message_func(self, edges):  # Passing two tensors: the transformed z embedding of source node and un-normalized attention score e
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

  def reduce_func(self, nodes):  # Performing two tasks: 1 normalize the attention score e, 2 Aggregate neighbor embeddings weighted by the attention scores
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}  # New embeddings of nodes h'
  
  def a_edges(self, num_sen):  # Add edges for graph.
        edges = []
        for i in range(num_sen):
            src = i
            for j in range(num_sen):
              if abs(i-j) <=self.sen_ngram:
                dst = j
                # - first connect the new sub_graph
                edges.append([src, dst])

        return torch.tensor(edges).to(self.device)  # If using gpu, the edges tensor should be in the same divice
  
  def build_graph(self, word_hidden):  # Build GAT graph
    
    n_sen = len(word_hidden)  # Number of sentence in each sample

    graph = dgl.DGLGraph()
    graph = graph.to(self.device)

    graph.add_nodes(n_sen) # Add sentence-level nodes for graph.
    z = self.fc(word_hidden)  # equation (1)
    graph.ndata['z'] = z  # Add node embeddings for graph

    edges = self.a_edges(n_sen)
    srcs, dsts = zip(*edges)  # get source and destination nodes id (local id).
    
    graph.add_edges(srcs, dsts)  # Add edges for graph.
    graph.apply_edges(self.edge_attention)  # Update the features of the specified edges, graph.edata['e] = F.leaky_relu(a)

    return graph

  def forward(self, word_id):
    out_sen, g_id, out_sum = self.word_gat(word_id)  # Get w & s

    s_id = list(g_id.keys())  # Sample id: (batch_size)

    graphs = [self.build_graph(out_sen[g_id[i][0]: g_id[i][1]]) for i in s_id ]  # Doc-level graphs.

    batch_graph = dgl.batch(graphs)  # batching update

    batch_graph.update_all(self.message_func, self.reduce_func)  # Using GAT updates graphs

    out_sen = dgl.sum_nodes(batch_graph, feat='h')  # (batch_zise, hidden_size_node)
    
    return out_sen, out_sum  # Sen-level and word-level


class MultiHeadGATLayer(torch.nn.Module):
  def __init__(self, in_dim, out_dim, num_heads, hidden_size_node, vocab, n_gram, edges_matrix, edges_num, merge='mean'):
    super(MultiHeadGATLayer, self).__init__()
    self.heads = torch.nn.ModuleList()
    for i in range(num_heads):
      self.heads.append(SentenceGATLayer(in_dim, out_dim, hidden_size_node, vocab, n_gram, edges_matrix, edges_num))
    self.merge = merge
  
  def forward(self, doc_id):
    head_outs = [attn_head(doc_id) for attn_head in self.heads]
    if self.merge == 'cat':  # Hidden layer
      return torch.cat(head_outs, dim=1)  # (batch_size, hidden_size_node*num_heads)
    else:
      return torch.mean(torch.stack(head_outs, dim=1), dim=1)  #  (batch_size, hidden_size_node)


class DocLevelGCN(torch.nn.Module): # The same as Huang lianzhe, 'Model' parts of paper code, see References.
    def __init__(self, hidden_size_node, vocab, n_gram, edges_matrix, edges_num, max_length=400):
      super(DocLevelGCN, self).__init__()

      self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

      self.vocab = vocab  # Prepare for other methods in this class.

      self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)  # Word-level node hidden (embedding).
      self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.100d.w2vformat.txt')))
      self.node_hidden.weight.requires_grad = True

      self.node_eta = torch.nn.Embedding.from_pretrained(torch.rand(len(vocab), 1), freeze=False)  # Word-level node save itself feature according to eta percent.

      self.edges_matrix = edges_matrix  # All global doc-level edges are stored in the matrix.
      self.edges_num = edges_num  # The number of doc-level edges.
      self.node_edge_w = torch.nn.Embedding.from_pretrained(torch.ones(edges_num, 1), freeze=False)  # Setting trainable parameters for word-level edges.

      self.len_vocab = len(vocab)
      self.d = dict(zip(vocab, range(len(vocab))))  # Set new index for vocab.

      self.ngram = n_gram
      self.max_length = max_length  # Maximum number of doc-level nodes.

    def doc_level_edges(self, word_ids: list, old_to_new:dict):  # Add doc-level edges for word-level graph.
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(word_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(word_ids))):
                dst_word_old = word_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            #edges.append([src, src])  # All edges.
            #old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id
    
    def load_word2vec(self, word2vec_file):  # Loading embeddings representation from w2v.
      model = word2vec.load(word2vec_file)

      embedding_matrix = []

      for word in self.vocab:
          try:
              embedding_matrix.append(model[word])
          except KeyError:
              # print(f'Line 269. The word not in vocab:{word}')
              #unk_word=np.zeros(300)
              #print(f"line 271, model['the].shape: {len(model['the'])}")  # [,300]
              embedding_matrix.append(np.zeros(len(model['the'])))  # other unknow words use zeros vector

      embedding_matrix = np.array(embedding_matrix)

      return embedding_matrix

    def doc_level_graph(self, word_ids):  # Build doc level graph.
        if len(word_ids) > self.max_length:
            word_ids = word_ids[ : self.max_length]
        
        local_vocab = set(word_ids)
        old_to_new = dict(zip(local_vocab, range(len(local_vocab))))

        local_vocab = torch.tensor(list(local_vocab)).to(self.device)  #Graph on device, so need local_vocab on same device.

        graph = dgl.DGLGraph()
        graph = graph.to(self.device)

        graph.add_nodes(len(local_vocab)) # Add nodes for graph.
        graph.ndata['h'] = self.node_hidden(local_vocab)  # Add node features for graph.
        
        eta = torch.sigmoid(self.node_eta(local_vocab))  # Limit to [0-1].
        graph.ndata['eta'] = eta  # Add node eta for graph

        edges, edges_id = self.doc_level_edges(word_ids, old_to_new)
        edges_id = torch.LongTensor(edges_id).to(self.device)  # The edges_id is a 1-d list.
        srcs, dsts = zip(*edges)  # get source and destination nodes id (local id).
        
        graph.add_edges(srcs, dsts)  # Add edges for graph.
        graph.edata['w'] = self.node_edge_w(edges_id) # Add weight for edges.

        return graph

    def doc_level_edges(self, word_ids: list, old_to_new:dict):  # Add doc-level edges for word-level graph.
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(word_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(word_ids))):
                dst_word_old = word_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            edges.append([src, src])  # All edges.
            old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id

    def forward(self, doc_id):  # Update doc-level graph
      graphs = [self.doc_level_graph(doc) for doc in doc_id ]  # Doc-level graphs.

      batch_graph = dgl.batch(graphs)  # batching update
      batch_graph.update_all(
          message_func = dgl.function.src_mul_edge('h', 'w', 'weighted_message'),
          reduce_func = dgl.function.max('weighted_message','M')
      )
      batch_graph.ndata['h'] = batch_graph.ndata['eta']*batch_graph.ndata['h'] + batch_graph.ndata['M'] * (1 - batch_graph.ndata['eta'])  # Update word-level node feature.
      out_doc = dgl.sum_nodes(batch_graph, feat='h')  # Convert doc-level graph to one vector (batch_size, 300)

      return out_doc  # Return doc-level graph representation.


class DocLevelGAT(torch.nn.Module):  # GAT for doc-level graph
  def __init__(self, in_dim, out_dim, hidden_size_node, vocab, n_gram, edges_matrix, max_length=400):
    super(DocLevelGAT, self).__init__()

    self.fc = torch.nn.Linear(in_dim, out_dim, bias=False)
    self.attn_fc = torch.nn.Linear(out_dim*2, 1, bias=False)
    self.reset_parameters()
    
    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    self.vocab = vocab
    self.word_ngram = n_gram
    self.edges_matrix = edges_matrix
    self.max_length = max_length

    self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)
    self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.100d.w2vformat.txt')))
    self.node_hidden.weight.requires_grad = True
  
  def load_word2vec(self, w2v_file):
    model = word2vec.load(w2v_file)
    embedding_matrix = []
    for word in self.vocab:
      try:
        embedding_matrix.append(model[word])
      except KeyError:
        embedding_matrix.append(np.zeros(len(model['the'])))
    embedding_matrix = np.array(embedding_matrix)

    return embedding_matrix

  def reset_parameters(self):
    gain = torch.nn.init.calculate_gain('relu')
    torch.nn.init.xavier_normal_(self.fc.weight, gain=gain)
    torch.nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)

  def edge_attention(self, edges):
    # edge UDF for equation (2)
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e':F.leaky_relu(a)}

  def message_func(self, edges):
    # message UDF for equation (3) & (4)
    return {'z': edges.src['z'], 'e':edges.data['e']}

  def reduce_func(self, nodes):
    alpha = F.softmax(nodes.mailbox['e'], dim=1)  # Normalizing the attention score e
    h =torch.sum(alpha * nodes.mailbox['z'], dim=1)
    return {'h': h}

  def a_edges(self, word_ids: list, old_to_new:dict):  # Add edges for graph
    edges = []
    old_edge_id = []
    for index, src_word_old in enumerate(word_ids):
      src = old_to_new[src_word_old]
      for i in range(max(0, index - self.word_ngram), min(index + self.word_ngram + 1, len(word_ids))):
        dst_word_old = word_ids[i]
        dst = old_to_new[dst_word_old]
        # - first connect the new sub_graph
        edges.append([src, dst])
        # - then get the hidden from parent_graph
        old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])    
      # self circle
      edges.append([src, src])  # All edges.
      old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])  

    return edges, old_edge_id

  def build_graph(self, word_ids):  # Build GAT graph
    if len(word_ids) > self.max_length:
        word_ids = word_ids[ : self.max_length]
    
    local_vocab = set(word_ids)
    old_to_new = dict(zip(local_vocab, range(len(local_vocab))))
    local_vocab = torch.tensor(list(local_vocab)).to(self.device)  # Graph on device, so need local_vocab on same device.

    graph = dgl.DGLGraph()
    graph = graph.to(self.device)

    graph.add_nodes(len(local_vocab)) # Add nodes for graph.
    z = self.fc(self.node_hidden(local_vocab))  # equation (1)
    graph.ndata['z'] = z  # Add node embeddings for graph

    edges, edges_id = self.a_edges(word_ids, old_to_new)
    srcs, dsts = zip(*edges)  # get source and destination nodes id (local id).
    
    graph.add_edges(srcs, dsts)  # Add edges for graph.
    graph.apply_edges(self.edge_attention)  # Update the features of the specified edges, graph.edata['e] = F.leaky_relu(a)

    return graph

  def forward(self, doc_id):  # word_id: 2d, GAT for doc-level
      graphs = [self.build_graph(doc) for doc in doc_id ]  # word-level graphs

      batch_graph = dgl.batch(graphs)  # batching update

      batch_graph.update_all(self.message_func, self.reduce_func)

      out_w = dgl.sum_nodes(batch_graph, feat='h')  # (batch_zise, hidden_size_node)

      return out_w


class HieGNN(torch.nn.Module):  # Paper model
  def __init__(self, class_num, hidden_size_node,vocab, n_gram, drop_out, edges_matrix, edges_num):
    super(HieGNN, self).__init__()

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    self.dim = hidden_size_node

    self.alpha = torch.nn.Parameter(torch.sigmoid(torch.randn(1)))
    self.beta = torch.nn.Parameter(torch.sigmoid(torch.randn(1)))
    self.gamma = torch.nn.Parameter(torch.sigmoid(torch.randn(1)))
    

    #self.docgcn = DocLevelGCN(self.dim, vocab, n_gram, edges_matrix, edges_num)
    self.sengat = SentenceGATLayer(self.dim, self.dim, self.dim, vocab, n_gram, edges_matrix, edges_num)  # GAT for sen-level --> word-level
    self.doclevelgat = DocLevelGAT(self.dim, self.dim, self.dim, vocab, n_gram, edges_matrix)


    self.dropout = torch.nn.Dropout(p=drop_out)  # Output layer.
    self.activation = torch.nn.ReLU()
    self.Linear = torch.nn.Linear(hidden_size_node, class_num)  # y = Wx + b
  
  def forward(self, content, content_doc):  # conten:3-d, content_doc: 2d
    
    s, w = self.sengat(content)  # GAT for sen-level graph
    d = self.doclevelgat(content_doc)  # GAT for doc-level
    #d = self.docgcn(content_doc)
    
    #s = s.unsqueeze(dim=1)
    #d = d.unsqueeze(dim=1)
    #w = w.unsqueeze(dim=1)
    #t_res, pos = torch.max(torch.cat([s, d], dim=1), dim=1)  # Temp result and max value position
    
    res = self.gamma * d + self.beta * s + self.alpha * w  # (batch_size, hidden_size_node)

    drop = self.dropout(res)
    act = self.activation(drop)
    l = self.Linear(act)  # (batch_size, class_num)

    return l


def cal_PMI(dataset: str, window_size=20):  # The point wise mutual information.
    helper = DataHelper(dataset=dataset, mode="train")
    content, _ = helper.get_content()  # function returns (content, label)
    pair_count_matrix = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=int)  # p(i,j)
    word_count =np.zeros(len(helper.vocab), dtype=int)  # 公式中的#W(i)
    print(f'Vocab of dataset: {len(helper.vocab)}')
    
    for sentence in content:  # one  sentence per document
        sentence = sentence.split()  # get the words in a sentence
        for i, word in enumerate(sentence):
            try:
                word_count[helper.d[word]] += 1
            except KeyError:
                continue
            start_index = max(0, i - window_size)
            end_index = min(len(sentence), i + window_size)
            for j in range(start_index, end_index):
                if i == j:
                    continue
                else:
                    target_word = sentence[j]
                    try:
                        pair_count_matrix[helper.d[word], helper.d[target_word]] += 1  # p(i,j)
                    except KeyError:
                        continue
        
    total_count = np.sum(word_count)
    # print(f'line 408:total_count: {total_count}')
    word_count = word_count / total_count
    # print(f'word_count: {word_count}')
    pair_count_matrix = pair_count_matrix / total_count
    
    pmi_matrix = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=float)
    for i in range(len(helper.vocab)):
        for j in range(len(helper.vocab)):
            pmi_matrix[i, j] = np.log(
                pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # The dividend is very small, divide by 0, get the NaN value
            )
    
    pmi_matrix = np.nan_to_num(pmi_matrix)  # replace NaN with zero and infinity with large finite number
    
    pmi_matrix = np.maximum(pmi_matrix, 0.0)  # Less than 0, set 0.

  
    count = 1  # The number of edges.
    edges_mappings = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=int)
    for i in range(len(helper.vocab)):
        for j in range(len(helper.vocab)):
            if pmi_matrix[i, j] > 0:  # The value of PMI, more than 0 is positive correlation.
                edges_mappings[i, j] = count
                count += 1
   
    return edges_mappings, count

NUM_ITER_EVAL = 100
EARLY_STOP_EPOCH = 10

def train(ngram, model_name, drop_out, dataset, total_epoch=1):  # Training function.
    data_helper = DataHelper(dataset, mode='train')

    tmp_data = dataset.split('/')
    dataset_name = tmp_data[len(tmp_data)-1]
    if os.path.exists(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model_name+'.pkl')) and model_name != 'temp_model':
        print('load model from file.')
        model = torch.load(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model_name+'.pkl'))
    else:
        print('Build new model.')
        if model_name == 'temp_model':
            model_name = f'temp_model_{dataset_name}'
        edges_mappings, count = cal_PMI(dataset=dataset)
        
        model = HieGNN(class_num=len(data_helper.labels_str), hidden_size_node=100,
                      vocab=data_helper.vocab, n_gram=ngram, drop_out=drop_out, edges_matrix=edges_mappings, edges_num=count)

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model.to(device)
    loss_func = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), weight_decay=1e-4)

    iter = 0

    best_acc = 0.0
    last_best_epoch = 0
    total_loss = 0.0
    total_correct = 0
    total = 0
    for content, content_doc, label, epoch in data_helper.batch_iter(batch_size=64, num_epoch=total_epoch):
        improved = ''
        model.train()

        logits = model(content, content_doc)
        loss = loss_func(logits, label)

        pred = torch.argmax(logits, dim=1)

        correct = torch.sum(pred == label)

        total_correct += correct
        total += len(label)

        total_loss += loss.item()

        optim.zero_grad()
        loss.backward()
        optim.step()

        iter += 1
        if iter % NUM_ITER_EVAL == 0:

            val_acc = dev(model, dataset=dataset)
            if val_acc > best_acc:
                best_acc = val_acc
                last_best_epoch = epoch
                improved = '*'

                torch.save(model, f'/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model/{model_name}.pkl')

            if epoch - last_best_epoch >= EARLY_STOP_EPOCH:
                print('Early stopping...')
                return model_name
            print(f'Epoch:{epoch}, iter:{iter}, train loss:{total_loss/ NUM_ITER_EVAL :.4f}, train acc: {float(total_correct) / float(total) :.4f}, val acc: {val_acc:.4f},{improved}')

            total_loss = 0.0
            total_correct = 0
            total = 0

    return model_name

def dev(model, dataset):
  data_helper = DataHelper(dataset, mode='dev')

  total_pred =0
  correct = 0

  for content, content_doc, label, _ in data_helper.batch_iter(batch_size=64, num_epoch=1):
    model.eval()

    logits = model(content, content_doc)
    pred = torch.argmax(logits, dim=1)

    correct_pred = torch.sum(pred==label)
    correct += correct_pred
    total_pred += len(content)

  total_pred = float(total_pred)
  correct = correct.float()

  return torch.div(correct, total_pred)

def test(model, dataset):
  model = torch.load(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model', model+'.pkl'))
  data_helper = DataHelper(dataset, mode='test')

  total_pred = 0
  correct = 0
  iter = 0
  for content, content_doc, label, _ in data_helper.batch_iter(batch_size=64, num_epoch=1):
    iter += 1
    model.eval()

    logits = model(content, content_doc)
    pred = torch.argmax(logits, dim=1)

    correct_pred = torch.sum(pred==label)
    correct += correct_pred
    total_pred += len(content)
  
  total_pred = float(total_pred)
  correct = correct.float()

  return torch.div(correct, total_pred)

parser = argparse.ArgumentParser()
parser.add_argument('--ngram', required=False, type=int, default=3, help='word level and doc level n-gram')
parser.add_argument('--name', required=False, type=str, default='temp_model', help='model name')
parser.add_argument('--dropout', required=False, type=float, default=0.5, help='drop out rate')
parser.add_argument('--dataset', required=True, type=str, help='dataset')
parser.add_argument('--rand', required=False, type=int, default=42, help='rand seed')
parser.add_argument('--epoch', required=False, type=int, default=50, help='training epoch')

args = parser.parse_args()
 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # Print running machine
if torch.cuda.is_available():
  print(f'device: {device}')
  print(f'name: {torch.cuda.get_device_name(0)}')
  print(f'memory: {torch.cuda.get_device_properties(0).total_memory/1e9}')
  print(f'*'*50)

SEED = args.rand
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

model = train(args.ngram, args.name, args.dropout, dataset=args.dataset, total_epoch=args.epoch)
print(f'test acc: {test(model, args.dataset).cpu().numpy():.4f}')

Overwriting parsing.py


# Run

In [None]:
!python parsing.py --dataset='/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data/oh' --ngram=2 --dropout=0.5 --epoch=100

Build new model.
Vocab of dataset: 7748
  pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # The dividend is very small, divide by 0, get the NaN value
  pair_count_matrix[i, j] / (word_count[i] * word_count[j])  # The dividend is very small, divide by 0, get the NaN value
Epoch:2, iter:100, train loss:2.7733, train acc: 0.2848, val acc: 0.4688,*
Epoch:4, iter:200, train loss:1.4126, train acc: 0.5933, val acc: 0.5813,*
Epoch:6, iter:300, train loss:0.6553, train acc: 0.8100, val acc: 0.6125,*
Epoch:8, iter:400, train loss:0.3004, train acc: 0.9166, val acc: 0.6500,*
Epoch:10, iter:500, train loss:0.1397, train acc: 0.9644, val acc: 0.6500,
Epoch:12, iter:600, train loss:0.0905, train acc: 0.9791, val acc: 0.6562,*
Epoch:14, iter:700, train loss:0.0625, train acc: 0.9831, val acc: 0.6562,
Epoch:17, iter:800, train loss:0.0427, train acc: 0.9897, val acc: 0.6656,*
Epoch:19, iter:900, train loss:0.0333, train acc: 0.9928, val acc: 0.6469,
Epoch:21, iter:1000, train loss:0.0297,

In [None]:
ga = torch.nn.Parameter(torch.sigmoid(torch.randn(10)))

a = torch.rand(6, 10)
print(f'ga:\n{ga}')

b = ga * a

print(f'b:\n{b}')