In [None]:
from PIL import Image
import requests
import torch
from transformers import AutoProcessor
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor

from datasets import load_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

dataset = load_dataset('eduvedras/Img_Desc',split='test',trust_remote_code=True)

In [None]:
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM

filetag = "pix2struct-textcaps-base-desc-final"

model_templates = Pix2StructForConditionalGeneration.from_pretrained("eduvedras/pix2struct-textcaps-base-desc-templates-final").to(device)
model_vars = Pix2StructForConditionalGeneration.from_pretrained("eduvedras/pix2struct-textcaps-base-desc-vars-final").to(device)

In [None]:
checkpoint = "google/pix2struct-textcaps-base"

processor = Pix2StructProcessor.from_pretrained(checkpoint)

predictions = []
references = []

dic = {}

from tqdm.auto import tqdm

for i in range(len(dataset)):
    dic[dataset[i]['Chart_name']] = []

for i in range(len(dataset)):
    dic[dataset[i]['Chart_name']].append(dataset[i]['Description'])
    
print(len(dic))

In [None]:
for i in tqdm(range(len(dataset))):
    inputs = processor(images=dataset[i]['Chart'], return_tensors="pt").to(device)
    flattened_patches = inputs.flattened_patches
    attention_mask = inputs.attention_mask
    generated_template_ids = model_templates.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=193)
    generated_template = processor.batch_decode(generated_template_ids, skip_special_tokens=True)[0]  
    generated_vars_ids = model_vars.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=167)
    generated_vars = processor.batch_decode(generated_vars_ids, skip_special_tokens=True)[0]
    generated_caption = generated_template.replace("[]","[" + generated_vars + "]")
    predictions.append(generated_caption)
    references.append(dic[dataset[i]['Chart_name']])

In [None]:
import evaluate


file = open(f"predictions-{filetag}.txt", "a")

bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=references)
print(results)
file.write(f"BLEU: {results}\n")

meteor = evaluate.load("meteor")
results = meteor.compute(predictions=predictions, references=references)
print(results)
file.write(f"METEOR: {results}\n")

rouge = evaluate.load("rouge")
results = rouge.compute(predictions=predictions, references=references)
print(results)
file.write(f"ROUGE: {results}\n")
file.close()

In [None]:
import pandas as pd

new_df = pd.DataFrame(columns=['Image','Prediction'])

i=0
while i < len(dataset):
    if ((new_df['Image'] == dataset[i]['Chart_name']) & (new_df['Prediction'] == predictions[i])).any():
        i += 1
        continue
    else:
        new_df.loc[len(new_df)] = {'Image': dataset[i]['Chart_name'], 'Prediction': predictions[i]}
        i += 1
    
new_df.to_csv(f'predictions-{filetag}.csv', index=False)