In [1]:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import pandas as pd
import torch
import json
import numpy as np
from tqdm import tqdm_notebook as tqdm
import requests
from scipy.stats import kendalltau

In [2]:
FACTUAL_ERROR_TYPES = ['label_error', 'magnitude_error', 'ooc_error', 'trend_error','value_error','nonsense_error']

In [3]:
model_name = "khhuang/chartve"
model = VisionEncoderDecoderModel.from_pretrained(model_name).cuda()
processor = DonutProcessor.from_pretrained(model_name)

In [4]:
with open("data/chocolate.json", "r") as f:
    chocolate = json.load(f)

In [5]:
def format_query(sentence):
    return f"Does the image entails this statement: \"{sentence}\"?"

In [6]:
def proccess_samples(samples): 
    processed = []
    for sample in samples:
        img_id = '-'.join(sample['_id'].split('-')[1:]).replace('pew_','').replace('vistext_','') 
        caption_label = all([ label not in FACTUAL_ERROR_TYPES for sent_labels in sample["labels"] for label in sent_labels]) 
        caption_label = int(caption_label)

        for sentence, sent_labels in zip(sample["sentences"], sample["labels"]):
#             image_path = sample["image_path"]
            if sample['dataset'] == 'pew':
                try:
                    table_type, image_id = img_id.replace('.txt','').split('_col-')
                except:
                    table_type, image_id = img_id.replace('.txt','').split('_col_')
                if table_type.startswith('two'):
                    root_image_path = "/shared/nas/data/m1/khhuang3/mfec/data/Chart-to-text/pew_dataset/dataset/"
                else:
                    root_image_path = "/shared/nas/data/m1/khhuang3/mfec/data/Chart-to-text/pew_dataset/dataset/multiColumn"

                image_path=f"{root_image_path}/imgs/{image_id}.png"

            else:
                image_path = f"/shared/nas/data/m1/khhuang3/mfec/data/vistext/data/images/{img_id}.png"
            
            query = format_query(sentence)
            sent_label = 0 if any([l in FACTUAL_ERROR_TYPES for l in sent_labels]) else 1
            prompt =  "<chartqa>  " + query + " <s_answer>" 
            row = [sample['_id'], image_path, prompt, sent_label, caption_label]
            processed.append(row)
    processed = pd.DataFrame(processed, columns=['_id','image_path','prompt','sent_label','caption_label'])
    return processed

In [7]:
def get_prediction(processed_df):
    binary_positive_probs = []
    with torch.no_grad():
        for row in tqdm(processed_df.itertuples(), total=len(processed_df)):
#             img = Image.open(requests.get(row.image_path, stream=True).raw)
            img = Image.open(row.image_path)
            pixel_values = processor(img.convert("RGB"), random_padding=False, return_tensors="pt").pixel_values
            pixel_values = pixel_values.cuda()
            decoder_input_ids = processor.tokenizer(row.prompt, add_special_tokens=False, return_tensors="pt", max_length=510).input_ids.cuda()#.squeeze(0)
            
                    

            outputs = model(pixel_values,
                                     decoder_input_ids=decoder_input_ids)

            # positive_logit = outputs['logits'].squeeze()[-1,49922]
            # negative_logit = outputs['logits'].squeeze()[-1,2334] 
            
            binary_entail_prob = torch.nn.functional.softmax(outputs['logits'].squeeze()[-1,[2334, 49922]])[1]
            binary_positive_probs.append(binary_entail_prob.item())

    processed_df['binary_entailment_prob'] = binary_positive_probs
    return processed_df

In [8]:
def get_split(sample_id):
    if "bard" in sample_id or "gpt4v" in sample_id:
        return "LVLM"
    elif "deplot" in sample_id:
        return "LLM"
    else:
        return "FT"

In [9]:
processed_chocolate = proccess_samples(chocolate)

In [10]:
processed_chocolate = get_prediction(processed_chocolate)    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for row in tqdm(processed_df.itertuples(), total=len(processed_df)):


  0%|          | 0/5323 [00:00<?, ?it/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
  binary_entail_prob = torch.nn.functional.softmax(outputs['logits'].squeeze()[-1,[2334, 49922]])[1]


In [11]:
processed_chocolate["split"] = processed_chocolate["_id"].apply(get_split)

In [12]:
id2score = processed_chocolate.groupby('_id').binary_entailment_prob.min().to_dict()

In [13]:
processed_chocolate["chartve_score"] = processed_chocolate['_id'].map(id2score)
final_df = processed_chocolate.drop_duplicates('_id')

In [14]:
for split in ['LVLM','LLM','FT']:
    current_df = final_df.loc[final_df.split == split].dropna()
    tau = kendalltau(current_df.caption_label.values, current_df.chartve_score.values, variant='c').statistic
    print(f"Split {split}| Tau: {tau:.03f}")

Split LVLM| Tau: 0.178
Split LLM| Tau: 0.091
Split FT| Tau: 0.215
