## Multi-document summarization

In [1]:
pip install transformers==4.18.0

In [9]:
import pandas as pd
import json
import torch
DEVICE = 0 if torch.cuda.is_available() else -1

In [31]:
with open('../input/ztm-topics/ztm_with_topics_50_sentence-transformers_allenai-specter_topic_index.json') as f:
    index = json.loads(f.read())

In [32]:
df = pd.DataFrame(index['documents'])

In [33]:
df.head()

In [53]:
df.groupby('topic_number').count().sort_values('doi', ascending=False)

In [40]:
topics = df.topic_number.unique()

In [50]:
len(topics)

In [43]:
for topic in topics:
    print(topic, df[df['topic_number'] == topic].iloc[0]['keywords'])

In [22]:
def process_document(documents, tokenizer, docsep_token_id, pad_token_id, device=DEVICE):
    input_ids_all=[]
    for data in documents:
        all_docs = data.split("|||||")[:-1]
        for i, doc in enumerate(all_docs):
            doc = doc.replace("\n", " ")
            doc = " ".join(doc.split())
            all_docs[i] = doc
        
        #### concat with global attention on doc-sep
        input_ids = []
        for doc in all_docs:
            input_ids.extend(
                tokenizer.encode(
                    doc,
                    truncation=True,
                    max_length=4096 // len(all_docs),
                )[1:-1]
            )
            input_ids.append(docsep_token_id)
        input_ids = (
            [tokenizer.bos_token_id]
            + input_ids
            + [tokenizer.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids))
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=pad_token_id
    )
    return input_ids


def batch_process(batch, model, tokenizer, docsep_token_id, pad_token_id, device=DEVICE):
    input_ids=process_document(batch['document'], tokenizer, docsep_token_id, pad_token_id)
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(device)
    input_ids = input_ids.to(device)
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == docsep_token_id] = 1
    generated_ids = model.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=1024,
        num_beams=5,
    )
    generated_str = tokenizer.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )
    result={}
    result['generated_summaries'] = generated_str
    return result

In [54]:
from transformers import pipeline, AutoModel, AutoTokenizer, LEDForConditionalGeneration


mds_models = [
    {"model": "allenai/led-base-16384-ms2", "type": "multi-document", "tokenizer": None },
    {"model": "allenai/PRIMERA-multixscience", "type": "multi-document", "tokenizer": None },
    {"model": "allenai/led-base-16384-multi_lexsum-source-tiny", "type": "multi-document", "tokenizer": None },
    {"model": "allenai/led-base-16384-multi_lexsum-source-long", "type": "multi-document", "tokenizer": None },
]

topic_summaries = []
for model in mds_models:
    tok = AutoTokenizer.from_pretrained(model['model'])
    mdl = LEDForConditionalGeneration.from_pretrained(model['model'])
    mdl.to(DEVICE)
    mdl.gradient_checkpointing_enable()
    pad_token_id = tok.pad_token_id
    docsep_token_id = tok.convert_tokens_to_ids("<doc-sep>")
    for topic in topics:
        cluster = df[df['topic_number'] == topic ]['text']
        out = batch_process({ 'document': ["|||||".join([sent for sent in cluster])]}, mdl, tok, docsep_token_id, pad_token_id)
        
        topic_summaries.append({
            "topic_number": topic,
            "topic_keywords": df[df['topic_number'] == topic ]['keywords'].iloc[0],
            "model": model['model'],
            "summary": out['generated_summaries'][0]
        })
    del mdl
    del tok
    torch.cuda.empty_cache()

In [56]:
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration


models = [
    {"model": "BeIR/query-gen-msmarco-t5-large-v1", "type": "multi-document", "tokenizer": None },
]
    
topic_queries = []
for model in models:
    tokenizer = AutoTokenizer.from_pretrained(model['model'])
    mdel = T5ForConditionalGeneration.from_pretrained(model['model'])
    mdel.to(DEVICE)
    
    for topic in topics:
        cluster = df[df['topic_number'] == topic ]['text']
        para = "".join([sent for sent in cluster])
        input_ids = tokenizer.encode(para, return_tensors='pt')
        input_ids = input_ids.to(DEVICE)
        outputs = mdel.generate(
            input_ids=input_ids,
            max_length=256,
            do_sample=True,
            top_p=0.95,
            num_return_sequences=3)

        queries = []
        for i in range(len(outputs)):
            query = tokenizer.decode(outputs[i], skip_special_tokens=True)
            queries.append(query)
            
        topic_queries.append({
            "topic_number": topic,
            "topic_keywords": df[df['topic_number'] == topic ]['keywords'].iloc[0],
            "model": model['model'],
            "questions": queries
        })
        

In [57]:
df_queries = pd.DataFrame(topic_queries)

In [58]:
df_summaries = pd.DataFrame(topic_summaries)

In [64]:
top_topics = df.groupby('topic_number').count().sort_values('doi', ascending=False)[[ 'doi']].rename(columns={'doi': 'citations'})
top_topics.head(5)

In [72]:
df_summaries.head()

In [78]:
for topic in top_topics.head(5).reset_index().topic_number.unique():
    print(f"Topic {topic}: {df_summaries[df_summaries['topic_number'] == topic].iloc[0]['topic_keywords']}")
    for i, row in df_summaries[df_summaries['topic_number'] == topic].iterrows():
        print(row['model'])
        print(row['summary'])
        print('\n')

In [80]:

for topic in top_topics.head(5).reset_index().topic_number.unique():
    print(f"Topic {topic}: {df_summaries[df_summaries['topic_number'] == topic].iloc[0]['topic_keywords']}")
    for i, row in df_queries[df_queries['topic_number'] == topic].iterrows():
        for question in row['questions']:
            print(question + '?')
        print('\n') 