In [None]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import sys
from datasets import load_dataset, load_metric
import torch
import re
import json
import gzip
import pandas as pd
import numpy as np
from tqdm import tqdm
import nltk
import os
import wandb
import random
import time

import wandb
import os

import tempfile
import zipfile

from datasets import load_dataset, load_metric
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    PreTrainedTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
)

In [None]:
wandb_key = ""
wandb.login(key=wandb_key)

In [None]:
def get_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    return model, tokenizer

def get_wandb_model(path):
    run = wandb.init()
    artifact = run.use_artifact(path, type="model")
    artifact_dir = artifact.download()
    return artifact_dir

def count_parameters(model): 
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable, non_trainable

# Base model

In [None]:
model_name = get_wandb_model('flavorfusion-team/FlavorFusion/model-w10g07vv:v0')
base_model, base_tokenizer = get_model(model_name)

In [None]:
# model_name = get_wandb_model('flavorfusion-team/FlavorFusion/pruned_model:v0')
# pruned_model, pruned_tokenizer = get_model(model_name)

In [None]:
def split_docs(text: str, doc_sep_token: str):
    """Split a string into multiple reviews based on separator token."""
    text = re.sub(rf"{doc_sep_token}$", "", text.strip())
    return [doc.strip() for doc in text.split(doc_sep_token)]

def get_num_docs(text: str, doc_sep_token: str) -> int:
    """Get number of reviews in a string with separator token."""
    return len(list(filter(bool, split_docs(text, doc_sep_token=doc_sep_token))))

def sample_reviews(examples, max_docs_per_review=5, k_top_longest=20):
    """Perform data augmentation by sampling the k longest reviews in a given data point,
    then dividing into max_docs number of new data points.
    """
    text_column = 'review_str'
    summary_column = 'summary'
    
    new_reviews = []
    new_summaries = []
    for i in range(len(examples[text_column])):
        summary = examples[summary_column][i]
        docs = examples[text_column][i]
        docs = split_docs(docs, '|||||')
        longest_docs = sorted(docs, key=len, reverse=True)[:k_top_longest]
        random.shuffle(longest_docs)
        new_docs = [longest_docs[i:i + max_docs_per_review] for i in range(0, len(longest_docs), max_docs_per_review)]
        new_docs = ['|||||'.join(new_docs_i) for new_docs_i in new_docs]
        new_reviews += new_docs
        new_summaries += [summary]*len(new_docs)
    return {'augmented_review_str': new_reviews, 'new_summary': new_summaries}


def process_document(documents, doc_sep, max_source_length, tokenizer, DOCSEP_TOKEN_ID, PAD_TOKEN_ID):
    input_ids_all=[]
    for data in documents:
        all_docs = data.split(doc_sep)[:-1]
        for i, doc in enumerate(all_docs):
            doc = doc.replace("\n", " ")
            doc = " ".join(doc.split())
            all_docs[i] = doc

        #### concat with global attention on doc-sep
        input_ids = []
        for doc in all_docs:
            input_ids.extend(
                tokenizer.encode(
                    doc,
                    truncation=True,
                    max_length=max_source_length // len(all_docs),
                )[1:-1]
            )
            input_ids.append(DOCSEP_TOKEN_ID)
        input_ids = (
            [tokenizer.bos_token_id]
            + input_ids
            + [tokenizer.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids))
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    )
    return input_ids


def preprocess_function(examples, tokenizer, text_column, summary_column, max_source_length, max_target_length, 
                        padding="max_length", ignore_pad_token_for_loss=True, prefix=""):
    
    model_inputs = {}
    PAD_TOKEN_ID = tokenizer.pad_token_id
    DOCSEP_TOKEN_ID = tokenizer.convert_tokens_to_ids("<doc-sep>")
    
    inputs, targets = [], []
    for i in range(len(examples[text_column])):
        if examples[text_column][i] and examples[summary_column][i]:
            inputs.append(examples[text_column][i])
            targets.append(examples[summary_column][i])

    inputs = [prefix + inp for inp in inputs]
    
    model_inputs['input_ids'] = process_document(inputs,
                                                 doc_sep='|||||', 
                                                 max_source_length=max_source_length,
                                                 tokenizer=tokenizer, 
                                                 DOCSEP_TOKEN_ID=DOCSEP_TOKEN_ID, 
                                                 PAD_TOKEN_ID=PAD_TOKEN_ID)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length" and ignore_pad_token_for_loss:
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["input_ids"]
    
    global_attention_mask = torch.zeros_like(model_inputs['input_ids']).to(model_inputs['input_ids'])
    
    # put global attention on <s> token
    global_attention_mask[:, 0] = 1
    global_attention_mask[model_inputs['input_ids'] == DOCSEP_TOKEN_ID] = 1
    
    model_inputs["global_attention_mask"] = global_attention_mask
    
    return model_inputs


def postprocess_text(preds, labels):
    """Decode predictions after forward pass so we can do evaluation metrics."""
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

In [None]:
data_files = {'train': '*.csv'}
raw_dataset = load_dataset('csv', data_files=data_files, split='train', streaming = False)
column_names = raw_dataset.column_names

In [None]:
raw_dataset = raw_dataset.shuffle(seed=42)
raw_dataset = raw_dataset.train_test_split(test_size=0.05)
raw_dataset = raw_dataset['test']

In [None]:
aug_kwargs = {'max_docs_per_review':5, 'k_top_longest': 5}

aug_dataset = raw_dataset.map(
            sample_reviews,
            fn_kwargs=aug_kwargs,
            batched=True,
            num_proc=4,
            remove_columns=column_names,
            desc="Augmenting train and test datasets")

token_kwargs = {'text_column': 'augmented_review_str', 
                'tokenizer': base_tokenizer,
                'summary_column': 'new_summary', 
                'max_source_length': 700, 
                'max_target_length': 100}

full_dataset = aug_dataset.map(
            preprocess_function,
            fn_kwargs=token_kwargs,
            batched=True,
            num_proc=4,
            desc="Running tokenizer on train and test datasets")

In [None]:
metric = load_metric("rouge")

In [None]:
def inference_batch(batch, model, tokenizer):
    input_ids = batch['input_ids']

    # get the input ids and attention masks together
    global_attention_mask = batch['global_attention_mask']
    
    start = time.time()
    
    generated_ids = model.generate(
        input_ids=torch.tensor(input_ids).to(model.device),
        global_attention_mask=torch.tensor(global_attention_mask).to(model.device),
        use_cache=True,
        max_length=100,
        num_beams=1)
    
    generated_str = tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
    
    end = time.time()
    
    result={}
    result['generated_summaries'] = generated_str
    result['gt_summaries']=batch['new_summary']

    score = metric.compute(predictions=generated_str, references=batch['new_summary'])
    scores = {key: [value.mid.fmeasure * 100] for key, value in score.items()}
    
    result['rouge1'] = scores['rouge1']
    result['rouge2'] = scores['rouge2']
    result['rougeL'] = scores['rougeL']
    result['rougeLsum'] = scores['rougeLsum']
    result['inference_time'] = [end-start]
    
    return result

In [None]:
eval_kwargs = {'model': base_model, 
                'tokenizer': base_tokenizer}

res = full_dataset.map(inference_batch, fn_kwargs=eval_kwargs, batched=True, batch_size=1)

In [None]:
def average_metrics(result):
    metrics_list = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum', 'inference_time']
    avg = [sum(result[met])/len(res) for met in metrics_list]
    return avg

In [None]:
print(average_metrics(res))

# Base model

In [None]:
model_name = get_wandb_model('flavorfusion-team/FlavorFusion/pruned_model:v0')
pruned_model, pruned_tokenizer = get_model(model_name)

In [None]:
eval_kwargs = {'model': pruned_model, 
                'tokenizer': pruned_tokenizer}

res = full_dataset.take(5).map(inference_batch, fn_kwargs=eval_kwargs, batched=True, batch_size=1)

In [None]:
print(average_metrics(res))