In [None]:
from PIL import Image
import requests
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor

from datasets import load_dataset
import os

os.environ['OPENAI_API_KEY'] = ''

from openai import OpenAI
client = OpenAI()

device = "cuda" if torch.cuda.is_available() else "cpu"

#dataset = load_dataset('eduvedras/Img_Desc',split='test',trust_remote_code=True)
dataset = load_dataset('eduvedras/Desc_Questions',split='test',trust_remote_code=True)

In [42]:
def join_predictions(template,vars):
    template = template[0].upper() + template[1:]
    template = template.replace(' - ','-')
    template = template.replace(' : ',': ')
    if 'decision tree with depth = 2' in template:
        variables = vars.split(',')
        first_index = 0
        second_index = 0
        for i in range(len(template)):
            if template[i] == '[' and first_index == 0:
                first_index = i
            elif template[i] == '[' and first_index != 0:
                second_index = i
                break
        return template[:first_index] + variables[0] + template[first_index+3:second_index] + variables[1][1:] + template[second_index+3:]
    elif 'bar chart showing the explained variance ratio' in template or 'bar chart showing the distribution of the target variable' in template:
        return template.replace('[ ]',vars)
    else:
        return template.replace('[ ]','['+vars+']')
                

In [None]:
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM

filetag = "final_model"

model_templates = AutoModelForCausalLM.from_pretrained("eduvedras/ChartClassificationModel_GiT").to(device)
model_vars = Pix2StructForConditionalGeneration.from_pretrained("eduvedras/VariableIdentificationModel_Pix2Struct").to(device)

In [44]:
checkpoint_vars = "google/pix2struct-textcaps-base"
checkpoint_templates = "microsoft/git-base"

processor_templates = AutoProcessor.from_pretrained(checkpoint_templates)
processor_vars = Pix2StructProcessor.from_pretrained(checkpoint_vars)

predictions = []
references = []

dic = {}

from tqdm.auto import tqdm

for i in range(len(dataset)):
    dic[dataset[i]['Chart_name']] = []

for i in range(len(dataset)):
    dic[dataset[i]['Chart_name']].append(dataset[i]['Description'])
    
print(len(dic))

94


In [None]:
for i in tqdm(range(len(dataset))):
    inputs_vars = processor_vars(images=dataset[i]['Chart'], return_tensors="pt").to(device)
    flattened_patches = inputs_vars.flattened_patches
    attention_mask = inputs_vars.attention_mask
    generated_vars_ids = model_vars.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=167)
    generated_vars = processor_vars.batch_decode(generated_vars_ids, skip_special_tokens=True)[0]
    
    inputs_templates = processor_templates(images=dataset[i]['Chart'], return_tensors="pt").to(device)
    pixel_values = inputs_templates.pixel_values
    generated_template_ids = model_templates.generate(pixel_values=pixel_values, max_length=200)
    generated_template = processor_templates.batch_decode(generated_template_ids, skip_special_tokens=True)[0]  
    
    generated_caption = join_predictions(generated_template,generated_vars)
    
    inputs = []
    inputs.append({"role": "system", "content": "You are a data science teacher creating exam questions."})
    inputs.append({"role": "user", "content": "Consider the following description of a data chart \"" + generated_caption + "\"."})
    inputs.append({"role": "assistant", "content": "I understand, the data chart is \"" + generated_caption + "\"."})
    inputs.append({"role": "user", "content": "Generate a true or false sentence based on this description, in your answer generate only the sentence."})
    
    completion = client.chat.completions.create(
        model="ft:gpt-3.5-turbo-0125:personal::95zeB7CK",
        messages=inputs
    )
    
    predictions.append(completion.choices[0].message.content)
    references.append(dataset[i]["Questions"][2:-2].split('", "'))
    #predictions.append(generated_caption)
    #references.append(dic[dataset[i]['Chart_name']])
    

  0%|          | 0/94 [00:00<?, ?it/s]

100%|██████████| 94/94 [06:11<00:00,  3.95s/it]


In [46]:
import evaluate


file = open(f"predictions-{filetag}.txt", "a")

bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=references)
print(results)
file.write(f"BLEU: {results}\n")

meteor = evaluate.load("meteor")
results = meteor.compute(predictions=predictions, references=references)
print(results)
file.write(f"METEOR: {results}\n")

rouge = evaluate.load("rouge")
results = rouge.compute(predictions=predictions, references=references)
print(results)
file.write(f"ROUGE: {results}\n")
file.close()

{'bleu': 0.002575442335131855, 'precisions': [0.8908496732026144, 0.802924791086351, 0.7563338301043219, 0.7203525641025641], 'brevity_penalty': 0.003259622816199413, 'length_ratio': 0.14867359828976776, 'translation_length': 1530, 'reference_length': 10291}


[nltk_data] Downloading package wordnet to
[nltk_data]     /home/eduvedras/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/eduvedras/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /home/eduvedras/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


{'meteor': 0.36450483132705586}
{'rouge1': 0.4575711049615443, 'rouge2': 0.4153702872808983, 'rougeL': 0.446040260167763, 'rougeLsum': 0.44776727656717574}


In [47]:
import pandas as pd

new_df = pd.DataFrame(columns=['Image','Prediction'])

i=0
while i < len(dataset):
    if ((new_df['Image'] == dataset[i]['Chart_name']) & (new_df['Prediction'] == predictions[i])).any():
        i += 1
        continue
    else:
        new_df.loc[len(new_df)] = {'Image': dataset[i]['Chart_name'], 'Prediction': predictions[i]}
        i += 1
    
new_df.to_csv(f'predictions-{filetag}.csv', index=False)