In [7]:
from datasets import load_dataset
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import torch
from evaluate import load

torch.cuda.empty_cache()

dataset = load_dataset('eduvedras/Img_Desc_Templates',split='test',trust_remote_code=True)


Downloading data: 100%|██████████| 66.7k/66.7k [00:00<00:00, 155kB/s]
Downloading data: 100%|██████████| 2.20k/2.20k [00:00<00:00, 6.53kB/s]
Generating train split: 444 examples [00:00, 39723.36 examples/s]
Generating test split: 15 examples [00:00, 4597.34 examples/s]


In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

filetag = "pix2struct-textcaps-base-desc-templates-final"

model = Pix2StructForConditionalGeneration.from_pretrained(f"eduvedras/{filetag}").to(device)
model

model.safetensors: 100%|██████████| 1.13G/1.13G [06:54<00:00, 2.73MB/s]


Pix2StructForConditionalGeneration(
  (encoder): Pix2StructVisionModel(
    (embeddings): Pix2StructVisionEmbeddings(
      (patch_projection): Linear(in_features=768, out_features=768, bias=True)
      (row_embedder): Embedding(4096, 768)
      (column_embedder): Embedding(4096, 768)
      (dropout): Dropout(p=0.06, inplace=False)
    )
    (encoder): Pix2StructVisionEncoder(
      (layer): ModuleList(
        (0-11): 12 x Pix2StructVisionLayer(
          (attention): Pix2StructVisionAttention(
            (query): Linear(in_features=768, out_features=768, bias=False)
            (key): Linear(in_features=768, out_features=768, bias=False)
            (value): Linear(in_features=768, out_features=768, bias=False)
            (output): Linear(in_features=768, out_features=768, bias=False)
          )
          (mlp): Pix2StructVisionMlp(
            (wi_0): Linear(in_features=768, out_features=2048, bias=False)
            (wi_1): Linear(in_features=768, out_features=2048, bias=False)


In [9]:
checkpoint = "google/pix2struct-textcaps-base"

processor = Pix2StructProcessor.from_pretrained(checkpoint)

predictions = []
references = []

dic = {}

from tqdm.auto import tqdm

for i in range(len(dataset)):
    dic[dataset[i]['Chart_name']] = []

for i in range(len(dataset)):
    dic[dataset[i]['Chart_name']].append(dataset[i]['Description'])
    
print(len(dic))

KeyError: 'Chart_name'

In [None]:
for i in tqdm(range(len(dataset))):
    inputs = processor(images=dataset[i]['Chart'], return_tensors="pt").to(device)
    flattened_patches = inputs.flattened_patches
    attention_mask = inputs.attention_mask
    generated_ids = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=167)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]  
    predictions.append(generated_caption)
    references.append(dic[dataset[i]['Chart_name']])

100%|██████████| 8/8 [00:06<00:00,  1.18it/s]


In [None]:
import evaluate


file = open(f"predictions-{filetag}.txt", "a")

bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=references)
print(results)
file.write(f"BLEU: {results}\n")

meteor = evaluate.load("meteor")
results = meteor.compute(predictions=predictions, references=references)
print(results)
file.write(f"METEOR: {results}\n")

rouge = evaluate.load("rouge")
results = rouge.compute(predictions=predictions, references=references)
print(results)
file.write(f"ROUGE: {results}\n")
file.close()

{'bleu': 0.8791842281420499, 'precisions': [0.8863636363636364, 0.8888888888888888, 0.8666666666666667, 0.875], 'brevity_penalty': 1.0, 'length_ratio': 1.0, 'translation_length': 44, 'reference_length': 44}


[nltk_data] Downloading package wordnet to
[nltk_data]     /home/eduvedras/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/eduvedras/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /home/eduvedras/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


{'meteor': 0.6097657954935154}
{'rouge1': 0.7638888888888888, 'rouge2': 0.5416666666666666, 'rougeL': 0.7597222222222223, 'rougeLsum': 0.7555555555555555}


In [None]:
import pandas as pd

new_df = pd.DataFrame(columns=['Image','Prediction'])

i=0
while i < len(dataset):
    if ((new_df['Image'] == dataset[i]['Chart_name']) & (new_df['Prediction'] == predictions[i])).any():
        i += 1
        continue
    else:
        new_df.loc[len(new_df)] = {'Image': dataset[i]['Chart_name'], 'Prediction': predictions[i]}
        i += 1
    
new_df.to_csv(f'predictions-{filetag}.csv', index=False)