In [14]:
from transformers import (
    AutoTokenizer,
    LEDForConditionalGeneration,
)
from transformers import LEDConfig
import sys
sys.path.append('/n/data1/hms/dbmi/zitnik/lab/users/vau974/apcomp/PRIMER/script')
from datasets import load_dataset, load_metric
import torch

import json
import gzip
import pandas as pd
import numpy as np
from tqdm import tqdm

In [2]:
RESTS = np.array(['American restaurant', 'Angler fish restaurant',
       'Armenian restaurant', 'Asian fusion restaurant',
       'Asian restaurant', 'Australian restaurant', 'Austrian restaurant',
       'Barbecue restaurant', 'Breakfast restaurant', 'Brunch restaurant',
       'Buffet restaurant', 'Burrito restaurant',
       'Cheesesteak restaurant', 'Chicken restaurant',
       'Chicken wings restaurant', 'Chinese noodle restaurant',
       'Chinese restaurant', 'Chophouse restaurant',
       'Continental restaurant', 'Delivery Chinese restaurant',
       'Delivery Restaurant', 'Dessert restaurant',
       'Down home cooking restaurant', 'European restaurant',
       'Family restaurant', 'Fast food restaurant', 'Filipino restaurant',
       'Fine dining restaurant', 'Fish & chips restaurant',
       'German restaurant', 'Gluten-free restaurant', 'Greek restaurant',
       'Hamburger restaurant', 'Hawaiian restaurant',
       'Health food restaurant', 'Hoagie restaurant',
       'Hot dog restaurant', 'Indian restaurant', 'Irish restaurant',
       'Israeli restaurant', 'Italian restaurant', 'Japanese restaurant',
       'Korean restaurant', 'Latin American restaurant',
       'Lebanese restaurant', 'Lunch restaurant', 'Meat dish restaurant',
       'Mediterranean restaurant', 'Mexican restaurant',
       'Mexican torta restaurant', 'Middle Eastern restaurant',
       'Mongolian barbecue restaurant', 'New American restaurant',
       'Organic restaurant', 'Pan-Asian restaurant',
       'Peruvian restaurant', 'Pho restaurant', 'Pizza restaurant',
       'Ramen restaurant', 'Restaurant', 'Restaurant or cafe',
       'Restaurant supply store', 'Rice restaurant', 'Seafood restaurant',
       'Small plates restaurant', 'Soul food restaurant',
       'Soup restaurant', 'Southeast Asian restaurant',
       'Southern restaurant (US)', 'Southwestern restaurant (US)',
       'Spanish restaurant', 'Sushi restaurant', 'Taco restaurant',
       'Taiwanese restaurant', 'Takeout Restaurant', 'Takeout restaurant',
       'Tex-Mex restaurant', 'Thai restaurant',
       'Traditional American restaurant', 'Traditional restaurant',
       'Vegan restaurant', 'Vegetarian restaurant',
       'Venezuelan restaurant', 'Vietnamese restaurant',
       'Western restaurant'], dtype='<U31')

First, we load the **Multi-news** dataset from huggingface dataset hub

In [3]:
dataset=load_dataset('multi_news')

Then we load the fine-tuned PRIMERA model, please download [it](https://storage.googleapis.com/primer_summ/PRIMER_multinews.tar.gz) to your local computer.

In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

TOKENIZER = AutoTokenizer.from_pretrained('allenai/PRIMERA')

config=LEDConfig.from_pretrained('allenai/PRIMERA')

MODEL = LEDForConditionalGeneration.from_pretrained('allenai/PRIMERA').to(device)

MODEL.gradient_checkpointing_enable()
PAD_TOKEN_ID = TOKENIZER.pad_token_id
DOCSEP_TOKEN_ID = TOKENIZER.convert_tokens_to_ids("<doc-sep>")

In [5]:
def parse(path):
    g = open(path, 'r')
    for l in g:
        yield json.loads(l)

def make_metadata_df(fl):
    parser = parse(fl)
    rest_records = []
    for record in parser:
        if record['category'] != None:
            if not set(record['category']).isdisjoint(RESTS):
                rest_records.append([record['name'],
                                     record['gmap_id'],
                                     record['address'],
                                     record['avg_rating'],
                                     record['relative_results'],
                                     record['num_of_reviews']])
    
    df = pd.DataFrame(rest_records, columns=['Name', 'gmap_id', 'address', 'avg_rating', 
                                             'relative_results', 'num_of_reviews'])
    return df

def make_reviews_df(fl, min_char=0, max_char=10000):
    parser = parse(fl)
    reviews = []
    for review in parser:
        if review['text'] != None:
            if len(review['text']) >= min_char and len(review['text']) < max_char:
                reviews.append([review['name'],
                                review['rating'],
                                review['text'],
                                review['gmap_id']
                               ])
    df = pd.DataFrame(reviews, columns=['name', 'rating', 'text', 'gmap_id'])
    return df

In [6]:
reviews_df = make_reviews_df('review-Wyoming_10.json')
meta_df = make_metadata_df('meta-Wyoming.json')

In [7]:
meta_df = meta_df[meta_df['num_of_reviews'] < 100]

In [8]:
combined_df = reviews_df.merge(meta_df, on="gmap_id", how="inner")

In [9]:
sub_df = combined_df.loc[:,['text', 'Name']]
sub_df = sub_df.groupby(["Name"]).agg({"text": "|||||".join}).reset_index()

In [10]:
sub_df

Unnamed: 0,Name,text
0,225 BBQ etc,Delicious BBQ in the heart of Star Valley. We ...
1,4th on Main,Nice venue. Good food. Could have better serv...
2,8 Bytes Game Cafe,Was extremely impressed all around with this p...
3,9 Iron Italian Grill,The food was hot and fresh. Tasted great! Our...
4,A&W Restaurant,I decided to give A&W another chance yesterday...
...,...,...
229,Windy Peaks Brewery & Steakhouse,"Good beer , friendly staff and good food. Had..."
230,Wing Street,Service was slow. Order got mixed up. Manager ...
231,Wy Thai,"We drove up on impulse to get some dinner, but..."
232,Wycolo Lodge,Very excellent experience! Clean restaurant. ...


In [35]:
def process_document_review(documents):
    input_ids_all=[]
    for data in documents:
        all_docs = data.split("|||||")[:-1]
        for i, doc in enumerate(all_docs):
            doc = doc.replace("\n", " ")
            doc = " ".join(doc.split())
            all_docs[i] = doc

        #### concat with global attention on doc-sep
        input_ids = []
        for doc in all_docs:
            input_ids.extend(
                TOKENIZER.encode(
                    doc,
                    truncation=True,
                    max_length=4096 // len(all_docs),
                )[1:-1]
            )
            input_ids.append(DOCSEP_TOKEN_ID)
        input_ids = (
            [TOKENIZER.bos_token_id]
            + input_ids
            + [TOKENIZER.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids))
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    )
    return input_ids

def batch_process_review(batch):
    input_ids=process_document_review(batch).to(device)
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1
    generated_ids = MODEL.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=2048,
        num_beams=5,
    )
    generated_str = TOKENIZER.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )
    result={}
    result['generated_summaries'] = generated_str
    #result['gt_summaries']=batch['summary']
    return result

In [36]:
reviews = sub_df['text'].values
batch_size = 1
end = 10

results = []

for batch_start in tqdm(range(0,end,batch_size)):
    batch = reviews[batch_start:batch_start+batch_size]
    result = batch_process_review(batch)
    results.append(result)

100%|██████████| 10/10 [00:46<00:00,  4.62s/it]


In [71]:
reviews[0]

'Delicious BBQ in the heart of Star Valley. We are huge fans of the brisket, we get it every time. We also love their smoked mac and cheese, and their bbq sauces are all amazing, we don\'t have one favorite, there\'s just too many good ones to choose from. Highly recommend this place. 👍|||||Absolutely the best BBQ in town. Get there early enough to get the Smoked Mac and cheese. It\'s to die for.\nThe owners have created an amazing business and the food is even better.\nPrices, sizes, and taste will not disappoint.|||||Small portions for price. Chicken thigh and maybe 6 oz of chopped brisket with 2 sides for $15! All the brisket I\'ve eaten in Texas is sliced and they offer lean or moist. The coleslaw was bland and the smoked macaroni and cheese was okay.|||||Bought brisket and pulled pork. Did not want sandwich or meat and 2 or 3 asked for # of each with sides. Here overnight with NO plates or flatware. Asked for this and was told " not the way we usually do things" grudgingly gave us

In [72]:
results[0]

{'generated_summaries': ["Delicious BBQ in the heart of Star Valley.We are huge fans of the brisket, we get it every time. We also love their smoked mac and cheese, and their bbq sauces are all amazing, we don't have one favorite, there's just too many good ones to choose from. Highly recommend this place. 👍Absolutely the best BBQ in town. Get there early enough to get the Smoked Mac and cheese.."]}

We then define the functions to pre-process the data, as well as the function to generate summaries.

In [42]:
def process_document(documents):
    input_ids_all=[]
    for data in documents:
        all_docs = data.split("|||||")[:-1]
        for i, doc in enumerate(all_docs):
            doc = doc.replace("\n", " ")
            doc = " ".join(doc.split())
            all_docs[i] = doc

        #### concat with global attention on doc-sep
        input_ids = []
        for doc in all_docs:
            input_ids.extend(
                TOKENIZER.encode(
                    doc,
                    truncation=True,
                    max_length=4096 // len(all_docs),
                )[1:-1]
            )
            input_ids.append(DOCSEP_TOKEN_ID)
        input_ids = (
            [TOKENIZER.bos_token_id]
            + input_ids
            + [TOKENIZER.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids))
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    )
    return input_ids




def batch_process(batch):
    input_ids=process_document(batch['document']).to(device)
    print(input_ids.shape)
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1
    generated_ids = MODEL.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=1024,
        num_beams=5,
    )
    generated_str = TOKENIZER.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )
    result={}
    result['generated_summaries'] = generated_str
    result['gt_summaries']=batch['summary']
    return result

Next, we simply run the model on 10 data examples (or any number of examples you want)

In [43]:
import random
data_idx = random.choices(range(len(dataset['test'])),k=10)
dataset_small = dataset['test'].select(data_idx)
result_small = dataset_small.map(batch_process, batched=True, batch_size=2)

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

torch.Size([2, 658])


Map:  20%|██        | 2/10 [00:04<00:18,  2.30s/ examples]

torch.Size([2, 721])


Map:  40%|████      | 4/10 [00:10<00:15,  2.59s/ examples]

torch.Size([2, 2318])


Map:  60%|██████    | 6/10 [00:18<00:13,  3.30s/ examples]

torch.Size([2, 1469])


Map:  80%|████████  | 8/10 [00:26<00:07,  3.54s/ examples]

torch.Size([2, 3207])


Map: 100%|██████████| 10/10 [00:44<00:00,  4.48s/ examples]


In [65]:
result_small[2]['gt_summaries']

'– She wasn\'t a Target employee, but a woman who sure seemed like one allegedly made off with around $40,000 worth of iPhones from a Virginia store. NBC Washington reports Fairfax County cops are looking for the retail impostor, who they say donned attire resembling a worker\'s getup, waltzed into the stockroom of the Alexandria location with a box, and loaded the box with dozens of iPhones before taking off. WTOP reports the woman, whose image was caught on tape, seemed to be familiar with how things worked at the store, including employee hours and where the iPhones were stored. Police say the theft occurred March 15, but posted about it on Facebook Monday with a call to "help us nab an iPhone thief." (Target recently had a Boston problem.)'

In [64]:
result_small[2]['document']

"Fairfax County police are searching for a woman suspected of impersonating a Target employee and stealing more than $40,000 worth of iPhones. See video. \n \n WASHINGTON — Fairfax County police are searching for a woman suspected of impersonating a Target employee and stealing more than $40,000 worth of iPhones earlier this month. \n \n Police released surveillance footage Tuesday of the suspect leaving the store. \n \n On March 15, an unidentified woman impersonated a Target employee at the 6600 Richmond Highway location in Alexandria, Virginia, police said. \n \n She gained access to the stockroom and from there, police said she took the iPhones and put them in a box before leaving the store. \n \n Surveillance footage shows the woman leaving the store and getting into a Volvo station wagon. \n \n The suspect was familiar with store procedures, employee hours and where iPhones were kept in the stockroom. \n \n Anyone with more information about this case can call Fairfax County poli

In [63]:
result_small[2]['generated_summaries']

'WASHINGTON — Fairfax County police are searching for a woman suspected of impersonating a Target employee and stealing more than $40,000 worth of iPhones earlier this month. Police released surveillance footage Tuesday of the suspect leaving the store..'

After getting all the results, we load the evaluation metric. 


(Note in the original code, we didn't use the default aggregators, instead, we simply take average over all the scores.
We simply use 'mid' in this notebook)

In [None]:
rouge = load_metric("rouge")

In [None]:
result_small['generated_summaries']

In [None]:
score=rouge.compute(predictions=result_small["generated_summaries"], references=result_small["gt_summaries"])
print(score['rouge1'].mid)
print(score['rouge2'].mid)
print(score['rougeL'].mid)