In [None]:
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from datetime import date, timedelta
from scipy.sparse import vstack
from tqdm.notebook import tqdm

from PDSum_components import *
from PDSum_evaluation import *

In [2]:
GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device) # change allocation of current GPU

## Loading dataset

In [3]:
#WCEP
df_org = pd.read_json("datasets/WCEP_EMDS_articles.json")
stories = pd.read_json("datasets/WCEP_EMDS_reference_summaries.json")

In [23]:
#W2E
df_org = pd.read_json("datasets/W2E_EMDS_articles.json")
stories = pd.read_json("datasets/W2E_EMDS_reference_summaries.json")

### Initialization

In [None]:
target_df = df_org # df_org[:100] for the first 100 articles for test
target_df, masked_tensors, masks, all_vocab = initialize(target_df)

In [8]:
mean_embds = torch.div(masked_tensors.sum(1),(1-masks).sum(1).reshape(-1,1)).cpu().detach().numpy()

df_org['mean_embd'] = list(mean_embds)
tfidf_vectorizer = TfidfVectorizer(ngram_range=(1,2), tokenizer=lambda x: x, lowercase=False, norm=None)
tfidf_vectorizer.fit_transform([sum(k, []) for k in df_org['sentence_tokens']])
all_vocab = tfidf_vectorizer.get_feature_names()

count_vectorizer = CountVectorizer(tokenizer=lambda x: x, ngram_range = (1,2), vocabulary = list(all_vocab), lowercase=False)
df_org['sentence_TFs'] = [count_vectorizer.transform(y) for y in df_org['sentence_tokens'].values]
df_org['article_TF'] = [sum(a) for a in df_org['sentence_TFs'].values]


# Setting Hyperparameters 

In [9]:
dataset = 'Default'# Default or Custom - Default: summaries are returned at every unique date (e.g., WCEP) or Custom: at every true summary date (e.g., W2E)
 max_sentences = 1
max_tokens = 40

batch = 64
epoch= 5
temp = 0.2

D_in = 1024
D_hidden = 1024
head = 2
dropout = 0
lr = 1e-5
N = 10
distill_ratio = 0.5

model = Model(D_in, D_hidden, head, dropout).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Simulating EMDS

In [10]:
Queries = list(target_df.Query.unique())
tuned_summary = pd.DataFrame(index=Queries)

theme_dic = {}
weights_dic = {}

losses = []
weights = {}
accum_weights = {}
centers = {}
accum_centers = {}
accum_cluster_topN_indices = {}
accum_cluster_topN_scores = {}
prev_summaries = {}
to_drop_indices = []
target_window_indices = []
WE2_concurrent_queries_df = pd.DataFrame(columns = ['queries'])

for date in tqdm(target_df.date.unique()):
   
    ### Get the current context ###
    if dataset == 'Default': #WCEP
        ### Retrieve target window (by each date) ###
        target_window_indices = target_df[target_df.date==date].index
        target_window = target_df.loc[target_window_indices]
        target_window_queries = list(target_window['Query'].unique())
        summary_basis_queries = target_window_queries
    elif dataset == 'Custom': #W2E
        ### Retrieve target window (by each true summary)###
        target_window_indices = list(set(target_window_indices) - set(to_drop_indices))
        target_window_indices = target_window_indices + list(target_df[target_df.date==date].index)
        target_window = target_df.loc[target_window_indices]
        target_window_queries = list(target_window['Query'].unique())
        
        curr_stories = stories[stories.date==date].Query.unique()
        summary_basis_queries = curr_stories
        if len(curr_stories) < 1:
            to_drop_indices = []
            continue
        to_drop_indices = target_window[target_window.Query.isin(curr_stories)].index
        WE2_concurrent_queries_df.loc[date,'queries'] = target_window_queries
        if len(target_window[target_window.Query.isin(summary_basis_queries)]) < 1: continue
    
    ### Get set phrases ###
    cluster_topN_indices, cluster_topN_scores, cluster_topN_words = get_cluster_theme(all_vocab, target_window, N)
    theme_dic[date] = (cluster_topN_indices, cluster_topN_scores, cluster_topN_words)
    weights = {}
    for query in target_window.Query.unique():
        weights_raw = np.array(vstack([x[0, cluster_topN_indices[query]] for x in target_window[target_window.Query==query].article_TF.values]).multiply(cluster_topN_scores[query]).sum(1)).squeeze(1)
        weights[query] = weights_raw/np.sum(weights_raw)
    weights_dic[date] = weights


    ### Set previous set phrases ###
    for query in summary_basis_queries:
        if query not in accum_cluster_topN_indices:
            accum_cluster_topN_indices[query] =  cluster_topN_indices[query]
            accum_cluster_topN_scores[query] = cluster_topN_scores[query]
        else:
            accum_cluster_topN_indices[query] = np.append(accum_cluster_topN_indices[query] , cluster_topN_indices[query])
            accum_cluster_topN_scores[query] = np.append(accum_cluster_topN_scores[query] , cluster_topN_scores[query])
    

    ### Initialize set prototypes ###
    for query in target_window_queries:
        weights = weights_dic[date]
        ## Initialize set prototype to mean embedding + set phrases ##
        centers[query] = np.sum(target_window[target_window.Query==query].mean_embd.values * weights[query]) 
        
        ## Initialize set prototype to tuned embedding + accum set phrases ##
        if query in summary_basis_queries:
            model.eval()
            query_idices = target_window[target_window['Query']==query].index 
            outputs = model(masked_tensors[query_idices], masks[query_idices])
            target_window.loc[query_idices, 'tuned_embd'] = pd.Series(list(outputs[0].squeeze(1).cpu().detach().numpy()), index=query_idices) 
        
            accum_weights_raw = vstack(target_window[target_window.Query==query].article_TF.values)[:,accum_cluster_topN_indices[query]].multiply(accum_cluster_topN_scores[query]).sum(1).ravel().tolist()[0]
            accum_weights[query] = accum_weights_raw/np.sum(accum_weights_raw)
            accum_centers[query] = np.sum(target_window[target_window.Query==query].tuned_embd.values * accum_weights[query]) 
        else:
            accum_centers[query] = centers[query]

        
    
    ### Train a model ###
    num_itr = int(len(target_window)/batch)+1
    for e in range(epoch):
        ## Model training ##
        target_class_embds = torch.tensor(np.array([centers[q] * (1-distill_ratio) + distill_ratio *  accum_centers[q] for q in target_window.Query.unique()])).cuda()
        
        for itr in range(num_itr):
            model.train()
            samples = np.random.choice(target_window.index, batch) #window.index            
            class_indices = [target_window_queries.index(q) for q in target_window.loc[samples,'Query']]
            sample_outputs = model(masked_tensors[samples], masks[samples])[0].squeeze(1)
            loss = get_loss(sample_outputs, class_indices, target_class_embds, temp)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        losses.append(loss)

        ## Center update ##
        model.eval()
        for query in summary_basis_queries:
            query_idices = target_window[target_window['Query']==query].index 
            outputs = model(masked_tensors[query_idices], masks[query_idices])
            target_window.loc[query_idices, 'tuned_embd'] = pd.Series(list(outputs[0].squeeze(1).cpu().detach().numpy()), index=query_idices) 
            accum_centers[query] = np.sum(target_window[target_window.Query==query].tuned_embd.values * accum_weights[query]) 

    ### Summarize sets ###
    ## Sentence score computation ##
    tuned_summaries = []

    for query in summary_basis_queries:
        query_window = target_window[target_window.Query==query]
        if len(query_window) < 1: continue
        outputs = model(masked_tensors[query_window.index], masks[query_window.index])
        sentences_ap_weights = outputs[2].cpu().detach().numpy().squeeze(2)
        sentences_tuned_embds = outputs[3].cpu().detach().numpy()

        docs_sims = cosine_similarity([centers[query]], np.array(list(query_window.tuned_embd.values)))[0]
        doc_phrase_scores_raw = [x[:50, cluster_topN_indices[query]].multiply(cluster_topN_scores[query]) for x in query_window.sentence_TFs.values]
        doc_phrase_scores_sum = vstack(doc_phrase_scores_raw).sum()
        doc_phrase_scores = [np.array(x.sum(1)/doc_phrase_scores_sum).ravel() for x in doc_phrase_scores_raw]

        accum_doc_sims = cosine_similarity([accum_centers[query]], np.array(list(query_window.tuned_embd.values)))[0]
        accum_doc_phrase_scores_raw = [x[:50, accum_cluster_topN_indices[query]].multiply(accum_cluster_topN_scores[query]) for x in query_window.sentence_TFs.values]
        accum_sentence_phrase_scores_sum = vstack(accum_doc_phrase_scores_raw).sum()
        accum_doc_phrase_scores = [np.array(x.sum(1)/accum_sentence_phrase_scores_sum).ravel() for x in accum_doc_phrase_scores_raw]
        
        summary_scores = []
        for doc_id in range(len(query_window)):
            doc_score = np.exp(docs_sims[doc_id]) * (1-distill_ratio) + distill_ratio * np.exp(accum_doc_sims[doc_id])
            for sen_id in range(min(50,query_window.iloc[doc_id].sentence_counts)):
                sentence_phrase_score = doc_phrase_scores[doc_id][sen_id] * (1-distill_ratio) + distill_ratio * accum_doc_phrase_scores[doc_id][sen_id]
                composite_score = doc_score  * sentences_ap_weights[doc_id][sen_id]  * sentence_phrase_score  
                summary_scores.append((composite_score, query_window.iloc[doc_id].sentences[sen_id], sentences_tuned_embds[doc_id][sen_id]))
        
        ## Pick top sentences ##
        summary_scores.sort(reverse=True, key=lambda e:e[0])
        
        all_tokens = []
        all_sentences = []
        all_sentences_embds = []
        while len(all_sentences) < max_sentences: 
            (score, sentence, tuned_embd) =  summary_scores.pop(0)
            all_sentences.append(sentence.replace("\n"," "))
            all_sentences_embds.append(tuned_embd)
        summary = ' '.join(all_sentences)
        tuned_summaries.append(summary)
        
        if query not in prev_summaries:
            prev_summaries[query] = all_sentences_embds
        else:
            prev_summaries[query] = prev_summaries[query] + all_sentences_embds
    
    tuned_summary[str(date)[:10]] = pd.Series(tuned_summaries, index = summary_basis_queries)

  0%|          | 0/235 [00:00<?, ?it/s]

# Evaluation

In [11]:
tuned_summary[tuned_summary.columns[:3]].head(3) # Example output summaries

Unnamed: 0,2019-01-01,2019-01-02,2019-01-03
68967,• Watch New Horizons probe ring in the New Ye...,Scientists are already learning more about Ult...,Around 10 hours after reaching the icy world o...
68982,,Robow's arrest has been seen as a high-profile...,Somalia’s U.N. Ambassador Abukar Dahir Osman...
68968,,Cook said Apple has lowered its revenue guidan...,"Apple stocks have tumbled, after the company r..."


In [12]:
t_summary = get_tokenized_summary(tuned_summary)
t_reference_summaries = []
for (idx, row) in stories.iterrows():
        t_reference_summaries.append(get_tokens(row['summary']))
stories['tokenized_summary'] = t_reference_summaries

In [13]:
output_relevance_df = get_daily_score_df(t_summary, stories, 'ROUGE', 'L', WE2_concurrent_queries_df)
relevance_score = np.round(np.mean(sum([list(output_relevance_df[date].dropna().values) for date in output_relevance_df], []),0)*100, 2)    

In [14]:
output_novel_df, output_overlap_df, novel_ratio_df = get_novel_overlap_score_df(t_summary, stories, 'ROUGE', 'L')
novelty_score = np.round(np.mean(sum([list(output_novel_df[date].dropna().values) for date in output_novel_df], []),0)*100, 2)

In [15]:
output_contrast_df = get_daily_contrast_df(t_summary, stories, 'ROUGE', 'L', WE2_concurrent_queries_df)
contrast_score = [100-x for x in np.round(np.mean(sum([sum(list(output_contrast_df[date].dropna().values), []) for date in output_contrast_df.columns],[]),0)*100,2)]
distinctive_score = [np.round(x,2) for x in np.divide(contrast_score,(100-relevance_score))]

In [16]:
print(relevance_score, novelty_score, distinctive_score) #Precision, Recall, F1 Score

[22.4  23.41 21.17] [16.15 12.43 12.57] [1.29, 1.31, 1.27]
