<a href="https://colab.research.google.com/github/hshuai97/Colab20210803/blob/main/LianzheHuang_TextLevelGNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. 将Lianzhe Huang原始github上多个.py文件整合为一个.ipynb文件
2. 源代码地址: [source link](https://github.com/mojave-pku/TextLevelGCN)

In [1]:
# 安装colab中没有的库
# 安装deep graph labrary: https://www.dgl.ai/pages/start.html
import torch
try:
  import dgl
except ModuleNotFoundError:
  # Installing torch geometric packages with specific CUDA
  CUDA = 'cu' + torch.version.cuda.replace('.','')
  !pip install dgl-{CUDA} -f https://data.dgl.ai/wheels/repo.html

# 安装word2vec
try:
  import word2vec  # type: ignore
except ModuleNotFoundError:
  !pip install word2vec # type: ignore
  import word2vec  # type: ignore

Looking in links: https://data.dgl.ai/wheels/repo.html
Collecting dgl-cu111
  Downloading https://data.dgl.ai/wheels/dgl_cu111-0.7.2-cp37-cp37m-manylinux1_x86_64.whl (165.0 MB)
[K     |████████████████████████████████| 165.0 MB 39 kB/s 
Installing collected packages: dgl-cu111
Successfully installed dgl-cu111-0.7.2
Collecting word2vec
  Downloading word2vec-0.11.1.tar.gz (42 kB)
[K     |████████████████████████████████| 42 kB 557 kB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: word2vec
  Building wheel for word2vec (PEP 517) ... [?25l[?25hdone
  Created wheel for word2vec: filename=word2vec-0.11.1-py2.py3-none-any.whl size=156420 sha256=1454d09ac75b8cd542ee6f7c67e4131b95dc96cd5fa166052efebe55e93f2361
  Stored in directory: /root/.cache/pip/wheels/c9/c0/d4/29d797817e268124a32b6cf8beb8b8fe87b86f099d5a049e61
Successfully 

# Parsing.py

In [4]:
%%writefile parsing.py
# build graph
import torch
import dgl

# dada_helper
import os
import csv

# model
import torch.nn.functional as F
import numpy as np

import word2vec

# train
import tqdm
import sys, random
import argparse
import time, datetime

# ------------------------------------------------------------------------buildGraph.py------------------------------------------------------------------------
class GraphBuilder(object):
    def __init__(self, words, hiddenSizeNode):
        self.graph = dgl.DGLGraph()
        self.word2id = dict(zip(words, range(len(words))))
        self.graph.add_nodes(len(words))

        # add hidden para for nodes.
        self.graph.ndata['h'] = torch.nn.Parameter(
            torch.Tensor(len(words), hiddenSizeNode)
        )

        # all node are supposed to connected.
        # warning: self-connected enabled.
        for i in range(len(words)):
            self.graph.add_edges(i, range(0, len(words)))

        # add hidden para for edges. Only edge weight (size = 1 )
        self.graph.edata['h'] = torch.nn.Parameter(
            torch.Tensor(self.graph.number_of_edges(), 1)

        )


# ------------------------------------------------------------------------data_helper.py------------------------------------------------------------------------
class DataHelper(object):
    def __init__(self, dataset, mode='train', vocab=None):
        allowed_data = ['r8', '20ng', 'r52', 'mr', 'oh']
        tmp_data = dataset.split('/')
        dataset = tmp_data[len(tmp_data)-1]

        if dataset not in allowed_data:
            raise ValueError('currently allowed data: %s' % ','.join(allowed_data))
        else:
            self.dataset = dataset

        self.mode = mode

        self.base = os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data', self.dataset)  # 'data/r8'

        self.current_set = os.path.join(self.base, '%s-%s-stemmed.txt' % (self.dataset, self.mode))

        with open(os.path.join(self.base, 'label.txt')) as f:
            labels = f.read()
        self.labels_str = labels.split('\n')

        content, label = self.get_content()

        self.label = self.label_to_onehot(label)
        if vocab is None:
            self.vocab = []

            try:
                self.get_vocab()
            except FileNotFoundError:
                self.build_vocab(content, min_count=5)
        else:
            self.vocab = vocab

        self.d = dict(zip(self.vocab, range(len(self.vocab))))

        self.content = [list(map(lambda x: self.word2id(x), doc.split(' '))) for doc in content]

    def label_to_onehot(self, label_str):

        return [self.labels_str.index(l) for l in label_str]

    def get_content(self):
        with open(self.current_set) as f:
            all = f.read()
            content = [line.split('\t') for line in all.split('\n')]
        if self.dataset == '20ng' or 'r52':
            cleaned = []
            for i, pair in enumerate(content):
                if len(pair) < 2:
                    # print(i, pair)
                    pass
                else:
                    cleaned.append(pair)
        else:
            cleaned = content

        label, content = zip(*cleaned)

        return content, label

    def word2id(self, word):
        try:
            result = self.d[word]
        except KeyError:
            result = self.d['UNK']

        return result

    def get_vocab(self):
        with open(os.path.join(self.base, 'vocab-5.txt')) as f:
            vocab = f.read()
            self.vocab = vocab.split('\n')

    def build_vocab(self, content, min_count=5):
        vocab = []

        for c in content:
            words = c.split(' ')
            for word in words:
                if word not in vocab:
                    vocab.append(word)

        freq = dict(zip(vocab, [0 for i in range(len(vocab))]))

        for c in content:
            words = c.split(' ')
            for word in words:
                freq[word] += 1

        results = []
        for word in freq.keys():
            if freq[word] < min_count:
                continue
            else:
                results.append(word)

        results.insert(0, 'UNK')
        with open(os.path.join(self.base, 'vocab-5.txt'), 'w') as f:
            f.write('\n'.join(results))

        self.vocab = results

    def count_word_freq(self, content):  # 函数未使用
        freq = dict(zip(self.vocab, [0 for i in range(len(self.vocab))]))

        for c in content:
            words = c.split(' ')
            for word in words:
                freq[word] += 1

        with open(os.path.join(self.base, 'freq.csv'), 'w') as f:
            writer = csv.writer(f)
            results = list(zip(freq.keys(), freq.values()))
            writer.writerows(results)

    def batch_iter(self, batch_size, num_epoch):
        for i in range(num_epoch):
            num_per_epoch = int(len(self.content) / batch_size)
            for batch_id in range(num_per_epoch):
                start = batch_id * batch_size
                end = min((batch_id + 1) * batch_size, len(self.content))

                content = self.content[start:end]
                label = self.label[start:end]

                device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
                y = torch.tensor(label).to(device)  # gup内存上的label

                yield content, y, i 


# ------------------------------------------------------------------------model.py------------------------------------------------------------------------
def gcn_msg(edge):
    print('Line 181 has been run.')
    return {'m': edge.src['h'], 'w': edge.data['w']}


def gcn_reduce(node):
    w = node.mailbox['w']

    new_hidden = torch.mul(w, node.mailbox['m'])

    new_hidden,_ = torch.max(new_hidden, 1)

    node_eta = torch.sigmoid(node.data['eta'])
    # node_eta = F.leaky_relu(node.data['eta'])

    # new_hidden = node_eta * node.data['h'] + (1 - node_eta) * new_hidden
    # print(new_hidden.shape)
    print('Line 197 has been run.')

    return {'h': new_hidden}


class Model(torch.nn.Module):
    def __init__(self,
                 class_num,
                 hidden_size_node,
                 vocab,
                 n_gram,
                 drop_out,
                 edges_num,
                 edges_matrix,
                 max_length=350,
                 trainable_edges=True,
                 pmi=None,
                 ):
        super(Model, self).__init__()

        self.vocab = vocab
        # print(len(vocab))
        self.seq_edge_w = torch.nn.Embedding(edges_num, 1)
        print(f'edges_num: {edges_num}')
        print(f'pmi.shape: {pmi.shape}')

        self.node_hidden = torch.nn.Embedding(len(vocab), hidden_size_node)
        
        self.seq_edge_w = torch.nn.Embedding.from_pretrained(pmi, freeze=True)
            
        self.edges_num = edges_num
        if trainable_edges:
            self.seq_edge_w = torch.nn.Embedding.from_pretrained(torch.ones(edges_num, 1), freeze=False)
        else:
            self.seq_edge_w = torch.nn.Embedding.from_pretrained(pmi, freeze=True)

        self.hidden_size_node = hidden_size_node

        self.node_hidden.weight.data.copy_(torch.tensor(self.load_word2vec('/content/drive/MyDrive/Colab_Notebooks/DATA/glove.6B/glove.6B.300d.w2vformat.txt')))
        self.node_hidden.weight.requires_grad = True

        self.len_vocab = len(vocab)

        self.ngram = n_gram

        self.d = dict(zip(self.vocab, range(len(self.vocab))))

        self.max_length = max_length

        self.edges_matrix = edges_matrix

        self.dropout = torch.nn.Dropout(p=drop_out)

        self.activation = torch.nn.ReLU()

        self.Linear = torch.nn.Linear(hidden_size_node, class_num, bias=True)

    def word2id(self, word):
        try:
            result = self.d[word]
        except KeyError:
            result = self.d['UNK']

        return result

    def load_word2vec(self, word2vec_file):
        model = word2vec.load(word2vec_file)

        embedding_matrix = []

        for word in self.vocab:
            try:
                embedding_matrix.append(model[word])
            except KeyError:
                # print(f'Line 269. The word not in vocab:{word}')
                #unk_word=np.zeros(300)
                #print(f"line 271, model['the].shape: {len(model['the'])}")  # [,300]
                embedding_matrix.append(np.zeros(len(model['the'])))  # 其他词用0向量代替

        embedding_matrix = np.array(embedding_matrix)

        return embedding_matrix

    def add_all_edges(self, doc_ids: list, old_to_new: dict):
        edges = []
        old_edge_id = []

        local_vocab = list(set(doc_ids))

        for i, src_word_old in enumerate(local_vocab):
            src = old_to_new[src_word_old]
            for dst_word_old in local_vocab[i:]:
                dst = old_to_new[dst_word_old]
                edges.append([src, dst])
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            edges.append([src, src])
            old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id

    def add_seq_edges(self, doc_ids: list, old_to_new: dict):
        edges = []
        old_edge_id = []
        for index, src_word_old in enumerate(doc_ids):
            src = old_to_new[src_word_old]
            for i in range(max(0, index - self.ngram), min(index + self.ngram + 1, len(doc_ids))):
                dst_word_old = doc_ids[i]
                dst = old_to_new[dst_word_old]

                # - first connect the new sub_graph
                edges.append([src, dst])
                # - then get the hidden from parent_graph
                old_edge_id.append(self.edges_matrix[src_word_old, dst_word_old])

            # self circle
            edges.append([src, src])
            old_edge_id.append(self.edges_matrix[src_word_old, src_word_old])

        return edges, old_edge_id

    def seq_to_graph(self, doc_ids: list) -> dgl.DGLGraph():  # function annotation
        if len(doc_ids) > self.max_length:
            doc_ids = doc_ids[:self.max_length]

        local_vocab = set(doc_ids)

        old_to_new = dict(zip(local_vocab, range(len(local_vocab))))

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        local_vocab = torch.tensor(list(local_vocab)).to(device)

        sub_graph = dgl.DGLGraph()  # 定义子图, sub_graph = dgl.DGLGraph()
        sub_graph = sub_graph.to(device)  # 将cpu上的图也放到device上

        sub_graph.add_nodes(len(local_vocab))  # 添加节点
        local_node_hidden = self.node_hidden(local_vocab)

        sub_graph.ndata['h'] = local_node_hidden  # Return a node data view for setting/getting node features

        seq_edges, seq_old_edges_id = self.add_seq_edges(doc_ids, old_to_new)

        edges, old_edge_id = [], []
        # edges = []

        edges.extend(seq_edges)

        old_edge_id.extend(seq_old_edges_id)

        old_edge_id = torch.LongTensor(old_edge_id).to(device)

        srcs, dsts = zip(*edges)
        sub_graph.add_edges(srcs, dsts)
        try:
            seq_edges_w = self.seq_edge_w(old_edge_id)
        except RuntimeError:
            print(old_edge_id)
        sub_graph.edata['w'] = seq_edges_w

        return sub_graph

    def forward(self, doc_ids, is_20ng=None):
        sub_graphs = [self.seq_to_graph(doc) for doc in doc_ids]

        batch_graph = dgl.batch(sub_graphs)

        batch_graph.update_all(
            message_func=dgl.function.src_mul_edge('h', 'w', 'weighted_message'),
            reduce_func=dgl.function.max('weighted_message', 'h')
        )

        h1 = dgl.sum_nodes(batch_graph, feat='h')

        drop1 = self.dropout(h1)
        act1 = self.activation(drop1)

        l = self.Linear(act1)

        return l


# ------------------------------------------------------------------------pmi.py------------------------------------------------------------------------
def cal_PMI(dataset: str, window_size=20):
    helper = DataHelper(dataset=dataset, mode="train")
    content, _ = helper.get_content()
    pair_count_matrix = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=int)
    word_count =np.zeros(len(helper.vocab), dtype=int)
    
    for sentence in content:
        sentence = sentence.split(' ')
        for i, word in enumerate(sentence):
            try:
                word_count[helper.d[word]] += 1
            except KeyError:
                continue
            start_index = max(0, i - window_size)
            end_index = min(len(sentence), i + window_size)
            for j in range(start_index, end_index):
                if i == j:
                    continue
                else:
                    target_word = sentence[j]
                    try:
                        pair_count_matrix[helper.d[word], helper.d[target_word]] += 1
                    except KeyError:
                        continue
        
    total_count = np.sum(word_count)
    word_count = word_count / total_count
    pair_count_matrix = pair_count_matrix / total_count
    
    pmi_matrix = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=float)
    for i in range(len(helper.vocab)):
        for j in range(len(helper.vocab)):
            pmi_matrix[i, j] = np.log(
                pair_count_matrix[i, j] / (word_count[i] * word_count[j]) 
            )
    
    pmi_matrix = np.nan_to_num(pmi_matrix)
    
    pmi_matrix = np.maximum(pmi_matrix, 0.0)

    edges_weights = [0.0]
    count = 1
    edges_mappings = np.zeros((len(helper.vocab), len(helper.vocab)), dtype=int)
    for i in range(len(helper.vocab)):
        for j in range(len(helper.vocab)):
            if pmi_matrix[i, j] != 0:
                edges_weights.append(pmi_matrix[i, j])
                edges_mappings[i, j] = count
                count += 1

    edges_weights = np.array(edges_weights)

    edges_weights = edges_weights.reshape(-1, 1)
    # print(edges_weights.shape)
    edges_weights = torch.Tensor(edges_weights)
    
    return edges_weights, edges_mappings, count


# ------------------------------------------------------------------------train.py------------------------------------------------------------------------
NUM_ITER_EVAL = 100
EARLY_STOP_EPOCH = 25


def edges_mapping(vocab_len, content, ngram):
    count = 1
    mapping = np.zeros(shape=(vocab_len, vocab_len), dtype=np.int32)
    for doc in content:
        for i, src in enumerate(doc):
            for dst_id in range(max(0, i-ngram), min(len(doc), i+ngram+1)):
                dst = doc[dst_id]

                if mapping[src, dst] == 0:
                    mapping[src, dst] = count
                    count += 1

    for word in range(vocab_len):
        mapping[word, word] = count
        count += 1

    return count, mapping


def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return datetime.timedelta(seconds=int(round(time_dif)))


def dev(model, dataset):
    data_helper = DataHelper(dataset, mode='dev')

    total_pred = 0
    correct = 0
    iter = 0
    for content, label, _ in data_helper.batch_iter(batch_size=64, num_epoch=1):
        iter += 1
        model.eval()

        logits = model(content)
        pred = torch.argmax(logits, dim=1)

        correct_pred = torch.sum(pred == label)

        correct += correct_pred
        total_pred += len(content)

    total_pred = float(total_pred)
    correct = correct.float()
    # print(torch.div(correct, total_pred))
    return torch.div(correct, total_pred)


def test(model_name, dataset):
    model = torch.load(os.path.join('.', model_name + '.pkl'))

    data_helper = DataHelper(dataset, mode='test')

    total_pred = 0
    correct = 0
    iter = 0
    for content, label, _ in data_helper.batch_iter(batch_size=64, num_epoch=1):
        iter += 1
        model.eval()

        logits = model(content)
        pred = torch.argmax(logits, dim=1)

        correct_pred = torch.sum(pred == label)

        correct += correct_pred
        total_pred += len(content)

    total_pred = float(total_pred)
    correct = correct.float()
    # print(torch.div(correct, total_pred))
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    x = torch.div(correct, total_pred)
    x = x.to(device)
    return x


def train(ngram, model_name, bar, drop_out, dataset, edges=True, total_epoch=1):
    print('load data helper.')
    data_helper = DataHelper(dataset, mode='train')

    tmp_data = dataset.split('/')
    dataset_name = tmp_data[len(tmp_data)-1]
    if os.path.exists(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/temp_model', model_name+'.pkl')) and model_name != 'temp_model':
        print('load model from file.')
        model = torch.load(os.path.join('/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/temp_model', model_name+'.pkl'))
    else:
        print('new model.')
        if model_name == 'temp_model':
            model_name = f'/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/model/temp_model_{dataset_name}'
        # edges_num, edges_matrix = edges_mapping(len(data_helper.vocab), data_helper.cont ent, ngram)
        edges_weights, edges_mappings, count = cal_PMI(dataset=dataset)
        
        model = Model(class_num=len(data_helper.labels_str), hidden_size_node=300,
                      vocab=data_helper.vocab, n_gram=ngram, drop_out=drop_out, edges_matrix=edges_mappings, edges_num=count,
                      trainable_edges=edges, pmi=edges_weights)

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model.to(device)
    loss_func = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), weight_decay=1e-6)

    iter = 0
    if bar:
        pbar = tqdm.tqdm(total=NUM_ITER_EVAL)
    best_acc = 0.0
    last_best_epoch = 0
    start_time = time.time()
    total_loss = 0.0
    total_correct = 0
    total = 0
    for content, label, epoch in data_helper.batch_iter(batch_size=32, num_epoch=total_epoch):
        improved = ''
        model.train()

        logits = model(content)
        loss = loss_func(logits, label)

        pred = torch.argmax(logits, dim=1)

        correct = torch.sum(pred == label)

        total_correct += correct
        total += len(label)

        total_loss += loss.item()

        optim.zero_grad()
        loss.backward()
        optim.step()

        iter += 1
        if bar:
            pbar.update()
        if iter % NUM_ITER_EVAL == 0:
            if bar:
                pbar.close()

            val_acc = dev(model, dataset=dataset)
            if val_acc > best_acc:
                best_acc = val_acc
                last_best_epoch = epoch
                improved = '*'

                torch.save(model, model_name + '.pkl')

            if epoch - last_best_epoch >= EARLY_STOP_EPOCH:
                print('Early stopping...')
                return model_name
            print(f'Epoch:{epoch}, iter:{iter}, train loss:{total_loss/ NUM_ITER_EVAL :.4f}, train acc: {float(total_correct) / float(total) :.4f}, val acc: {val_acc:.4f},{improved}')

            total_loss = 0.0
            total_correct = 0
            total = 0
            if bar:
                pbar = tqdm.tqdm(total=NUM_ITER_EVAL)

    return model_name


def word_eval():
    print('load model from file.')
    data_helper = DataHelper('r8')
    edges_num, edges_matrix = edges_mapping(len(data_helper.vocab), data_helper.content, 1)
    model = torch.load(os.path.join('word_eval_1.pkl'))

    edges_weights = model.seq_edge_w.weight.to('cpu').detach().numpy()

    core_word = 'billion'
    core_index = data_helper.vocab.index(core_word)

    results = {}
    for i in range(len(data_helper.vocab)):
        word = data_helper.vocab[i]
        n_word = edges_matrix[i, core_index]
        # n_word = edges_matrix[i, i]
        if n_word != 0:
            results[word] = edges_weights[n_word][0]

    sort_results = sorted(results.items(), key=lambda d: d[1])

    print(f'sort_results: {sort_results}')



parser = argparse.ArgumentParser()
parser.add_argument('--ngram', required=False, type=int, default=4, help='ngram number')
parser.add_argument('--name', required=False, type=str, default='temp_model', help='project name')
parser.add_argument('--bar', required=False, type=int, default=0, help='1 denotes show bar, 0 denotes off')
parser.add_argument('--dropout', required=False, type=float, default=0.5, help='dropout rate')
parser.add_argument('--dataset', required=True, type=str, help='dataset')
parser.add_argument('--edges', required=False, type=int, default=1, help='trainable edges')
parser.add_argument('--rand', required=False, type=int, default=7, help='rand_seed')
parser.add_argument('--epoch', required=False, type=int, default=1, help='training epoch')

args = parser.parse_args()

device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
if torch.cuda.is_available():
  print(f'device: {device}')
  print(f'name:{torch.cuda.get_device_name(0)}')
  print(f'memory:{torch.cuda.get_device_properties(0).total_memory/1e9}')
print(f'*'*50)

print('ngram: %d' % args.ngram)
print('project_name: %s' % args.name)
print('dataset: %s' % args.dataset)
print('trainable_edges: %s' % args.edges)
print(f'*'*50)
# #
SEED = args.rand
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if args.bar == 1:
    bar = True
else:
    bar = False

if args.edges == 1:
    edges = True
    print('trainable edges is True')
else:
    edges = False

model = train(args.ngram, args.name, bar, args.dropout, dataset=args.dataset, edges=edges, total_epoch=args.epoch)
print(f'test acc: {test(model, args.dataset).cpu().numpy():.4f}')

Overwriting parsing.py


# run

In [None]:
!python parsing.py --dataset='/content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data/oh' --ngram=3 --dropout=0.25 --epoch=100

DGL backend not selected or invalid.  Assuming PyTorch for now.
Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Using backend: pytorch
device: cuda:0
name:Tesla K80
memory:11.996954624
**************************************************
ngram: 3
project_name: temp_model
dataset: /content/drive/MyDrive/Colab_Notebooks/CODE/TextLevelGNN/data/oh
trainable_edges: 1
**************************************************
trainable edges is True
load data helper.
new model.
  pair_count_matrix[i, j] / (word_count[i] * word_count[j])
  pair_count_matrix[i, j] / (word_count[i] * word_count[j])
