In [None]:
from statistics import mean

import evaluate
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("./logs/checkpoint-600")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

In [None]:
metric_names = ["rouge", "bertscore", "bleu", "meteor"]
metrics = [evaluate.load(metric) for metric in metric_names]

In [None]:
reviews = [
    "Best breakfast in Akron. They care about quality and it shows in the food. The bar is designed like an island in your kitchen, makes you feel like you're at home. Wait staff is very professional and treat you like family. Owner is local and is very hands on; which shows in the food. Great place for lunch too!",
    "Very busy on weekends, yet there is never too long of a wait. Service is great and the portions are very generous, especially the pancakes! (My favorite). The new interior is very nice and adds great atmosphere!",
    "What a great place! Food is amazing and it's not just your ordinary breakfast or lunch spot. The food is unique and delicious. One of my favorite is the red eye hash! Fresh orange juice or a bloody mary, either way both delicious. The service is always great and the atmosphere is good.",
    "One of our favorite breakfast spots. There's often a wait on the weekends but we've never waited more than 10 minutes. The host and wait staff and always friendly and accommodating and the food is consistently wonderful. I recommend the eggs Benedict with crabmeat or the red eye hash!",
    "Great restaurant. More thank likely you'll find a wait, but it's worth it. Breakfast is amazing no matter what you get. If you get a lunch try the fries, double fried, dipped in white French dressing. Or order it as a side. You'll thank me later.",
    "Fantastic Omlets wonderful friendly service. Killer spicy Brown potatoes and the special crooked river jams. One of my top breakfast stops. Relaxing with a cup of coffee and surfing the web. great way to start the day. And they have great sandwiches too for lunch.",
    "A very good breakfast spot that has it's own take on popular dishes. The potatoes are especially tasty, though everyone at our table enjoyed their meals which ranged from pancakes, eggs and french toast. I also like the Akron themed pictures on the wall.",
    "Great breakfast!!! Come early because lines start early. Bloody Mary's and mimosas. Fast and affordable. Great experience. Staff was friendly and the portions are huge!!! $- $$. Great price and an amazing brunch stop in Akron / Fairlawn my eggs and sausage were cooked perfectly!!!",
]

gold_summ = [
    "This restaurant has consistently good food and service. It is an especially popular place for breakfast, though they serve a tasty lunch as well. The atmosphere inside is positive and the staff are always friendly. Expect a short wait on the weekends, as it can become overcrowded.",
    "Really great restaurant for a nice breakfast! Fantastic and unique dishes that never fails to amaze customers, friendly and efficient staff, generous portions and great atmosphere. Excellent menu with a wide variety. Management is quality-minded. Overall a highly recommended place.",
    "This restaurant is often very busy on weekends, but even so there usually isn't much of a wait. The staff is very friendly and provide great service. The food is a bit unique, but all of it is very good, particularly the eggs and pancakes. They specialize in breakfast, but also offer sandwiches for lunch. The portions are large for the price they charge. This place is highly recommended.",
]

palm_generated = "This restaurant is a local favorite for breakfast and lunch. The food is delicious and unique, and the portions are generous. The service is friendly and attentive, and the atmosphere is casual and inviting. There is often a wait on the weekends, but it is worth it. The prices are reasonable."

In [None]:
inputs = tokenizer(
    ",</s>".join(r for r in reviews),
    max_length=1024,
    truncation=True,
    padding="max_length",
    return_tensors="pt",
)

In [None]:
summaries = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    num_beams=3,
    max_length=128,
    early_stopping=True,
    length_penalty=0.6,
)

In [None]:
decoded_summaries = tokenizer.batch_decode(
    summaries,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=True,
    return_tensors="pt",
)

In [None]:
decoded_summaries

In [None]:
for metric in metrics:
    metric.add(predictions=decoded_summaries[0], references=gold_summ[0])
    metric.add(predictions=decoded_summaries[0], references=gold_summ[1])
    metric.add(predictions=decoded_summaries[0], references=gold_summ[2])

scores = [
    metric.compute()
    if metric.name != "bert_score"
    else metric.compute(lang="en")
    for metric in metrics
]

In [None]:
rouge1 = round(scores[0]["rouge1"], 4)
rouge2 = round(scores[0]["rouge2"], 4)
rougeL = round(scores[0]["rougeL"], 4)
rougeLsum = round(scores[0]["rougeLsum"], 4)

precision = round(mean(scores[1]["precision"]), 4)
recall = round(mean(scores[1]["recall"]), 4)
f1 = round(mean(scores[1]["f1"]), 4)

bleu = round(scores[2]["bleu"], 4)
meteor = round(scores[3]["meteor"], 4)

print("----------")
print(
    f"  ROUGE    : rouge1: {rouge1} rouge2: {rouge2} rougeL: {rougeL} rougeLsum: {rougeLsum}"
)
print(f"  BERTScore: precision: {precision} recall: {recall} f1: {f1}")
print(f"  BLEU     : {bleu}")
print(f"  METEOR   : {meteor}")
print("----------\n\n")