In [1]:
import re
import os
from nltk.corpus import stopwords
import glob
import copy
import random
import time
import json
import pickle
import nltk
import collections
from collections import Counter
from itertools import combinations
import numpy as np
from random import shuffle
import torch
import argparse
import time
from transformers import BertTokenizer, BertModel
import networkx as nx
import pickle as pkl
import dgl 
from module.vocabulary import Vocab

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = '../cnndm'
cache_dir = '../cache/CNNDM'
DATA_FILE = os.path.join(data_dir, "index_to_file_mapping_train.json")
VALID_FILE = os.path.join(data_dir, "index_to_file_mapping_val.json")
VOCAL_FILE = os.path.join(cache_dir, "vocab")
FILTER_WORD = os.path.join(cache_dir, "filter_word.txt")
train_w2s_path = os.path.join(cache_dir, "index_to_file_mapping_train.json")
val_w2s_path = os.path.join(cache_dir, "index_to_file_mapping_val.json")
    
# defaults
vocab_size = 50000
doc_max_timesteps = 50
sent_max_len = 100

# vocab
vocab = Vocab(VOCAL_FILE, vocab_size)

# filterwords
FILTERWORD = stopwords.words('english')
punctuations = [',', '.', ':', ';', '?', '(', ')', '[', ']', '&', '!', '*', '@', '#', '$', '%', '\'\'', '\'', '`', '``',
                '-', '--', '|', '\/']
FILTERWORD.extend(punctuations)

In [3]:
# utils
def readJson(fname):
    data = []
    with open(fname, encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line))
    return data


def readText(fname):
    data = []
    with open(fname, encoding="utf-8") as f:
        for line in f:
            data.append(line.strip())
    return data

In [4]:
class Example(object):
    def __init__(self, article_sents, abstract_sents, vocab, sent_max_len, label):
        self.sent_max_len = sent_max_len
        self.enc_sent_len = []
        self.enc_sent_input = []
        self.enc_sent_input_pad = []

        # Store the original strings
        self.original_article_sents = article_sents
        self.original_abstract = "\n".join(abstract_sents)

        # Process the article
        if isinstance(article_sents, list) and isinstance(article_sents[0], list):  # multi document
            self.original_article_sents = []
            for doc in article_sents:
                self.original_article_sents.extend(doc)
        for sent in self.original_article_sents:
            article_words = sent.split()
            self.enc_sent_len.append(len(article_words))  # store the length before padding
            self.enc_sent_input.append([vocab.word2id(w.lower()) for w in article_words])  # list of word ids; OOVs are represented by the id for UNK token
        self._pad_encoder_input(vocab.word2id('[PAD]'))

        # Store the label
        self.labels = np.array([1 if i in label else 0 for i in range(len(self.original_article_sents))])
        
    def _pad_encoder_input(self, pad_id):
        """
        :param pad_id: int; token pad id
        :return: 
        """
        max_len = self.sent_max_len
        for i in range(len(self.enc_sent_input)):
            article_words = self.enc_sent_input[i].copy()
            if len(article_words) > max_len:
                article_words = article_words[:max_len]
            if len(article_words) < max_len:
                article_words.extend([pad_id] * (max_len - len(article_words)))
            self.enc_sent_input_pad.append(article_words)

In [5]:
class ExampleSet(torch.utils.data.Dataset):
    def __init__(self, data_path, vocab, doc_max_timesteps, sent_max_len, filter_word_path, w2s_path, num_topics):

        self.vocab = vocab
        self.sent_max_len = sent_max_len
        self.doc_max_timesteps = doc_max_timesteps
        self.num_topics = num_topics
        with open(data_path, "r", encoding="utf-8") as f:
            self.example_list = json.load(f)
        self.size = len(self.example_list)

        tfidf_w = readText(filter_word_path)
        self.filterwords = FILTERWORD
        self.filterids = [vocab.word2id(w.lower()) for w in FILTERWORD]
        self.filterids.append(vocab.word2id("[PAD]"))   # keep "[UNK]" but remove "[PAD]"
        lowtfidf_num = 0
        for w in tfidf_w:
            if vocab.word2id(w) != vocab.word2id('[UNK]'):
                self.filterwords.append(w)
                self.filterids.append(vocab.word2id(w))
                lowtfidf_num += 1
            if lowtfidf_num > 5000:
                break
        self.filterids = list(set(self.filterids))
        self.filterwords = list(set(self.filterwords))
        with open(w2s_path, "r") as f:
            self.w2s_tfidf = json.load(f)  
            
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')      
            
    def get_example(self, index):
        file_name, new_index  = self.example_list[str(index)]
        e = readJson(file_name)
        e = e[new_index]
        e["summary"] = e.setdefault("summary", [])
        self.text = e['text']
        example = Example(e["text"], e["summary"], self.vocab, self.sent_max_len, e["label"])
        return example
    
    def get_bow_rep(self, input_pad):
        vocab_len = self.vocab.size()
        num_sents = input_pad.shape[0]
        bow_rep = np.zeros((num_sents, vocab_len))
        
        for i, sent in enumerate(input_pad):
            for word in sent:
                if word not in self.filterids:
                    bow_rep[i][word]+=1
        row_norms = np.linalg.norm(bow_rep, axis=1, keepdims=True)
        return bow_rep/row_norms

    def get_bert_tokenizer(self):
        inputs = self.tokenizer(self.text, padding=True, truncation=True, return_tensors='pt')
        return inputs
    
    def AddTopicNode(self, G, num_topics):
        wid2nid = {}
        nid2wid = {}
        nid = 0
        for k in range(num_topics):
            node_id = k 
            wid2nid[node_id] = nid
            nid2wid[nid] = node_id
            nid += 1

        G.add_nodes(nid)
        G.ndata["unit"] = torch.zeros(nid)
        G.ndata["dtype"] = torch.zeros(nid)
        G.ndata['id']= torch.LongTensor(list(nid2wid.values()))
        return wid2nid, nid2wid 
    
    def create_graph(self, input_pad, bow_rep, inputs, labels, num_topics):
        G = dgl.graph(([], []))
        _, _ = self.AddTopicNode(G, num_topics)
        N = len(input_pad)
        G.add_nodes(N)
        G.ndata["unit"][num_topics:] = torch.ones(N)
        G.ndata["dtype"][num_topics:] = torch.ones(N)
        sentids = [i+num_topics for i in range(N)]
        G.nodes[sentids].data['bert_input_ids'] = inputs['input_ids']
        G.nodes[sentids].data['bert_attention_mask'] = inputs['attention_mask']
        G.nodes[sentids].data['bert_token_type_ids'] = inputs['token_type_ids']
        G.nodes[sentids].data['bow'] = bow_rep
        G.nodes[sentids].data['label'] = labels
        for i in range(N): 
            G.nodes[i+num_topics].data['id'] = torch.LongTensor([i])
            
        for i in range(num_topics):
            for j in range(N):
                G.add_edge(i, j+num_topics, data={'tfidfembed': torch.tensor(1.0), 'dtype': torch.tensor(0.0)})
                G.add_edge(j+num_topics, i, data={'tfidfembed': torch.tensor(1.0), 'dtype': torch.tensor(0.0)})
        return G
    
    def checker(self, index):
        item = self.get_example(index)
        input_pad = np.array(item.enc_sent_input_pad[:self.doc_max_timesteps])
        bow_rep = torch.tensor(self.get_bow_rep(input_pad), dtype=torch.float32)
        labels = torch.tensor(item.labels, dtype=torch.int32)
        inputs = self.get_bert_tokenizer()
        G = self.create_graph(input_pad, bow_rep, inputs, labels, self.num_topics)
        return G
        
    def __getitem__(self, index):
        item = self.get_example(index)
        input_pad = np.array(item.enc_sent_input_pad[:self.doc_max_timesteps])
        bow_rep = torch.tensor(self.get_bow_rep(input_pad), dtype=torch.float32)
        labels = torch.tensor(item.labels, dtype=torch.int32)
        inputs = self.get_bert_tokenizer()
        G = self.create_graph(input_pad, bow_rep, inputs, labels, self.num_topics)
        return G
    
    def __len__(self):
        return self.size

In [6]:
dataset = ExampleSet(data_path=DATA_FILE, vocab=vocab, doc_max_timesteps=doc_max_timesteps, sent_max_len=sent_max_len, filter_word_path=FILTER_WORD, w2s_path=train_w2s_path, num_topics=5)
G = dataset.checker(0)
G

Graph(num_nodes=21, num_edges=0,
      ndata_schemes={'unit': Scheme(shape=(), dtype=torch.float32), 'dtype': Scheme(shape=(), dtype=torch.float32), 'id': Scheme(shape=(), dtype=torch.int64), 'bert_input_ids': Scheme(shape=(36,), dtype=torch.int64), 'bert_attention_mask': Scheme(shape=(36,), dtype=torch.int64), 'bert_token_type_ids': Scheme(shape=(36,), dtype=torch.int64), 'bow': Scheme(shape=(50000,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int32)}
      edata_schemes={})

In [8]:
with open('sample_graph.pkl', 'wb') as f:
    pkl.dump(G, f)

In [None]:
    G.add_edges([0, 1, 2, 3, 4], [5, 6, 7, 8, 9])