In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf
import os
import struct
import hashlib
import os
import re
import json
import string
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, TFBertModel, BertConfig
import tensorflow_hub as hub
import tokenization
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
from rouge_score import rouge_scorer
from sklearn.cluster import KMeans
import pickle
from sknetwork.ranking import PageRank, HITS
from nltk.stem.snowball import EnglishStemmer
from nltk.tokenize import sent_tokenize, word_tokenize
import torch
import sys
import networkx as nx

In [2]:
def tensor_to_string(x):
    return x.numpy().decode('UTF-8')

In [3]:
def get_sent_list(text,stem=None):
    sents = sent_tokenize(text)
    if stem == "None":
        return sents
    if stem == "EnglishStemmer":
        stemmer = EnglishStemmer()
    ans = []
    for sent in sents:
        words = word_tokenize(sent)
        word_stem = [stemmer.stem(w) for w in words]
        ans.append(str(" ".join(word_stem)))
    return ans

In [4]:
## Model features include an encode function -> takes a list of sentences. Returns a list of embeddings (all same dim)
transformers = ["multi-qa-mpnet-base-dot-v1"]

models = dict()
device = torch.device("cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for trans in transformers:
    models[trans] = SentenceTransformer(trans,cache_folder='/mnt/disks/disk-1/data/models')
    models[trans]._target_device = device

In [5]:
## V = list of embeddings. k = target size of summary
## Returns a list of sentence indices

def generate_summary(V, k):
    if k >= len(V):
        return list(range(len(V)))
    k -= 1
    centers = []
    cities = list(range(len(V)))
    centers.append(0)
    cities.remove(0)
    while k!= 0:
        city_dict = {}
        for cty in cities:
            min_dist = float("inf")
            for c in centers:
                min_dist = min(min_dist,np.linalg.norm(V[cty] - V[c]))
            city_dict[cty] = min_dist
        new_center = max(city_dict, key = lambda i: city_dict[i])
        centers.append(new_center)
        cities.remove(new_center)
        k -= 1
    return centers

In [12]:
## Pagerank version

def sim(a, b):
    return ((np.dot(a, b) / np.sqrt(np.dot(a, a) * np.dot(b, b))) + 1) / 2

def generate_summary(V,k):
    if k >= len(V):
        return list(range(len(V)))
    n = V.shape[0]
    adj = np.zeros((n, n))
    for i in range(n):
        adj[i][i] = sim(V[i],V[i])
        for j in range(i+1,n):
            s = sim(V[i], V[j])
            adj[i][j] = s
            adj[j][i] = s
            
    g = nx.complete_graph(len(lst))
    adj = np.zeros((len(lst), len(lst)))
    for i in range(len(lst)):
        for j in range(i, len(lst)):
            s = sim(lst[i], lst[j])
            g.edges[i,j]['weight'] = s

    pr_pair = sorted(
        nx.pagerank(g).items(),
        key=lambda x: x[1],
        reverse=True
    )[:k]

    return np.array([k for k, _ in pr_pair])

In [None]:
## HITS version

def sim(a, b):
    return ((np.dot(a, b) / np.sqrt(np.dot(a, a) * np.dot(b, b))) + 1) / 2

def generate_summary_hits(V,k,threshold=0.5):
    if k >= len(V):
        return list(range(len(V)))
    n = V.shape[0]
    adj = np.zeros((n, n))
    for i in range(n):
        adj[i][i] = sim(V[i],V[i])
        for j in range(i+1,n):
            s = 0 if sim(V[i], V[j]) < threshold else 1
            adj[i][j] = s
            adj[j][i] = s
            
    pr = HITS()
    scores = pr.fit_transform(adj)
    ind = np.argpartition(scores, -k)[-k:]
    return np.sort(ind)

In [13]:
def uml_summary(x,index,kind="cnn_dailymail",model="multi-qa-mpnet-base-dot-v1"):
    if kind == "cnn_dailymail":
        key1 = 'article'
        key2 = 'highlights'
    elif kind == "scientific_papers/arxiv" or kind == "scientific_papers/pubmed":
        key1 = 'article'
        key2 = 'abstract'
        
    stemmer = "EnglishStemmer"
    text = tensor_to_string(x[key1])
    text = get_sent_list(text,stemmer)
    summary = tensor_to_string(x[key2])
    summary = get_sent_list(summary,stemmer)
    
    filename = str(index) + "_" + model + "_" + stemmer + ".pickle"
    folderpath = os.path.join("/mnt/disks/disk-1/data/pickle",kind)
    filepath = os.path.join("/mnt/disks/disk-1/data/pickle",kind,filename)
    
    if os.path.exists(filepath):
        with open(filepath, 'rb') as handle:
            text_emb = pickle.load(handle)
    else:
        text_emb = models[model].encode(text)
        with open(filepath, 'wb') as handle:
            pickle.dump(text_emb, handle, protocol=pickle.HIGHEST_PROTOCOL)

    gen_sum = [text[x] for x in generate_summary(text_emb,len(summary))]
    scores = scorer.score(" ".join(summary)," ".join(gen_sum))
    return scores["rouge1"].fmeasure, scores["rouge2"].fmeasure, scores["rougeL"].fmeasure

In [14]:
%%time

# datasets = ["cnn_dailymail","scientific_papers/arxiv","scientific_papers/pubmed"]
datasets = ["cnn_dailymail"]
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2','rougeL'], use_stemmer=True)
for ds in datasets:
    for trans in transformers:
        train, val, test = tfds.load(name=ds, 
                              split=["train", "validation", "test"], 
                              data_dir="/mnt/disks/disk-1/data")
        
#         model = SentenceTransformer(trans,cache_folder='/mnt/disks/disk-1/data/models')
#         model._target_device = device
        r1 = []
        r2 = []
        rl = []
        index = 0
        for x in list(test):
            r1_val,r2_val,rl_val = uml_summary(x,index,kind=ds,model=trans)
            index += 1
            r1.append(r1_val)
            r2.append(r2_val)
            rl.append(rl_val)
            if index % 1000 == 0 and index > 0:
                print(index)
        print(ds,trans)
        print(index)
        print("Rouge 1 : ",np.round(np.mean(np.asarray(r1))*100,2))
        print("Rouge 2 : ",np.round(np.mean(np.asarray(r2))*100,2))
        print("Rouge L : ",np.round(np.mean(np.asarray(rl))*100,2))
        print("___")

INFO:absl:Load dataset info from /mnt/disks/disk-1/data/cnn_dailymail/3.1.0
INFO:absl:Reusing dataset cnn_dailymail (/mnt/disks/disk-1/data/cnn_dailymail/3.1.0)
INFO:absl:Constructing tf.data.Dataset cnn_dailymail for split ['train', 'validation', 'test'], from /mnt/disks/disk-1/data/cnn_dailymail/3.1.0


1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
cnn_dailymail multi-qa-mpnet-base-dot-v1
11490
Rouge 1 :  33.01
Rouge 2 :  12.01
Rouge L :  20.87
___
CPU times: user 4min 52s, sys: 1.11 s, total: 4min 53s
Wall time: 4min 51s
