In [None]:
INPUT_MAX_LENGTH = 256
PREDICTION_MAX_LENGTH = 64

In [None]:
import pandas as pd
import torch
import evaluate 
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from datasets import load_dataset

if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")
    
device = mps_device

In [None]:
BASE_PATH = './'
train_path = f'{BASE_PATH}Health-Fact-Checking/data/PUBHEALTH/formatted_train_most_similar.csv'
dev_path = f'{BASE_PATH}Health-Fact-Checking/data/PUBHEALTH/formatted_dev_most_similar.csv'
test_path = f'{BASE_PATH}Health-Fact-Checking/data/PUBHEALTH/formatted_test_most_similar.csv'
FEATURES = ['claim','top_k', 'label']


dataset = load_dataset("csv", data_files=[train_path])
val_dataset = load_dataset("csv", data_files=[dev_path])
test_dataset = load_dataset("csv", data_files=[test_path])


In [None]:
model_name = "google/pegasus-xsum"
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name)


In [None]:
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

SUMMARY_MODEL_NAME = model_name

In [None]:
def compute_metrics_summary(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = summary_tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, summary_tokenizer.pad_token_id)
    decoded_labels = summary_tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != summary_tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}


In [None]:
import numpy as np
import evaluate

rouge = evaluate.load("rouge")

summary_tokenizer = PegasusTokenizer.from_pretrained(model_name)
summary_model = PegasusForConditionalGeneration.from_pretrained(model_name)

data_collator = DataCollatorForSeq2Seq(tokenizer=summary_tokenizer, 
                                       model=summary_model)

def preprocess_function_summary(examples):
    prefix = "summarize: "
    inputs = [prefix + doc for doc in examples["top_k"]]
    # max_length=512, 
    model_inputs = summary_tokenizer(inputs, max_length=INPUT_MAX_LENGTH, 
                                     truncation=True, padding=True)
    
    # max_length=64, 
    labels = summary_tokenizer(text_target=examples["explanation"], 
                               max_new_tokens=PREDICTION_MAX_LENGTH, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs



tokenized_train_sum = dataset.shuffle(seed=42).remove_columns(["label", "subjects"]).map(preprocess_function_summary, batched=True)
tokenized_val_sum = val_dataset.shuffle(seed=42).remove_columns(["label", "subjects"]).map(preprocess_function_summary, batched=True)
tokenized_test_sum = test_dataset.shuffle(seed=42).remove_columns(["label", "subjects"]).map(preprocess_function_summary, batched=True)


training_args = Seq2SeqTrainingArguments(
    output_dir=f"health_summary_model_true_false_{SUMMARY_MODEL_NAME}",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    # fp16=True,
    push_to_hub=False
)

trainer = Seq2SeqTrainer(
    model=summary_model,
    args=training_args,
    train_dataset=tokenized_train_sum["train"], #.select(list(np.arange(0, 100))),
    eval_dataset=tokenized_val_sum["train"], #.select(list(np.arange(0, 100))),
    tokenizer=summary_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics_summary
)

trainer.train()

In [None]:
# save_path = f'/tmp/explanation-generation-peagusus-{INPUT_MAX_LENGTH}-{PREDICTION_MAX_LENGTH}'
# trainer.save_model(save_path)
# print(save_path)

### Generate ROUGE scores on dataset

In [None]:
preds = trainer.predict(tokenized_test_sum["train"])
preds.metrics

In [None]:
from transformers import PegasusTokenizer, PegasusForConditionalGeneration

tokenizer = PegasusTokenizer.from_pretrained(save_path)
model = PegasusForConditionalGeneration.from_pretrained(save_path)


In [None]:
tokenized_test_sum["train"][i]['top_k']

### Apply decoding top_k and top_p

In [None]:
results = []
pred_texts =[]
label_texts = []
for i in range(len(tokenized_test_sum["train"])):
    print(i)
    input_text = tokenized_test_sum["train"][i]['top_k']
    label_text = tokenized_test_sum["train"][i]["explanation"]
    
    inputs = tokenizer(input_text, truncation=True, return_tensors="pt").input_ids
    outputs = model.generate(inputs,
                             do_sample=True, 
                             max_new_tokens=64, 
                             top_k=0, 
                             top_p=0.95, 
                             num_return_sequences=1)

    for i, sample_output in enumerate(outputs):
        pred_text = tokenizer.decode(sample_output, skip_special_tokens=True)
        pred_texts.append(pred_text)
        label_texts.append(label_text)
    # break
result = rouge.compute(predictions=pred_texts, references=label_texts, 
                       use_stemmer=True)
results.append(result)        

### Prediction for sample text

In [None]:
preds = []
labels = []
metrics = []

for idx, row in test_df.iterrows():
    print(idx)
    input_text = row['top_k']
    gt_text = row['explanation']
    
    batch = tokenizer(input_text, truncation=True, padding="longest", return_tensors="pt")
    translated = model.generate(**batch)
    pred_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
    
    preds.append(pred_text[0])
    labels.append(gt_text)
    
metrics = rouge.compute(predictions=preds, references=labels, use_stemmer=True)    
metrics

In [None]:
rouge = evaluate.load("rouge")
label = "California's largest electricity provider has turned off power to hundreds of thousands of customers."
preds = tgt_text[0]

### Analyze manual scores

In [None]:
scores = pd.read_csv('/Users/neeteshtiwari/Documents/PredictedExplanations_Scored_NG.csv')
scores.groupby('Best explanation').agg({'Best explanation Rating': ['count', 'mean']
                                       })

In [None]:
scores