##Baselines:

Unimodal:
- Text: Questions --> Qwen LLM --> Answer
- Chart to Table: Image --> Table --> Qwen LLM --> Answer

Simple:
- Image --> CLIP --> projection --> + question --> GPT-2/Qwen3B LLM --> Answer

SOTA:
- UniChart
- AskChart

In [None]:
import torch
from tqdm import tqdm
import re
import pandas as pd

from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
# Load data
import polars as pl
import pandas as pd

from huggingface_hub import login
login(token="hf token")

from datasets import load_dataset

ds = load_dataset("HuggingFaceM4/ChartQA")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/852 [00:00<?, ?B/s]

data/train-00000-of-00003-49492f364babfa(…):   0%|          | 0.00/219M [00:00<?, ?B/s]

data/train-00001-of-00003-7302bae5e425bb(…):   0%|          | 0.00/311M [00:00<?, ?B/s]

data/train-00002-of-00003-194c9400785577(…):   0%|          | 0.00/315M [00:00<?, ?B/s]

data/val-00000-of-00001-0f11003c77497969(…):   0%|          | 0.00/50.2M [00:00<?, ?B/s]

data/test-00000-of-00001-e2cd0b7a0f9eb20(…):   0%|          | 0.00/68.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/28299 [00:00<?, ? examples/s]

Generating val split:   0%|          | 0/1920 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2500 [00:00<?, ? examples/s]

## Unimodal Baseline: Text


In [None]:
model_name = "Qwen/Qwen3-4B-Instruct-2507"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

model.eval()

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2560)
    (layers): ModuleList(
      (0-35): 36 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2560, out_features=4096, bias=False)
          (k_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=2560, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2560, out_features=9728, bias=False)
          (up_proj): Linear(in_features=2560, out_features=9728, bias=False)
          (down_proj): Linear(in_features=9728, out_features=2560, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2560,), eps=1e-06)
        (post_attention_layer

In [None]:
def clean_prediction(text):
    """
    Extract first short answer from model output.
    """
    text = text.strip()
    text = text.split("\n")[0]
    text = text.split("Answer:")[-1]
    return text.strip().strip(".")


def normalize_text(text):
    return text.lower().strip()


def extract_number(text):
    """
    Extract first numeric value from string if present.
    """
    match = re.search(r"-?\d+\.?\d*", text.replace(",", ""))
    if match:
        return float(match.group())
    return None


def relaxed_numeric_match(pred, gt, tol=0.05):
    """
    ±5% tolerance numeric accuracy
    """
    pred_num = extract_number(pred)
    gt_num = extract_number(gt)

    if pred_num is None or gt_num is None:
        return False

    return abs(pred_num - gt_num) / (abs(gt_num) + 1e-8) <= tol


In [None]:
def generate_answer(question):
    prompt = (
        "Answer the following chart question with a single short answer.\n\n"
        f"Question: {question}\n"
        "Answer:"
    )

    messages = [
        {"role": "user", "content": prompt},
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=40,
            do_sample=False,      # deterministic baseline
            temperature=0.0
        )

    response = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    )

    return clean_prediction(response)

In [None]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Model device:", next(model.parameters()).device)


split = ds["val"]
batch_size = 8   # T4 sweet spot (try 8 or 16)

exact_correct = 0
relaxed_correct = 0
total = 0

results = []

def build_prompt(question):
    return (
        "Answer the following chart question with a single short answer.\n\n"
        f"Question: {question}\n"
        "Answer:"
    )

for i in tqdm(range(0, len(split), batch_size)):

    batch = split[i:i+batch_size]

    questions = batch["query"]
    gt_lists = batch["label"]

    prompts = [build_prompt(q) for q in questions]

    messages_batch = [
        [{"role": "user", "content": p}]
        for p in prompts
    ]

    inputs = tokenizer.apply_chat_template(
        messages_batch,
        add_generation_prompt=True,
        tokenize=True,
        padding=True,
        return_tensors="pt",
        return_dict=True,
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=40,
            do_sample=False
        )

    generated_tokens = outputs[:, inputs["input_ids"].shape[-1]:]
    preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

    for question, pred, gt_list in zip(questions, preds, gt_lists):

        pred = clean_prediction(pred)
        pred_norm = normalize_text(pred)

        exact = False
        relaxed = False

        for gt in gt_list:
            gt_norm = normalize_text(str(gt))

            if pred_norm == gt_norm:
                exact = True
                relaxed = True
                break

            if relaxed_numeric_match(pred, str(gt)):
                relaxed = True

        if exact:
            exact_correct += 1
        if relaxed:
            relaxed_correct += 1

        total += 1

        results.append({
            "question": question,
            "ground_truth": gt_list,
            "prediction": pred,
            "exact_match": exact,
            "relaxed_match": relaxed
        })

# Final metrics
exact_acc = exact_correct / total
relaxed_acc = relaxed_correct / total

print("\n==============================")
print("TEXT-ONLY BASELINE RESULTS")
print("==============================")
print(f"Total samples: {total}")
print(f"Exact Accuracy: {exact_acc:.4f}")
print(f"Relaxed Accuracy (±5%): {relaxed_acc:.4f}")


for r in results[:5]:
    print(r["question"])
    print("GT:", r["ground_truth"])
    print("Pred:", r["prediction"])
    print()
df = pd.DataFrame(results)
df.to_csv("chartqa_text_only_results_full.csv", index=False)

CUDA available: True
Device: NVIDIA RTX PRO 6000 Blackwell Server Edition
Model device: cuda:0


  4%|▍         | 10/240 [00:06<02:28,  1.55it/s]


KeyboardInterrupt: 

## Unimodal Baseline: Chart to Table

In [None]:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Use a Donut model finetuned on structured documents (CORD / tables)
donut_model_name = "naver-clova-ix/donut-base-finetuned-cord-v2"

processor = DonutProcessor.from_pretrained(donut_model_name)
ocr_model = VisionEncoderDecoderModel.from_pretrained(donut_model_name).to(device)


preprocessor_config.json:   0%|          | 0.00/362 [00:00<?, ?B/s]

The image processor of type `DonutImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/536 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/335 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/806M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/484 [00:00<?, ?it/s]



In [None]:
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# ocr_model_name = "microsoft/trocr-base-printed"
# processor = TrOCRProcessor.from_pretrained(ocr_model_name)
# ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name).to("cuda")

# from paddleocr import PaddleOCR
# ocr_model = PaddleOCR(use_angle_cls=True, lang='en')

[33mChecking connectivity to the model hosters, this may take a while. To bypass this check, set `PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK` to `True`.[0m


KeyboardInterrupt: 

In [None]:
qwen_model_name = "Qwen/Qwen3-4B-Instruct-2507"
tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
qwen_model = AutoModelForCausalLM.from_pretrained(
    qwen_model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

model.safetensors:   0%|          | 0.00/806M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/238 [00:00<?, ?B/s]

In [None]:
from PIL import Image
def ocr_extract_text(image_or_path):
    """Extract text from PIL Image or path using TrOCR"""
    if isinstance(image_or_path, str):
        image = Image.open(image_or_path).convert("RGB")
    else:
        image = image_or_path.convert("RGB")

    pixel_values = processor(images=image, return_tensors="pt").pixel_values.to("cuda")
    with torch.no_grad():
        generated_ids = ocr_model.generate(pixel_values)
    text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return text.strip()

def ocr_to_table_text(raw_text):
    """Convert OCR text to Markdown-like table"""
    lines = raw_text.split("\n")
    lines = [l.strip() for l in lines if l.strip()]
    # replace multiple spaces with pipe separators
    lines = [" | ".join(l.split()) for l in lines]
    table_text = "\n".join(lines)
    return table_text



def generate_answer_from_table(table_text, question):
    """Pass table text + question to Qwen"""
    prompt = f"Given the following table extracted from a chart:\n\n{table_text}, Question: {question}, \n\nAnswer:"
    messages = [{"role": "user", "content": prompt}]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        padding=True,
        return_dict=True,
        return_tensors="pt",
    ).to(qwen_model.device)
    with torch.no_grad():
        outputs = qwen_model.generate(**inputs, max_new_tokens=60, do_sample=False)
    answer = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    return answer.strip()

In [None]:
import torch
import pandas as pd
from tqdm import tqdm
from PIL import Image

# ------------------------------
# 1. Setup
# ------------------------------
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Model device:", next(qwen_model.parameters()).device)

# Use first 50 validation examples
split = ds["val"]#.select(range(25))
split_list = [dict(ex) for ex in split]
batch_size = 4  # safe for Qwen3-4B on your GPU

exact_correct = 0
relaxed_correct = 0
total = 0
results = []

# ------------------------------
# 2. Helper functions
# ------------------------------
def build_prompt(table_text, question):
    return (
        f"Given the following table extracted from a chart:\n\n{table_text}\n\n"
        f"Answer the following question:\n{question}\nAnswer:"
    )

# ------------------------------
# 3. Batched loop
# ------------------------------
for i in tqdm(range(0, len(split_list), batch_size)):
    batch = split_list[i:i+batch_size]

    # ------------------------------
    # 3a. OCR each chart → table
    # ------------------------------
    raw_texts = []
    table_texts = []
    for ex in batch:
        img = ex["image"]  # PIL Image
        raw_text = ocr_extract_text(img)
        table_text = ocr_to_table_text(raw_text)
        raw_texts.append(raw_text)
        table_texts.append(table_text)

    # ------------------------------
    # 3b. Build batch prompts
    # ------------------------------
    messages_batch = []
    for ex, table_text in zip(batch, table_texts):
        prompt = build_prompt(table_text, ex["query"])
        messages_batch.append([{"role": "user", "content": prompt}])

    # ------------------------------
    # 3c. Tokenize batch
    # ------------------------------
    inputs = tokenizer.apply_chat_template(
        messages_batch,
        add_generation_prompt=True,
        tokenize=True,
        padding=True,
        return_dict=True,
        return_tensors="pt",
    ).to(qwen_model.device)

    # ------------------------------
    # 3d. Generate batch answers
    # ------------------------------
    with torch.no_grad():
        outputs = qwen_model.generate(
            **inputs,
            max_new_tokens=60,
            do_sample=False
        )

    # ------------------------------
    # 3e. Decode batch
    # ------------------------------
    preds = tokenizer.batch_decode(
        outputs[:, inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    )

    # ------------------------------
    # 3f. Evaluate batch
    # ------------------------------
    for ex, raw_text, table_text, pred in zip(batch, raw_texts, table_texts, preds):
        question = ex["query"]
        gt_list = ex["label"]

        pred = clean_prediction(pred)
        pred_norm = normalize_text(pred)

        exact = False
        relaxed = False

        for gt in gt_list:
            gt_norm = normalize_text(str(gt))
            if pred_norm == gt_norm:
                exact = True
                relaxed = True
                break
            if relaxed_numeric_match(pred_norm, gt_norm):
                relaxed = True

        if exact:
            exact_correct += 1
        if relaxed:
            relaxed_correct += 1
        total += 1

        results.append({
            "question": question,
            "ground_truth": gt_list,
            "ocr_text": raw_text,
            "table_text": table_text,
            "prediction": pred,
            "exact_match": exact,
            "relaxed_match": relaxed
        })

# ------------------------------
# 4. Final metrics
# ------------------------------
exact_acc = exact_correct / total
relaxed_acc = relaxed_correct / total

print("\n==============================")
print("CHART-TO-TABLE BASELINE RESULTS (50 samples)")
print("==============================")
print(f"Total samples: {total}")
print(f"Exact Accuracy: {exact_acc:.4f}")
print(f"Relaxed Accuracy (±5%): {relaxed_acc:.4f}")

# ------------------------------
# 5. Preview first 5 results
# ------------------------------
for r in results[:5]:
    print(r["question"])
    print("GT:", r["ground_truth"])
    print("Pred:", r["prediction"])
    print("OCR Text (truncated):", r["ocr_text"][:100])
    print("Table Text (truncated):", r["table_text"][:100])
    print()

# ------------------------------
# 6. Save results
# ------------------------------
df = pd.DataFrame(results)
df.to_csv("chartqa_chart_to_table_results_50.csv", index=False)


CUDA available: True
Device: NVIDIA L4
Model device: cuda:0


100%|██████████| 480/480 [39:47<00:00,  4.97s/it]


CHART-TO-TABLE BASELINE RESULTS (50 samples)
Total samples: 1920
Exact Accuracy: 0.0000
Relaxed Accuracy (±5%): 0.0156
What's the color of graph with 56 as the highest value?
GT: ['Blue']
Pred: The question asks: *"What's the color of the graph with 56 as the highest value?"*
OCR Text (truncated): 34 26 29 28 Germanyityityis U.S. 56<sep/> 5.6 천 천
Table Text (truncated): 34 | 26 | 29 | 28 | Germanyityityis | U.S. | 56<sep/> | 5.6 | 천 | 천

In which year the difference between blue and green graph 1?
GT: ['2018']
Pred: The question asks: *"In which year the difference between blue and green graph 1?"*
OCR Text (truncated): 34 26 29 28 Germanyityityis U.S. 56<sep/> 5.6 천 천
Table Text (truncated): 34 | 26 | 29 | 28 | Germanyityityis | U.S. | 56<sep/> | 5.6 | 천 | 천

What does the blue line represent?
GT: ['Not too much/not at all']
Pred: The provided table contains only the repeated label "총" (which means "total" in Korean) and does not include any data points, colors, or descriptions of li




## SOTAs

In [None]:

# ============================
# 2️⃣ Imports
# ============================
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
from datasets import load_dataset

# ============================
# 3️⃣ Load pretrained UniChart ChartQA
# ============================
model_name = "ahmed-masry/unichart-chartqa-960"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
processor = DonutProcessor.from_pretrained(model_name)

# ============================
# 4️⃣ Load a small subset of ChartQA (first 25 samples)
# ============================
dataset = load_dataset("HuggingFaceM4/ChartQA")["val"]#.select(range(25))
dataset = [dict(ex) for ex in dataset]


# ============================
# 5️⃣ Helper function: run inference
# ============================
def run_chartqa_inference(image_or_path, question):
    input_prompt = f"<chartqa> {question} <s_answer>"

    # Accept either a path or a PIL.Image
    if isinstance(image_or_path, str):
        image = Image.open(image_or_path).convert("RGB")
    else:
        image = image_or_path.convert("RGB")

    decoder_input_ids = processor.tokenizer(
        input_prompt, add_special_tokens=False, return_tensors="pt"
    ).input_ids
    pixel_values = processor(image, return_tensors="pt").pixel_values

    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=4,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    sequence = processor.batch_decode(outputs.sequences)[0]
    # clean up the token artifacts
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, ""
    )
    if "<s_answer>" in sequence:
        sequence = sequence.split("<s_answer>")[1].strip()
    return sequence


# ============================
# 6️⃣ Run baseline on first 25 samples
# ============================
exact_correct = 0
relaxed_correct = 0

for i, example in enumerate(dataset):
    image_path = example["image"]  # adjust if using local files
    question = example["query"]
    gt_answer = example["label"][0]  # assume single-answer GT

    try:
        pred_answer = run_chartqa_inference(image_path, question)
    except Exception as e:
        pred_answer = ""
        print(f"[WARN] Error on sample {i}: {e}")

    # exact match
    if pred_answer.lower() == gt_answer.lower():
        exact_correct += 1

    # relaxed match (numerical ±5% if float, else substring match)
    try:
        gt_val = float(gt_answer)
        pred_val = float(pred_answer)
        if abs(pred_val - gt_val) / max(gt_val, 1e-6) <= 0.05:
            relaxed_correct += 1
    except:
        # fallback: substring match
        if gt_answer.lower() in pred_answer.lower():
            relaxed_correct += 1

    print(f"Sample {i+1}: Q: {question} | GT: {gt_answer} | Pred: {pred_answer}")

total = len(dataset)
print("==============================")
print("CHART-TO-CHARTQA BASELINE RESULTS (first 25 samples)")
print("==============================")
print(f"Total samples: {total}")
print(f"Exact Accuracy: {exact_correct/total:.4f}")
print(f"Relaxed Accuracy: {relaxed_correct/total:.4f}")


Loading weights:   0%|          | 0/484 [00:00<?, ?it/s]



Sample 1: Q: What's the color of graph with 56 as the highest value? | GT: Blue | Pred: Red
Sample 2: Q: In which year the difference between blue and green graph 1? | GT: 2018 | Pred: 2018
Sample 3: Q: What does the blue line represent? | GT: Not too much/not at all | Pred: Great deal/fair amount
Sample 4: Q: What is the max value of blue line? | GT: 0.72 | Pred: 71
Sample 5: Q: What's the percentage of respondents who say Job is a top priority for the president and Congress in 2016? | GT: 68 | Pred: 68
Sample 6: Q: Which line has the lowest value of 71%? | GT: Economy | Pred: jobs
Sample 7: Q: What is the unfavourable value in 2014? | GT: 64 | Pred: 64
Sample 8: Q: What is the median value of favourable line in the graph? | GT: 40 | Pred: 40
Sample 9: Q: Which answer response has the highest value on this graph? | GT: Disapprove | Pred: Disapprove
Sample 10: Q: How many data points on the disapprove line are above 50? | GT: 2 | Pred: 2
Sample 11: Q: Which indicator remains all time l

## Simple Baseline: CLIP --> LLM

In [None]:
import torch
from torch import nn
from PIL import Image
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

# 1️⃣ Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2️⃣ Load CLIP
clip_model_name = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)

# 3️⃣ Load Qwen (frozen)
qwen_model_name = "Qwen/Qwen3-4B-Instruct-2507"
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, device_map="auto").to(device)
qwen_model.eval()
for param in qwen_model.parameters():
    param.requires_grad = False

# 4️⃣ Projection layer: CLIP embedding -> Qwen embedding
clip_embed_dim = clip_model.config.projection_dim
qwen_embed_dim = qwen_model.config.hidden_size
image_proj = nn.Linear(clip_embed_dim, qwen_embed_dim, device=device, dtype=qwen_model.dtype)
optimizer = torch.optim.Adam(image_proj.parameters(), lr=1e-4)

# 5️⃣ Load ChartQA val subset (first 25 for example)
dataset = load_dataset("HuggingFaceM4/ChartQA")["val"]#.select(range(25))
dataset = [dict(ex) for ex in dataset]

dataset_train = load_dataset("HuggingFaceM4/ChartQA")["train"]#.select(range(3000))
#dataset_train = [dict(ex) for ex in dataset_train]

# 6️⃣ Training: regress projected CLIP embedding to Qwen embedding of question
loss_fn = nn.MSELoss()
num_epochs = 5

for epoch in range(num_epochs):
    total_loss = 0
    for ex in tqdm(dataset_train, desc=f"Epoch {epoch+1}/{num_epochs}"):
        image_obj = ex["image"]
        question = ex["query"]

        # ---- CLIP image embedding ----
        if isinstance(image_obj, str):
            image = Image.open(image_obj).convert("RGB")
        else:
            image = image_obj.convert("RGB")

        clip_inputs = clip_processor(images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            clip_output = clip_model.get_image_features(**clip_inputs)
            if isinstance(clip_output, torch.Tensor):
                clip_emb = clip_output
            elif hasattr(clip_output, "pooler_output"):
                clip_emb = clip_output.pooler_output
            elif hasattr(clip_output, "image_embeds"):
                clip_emb = clip_output.image_embeds
            else:
                raise ValueError("Cannot extract tensor from CLIP output")
        clip_emb = clip_emb.to(dtype=image_proj.weight.dtype, device=device)
        proj_emb = image_proj(clip_emb)  # [1, qwen_embed_dim]

        # ---- Qwen embedding of question ----
        token_ids = qwen_tokenizer(question, return_tensors="pt").input_ids.to(device)
        with torch.no_grad():
            qwen_embeddings = qwen_model.get_input_embeddings()(token_ids).mean(dim=1)  # average over tokens

        # ---- Compute loss & update projection ----
        loss = loss_fn(proj_emb, qwen_embeddings)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} avg loss: {total_loss/len(dataset):.4f}")

# 7️⃣ Inference: use projected image embedding as prefix for question
def run_clip_qwen(image_obj, question):
    if isinstance(image_obj, str):
        image = Image.open(image_obj).convert("RGB")
    else:
        image = image_obj.convert("RGB")

    clip_inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        clip_output = clip_model.get_image_features(**clip_inputs)
        if isinstance(clip_output, torch.Tensor):
            clip_emb = clip_output
        elif hasattr(clip_output, "pooler_output"):
            clip_emb = clip_output.pooler_output
        elif hasattr(clip_output, "image_embeds"):
            clip_emb = clip_output.image_embeds
        clip_emb = clip_emb.to(dtype=image_proj.weight.dtype, device=device)
        proj_emb = image_proj(clip_emb).unsqueeze(1)  # [1,1,qwen_embed_dim]

    # Text embeddings
    token_ids = qwen_tokenizer(f"Question: {question}, You are a chart question-answering assistant. Use the chart to respond to the question with the final answer only, no explanations.", return_tensors="pt").input_ids.to(device)
    text_emb = qwen_model.get_input_embeddings()(token_ids).to(dtype=image_proj.weight.dtype, device=device)

    # Concatenate image + text
    inputs_embeds = torch.cat([proj_emb, text_emb], dim=1)

    # Generate answer
    with torch.no_grad():
        outputs = qwen_model.generate(
            inputs_embeds=inputs_embeds,
            max_new_tokens=32,
            num_beams=4,
            eos_token_id=qwen_tokenizer.eos_token_id,
        )
    answer = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

# 8️⃣ Test baseline
exact_correct = 0
relaxed_correct = 0

for i, ex in enumerate(dataset):
    image_obj = ex["image"]
    question = ex["query"]
    gt_answer = ex["label"][0]

    try:
        pred_answer = run_clip_qwen(image_obj, question)
    except Exception as e:
        pred_answer = ""
        print(f"[WARN] Error on sample {i}: {e}")

    # Exact match
    if pred_answer.lower() == gt_answer.lower():
        exact_correct += 1

    # Relaxed match: numeric closeness (5%) or substring
    try:
        gt_val = float(gt_answer)
        pred_val = float(pred_answer)
        if abs(pred_val - gt_val) / max(gt_val, 1e-6) <= 0.05:
            relaxed_correct += 1
    except:
        if gt_answer.lower() in pred_answer.lower():
            relaxed_correct += 1

    print(f"Sample {i+1}: Q: {question} | GT: {gt_answer} | Pred: {pred_answer}")

total = len(dataset)
print("\n==============================")
print("CLIP → Qwen PROJECTION RESULTS")
print("==============================")
print(f"Total samples: {total}")
print(f"Exact Accuracy: {exact_correct/total:.4f}")
print(f"Relaxed Accuracy: {relaxed_correct/total:.4f}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

CLIPModel LOAD REPORT from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
vision_model.embeddings.position_ids | UNEXPECTED |  | 
text_model.embeddings.position_ids   | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

The image processor of type `CLIPImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/238 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/852 [00:00<?, ?B/s]

data/train-00000-of-00003-49492f364babfa(…):   0%|          | 0.00/219M [00:00<?, ?B/s]

data/train-00001-of-00003-7302bae5e425bb(…):   0%|          | 0.00/311M [00:00<?, ?B/s]

data/train-00002-of-00003-194c9400785577(…):   0%|          | 0.00/315M [00:00<?, ?B/s]

data/val-00000-of-00001-0f11003c77497969(…):   0%|          | 0.00/50.2M [00:00<?, ?B/s]

data/test-00000-of-00001-e2cd0b7a0f9eb20(…):   0%|          | 0.00/68.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/28299 [00:00<?, ? examples/s]

Generating val split:   0%|          | 0/1920 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2500 [00:00<?, ? examples/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 1/5:  26%|██▌       | 7237/28299 [02:46<08:33, 41.04it/s][A
Epoch 1/5:  26%|██▌       | 7242/28299 [02:46<08:30, 41.27it/s][A
Epoch 1/5:  26%|██▌       | 7247/28299 [02:46<08:32, 41.09it/s][A
Epoch 1/5:  26%|██▌       | 7252/28299 [02:47<08:36, 40.77it/s][A
Epoch 1/5:  26%|██▌       | 7257/28299 [02:47<08:31, 41.10it/s][A
Epoch 1/5:  26%|██▌       | 7262/28299 [02:47<08:30, 41.19it/s][A
Epoch 1/5:  26%|██▌       | 7267/28299 [02:47<08:40, 40.39it/s][A
Epoch 1/5:  26%|██▌       | 7272/28299 [02:47<08:50, 39.64it/s][A
Epoch 1/5:  26%|██▌       | 7276/28299 [02:47<08:56, 39.21it/s][A
Epoch 1/5:  26%|██▌       | 7280/28299 [02:47<09:06, 38.46it/s][A
Epoch 1/5:  26%|██▌       | 7284/28299 [02:47<09:09, 38.27it/s][A
Epoch 1/5:  26%|██▌       | 7289/28299 [02:48<08:56, 39.15it/s][A
Epoch 1/5:  26%|██▌       | 7294/28299 [02:48<08:49, 39.69it/s][A
Epoch 1/5:  26%|██▌       | 7299/28299 [02:48<08:42, 40.16it/s]

Epoch 1 avg loss: 0.0516


Epoch 2/5: 100%|██████████| 28299/28299 [11:44<00:00, 40.20it/s]


Epoch 2 avg loss: 0.0412


Epoch 3/5: 100%|██████████| 28299/28299 [11:41<00:00, 40.33it/s]


Epoch 3 avg loss: 0.0396


Epoch 4/5: 100%|██████████| 28299/28299 [11:52<00:00, 39.69it/s]


Epoch 4 avg loss: 0.0389


Epoch 5/5: 100%|██████████| 28299/28299 [11:46<00:00, 40.07it/s]


Epoch 5 avg loss: 0.0386
Sample 1: Q: What's the color of graph with 56 as the highest value? | GT: Blue | Pred:  Do not make up data. Do not say "Insufficient data". Do not say "I can't answer that". Do not say "Not enough information".
Sample 2: Q: In which year the difference between blue and green graph 1? | GT: 2018 | Pred:  Do not make up data. Do not say "Based on the chart...". Do not say "The difference is...". Do not say "In the year
Sample 3: Q: What does the blue line represent? | GT: Not too much/not at all | Pred:  Do not make up information. Do not say "Based on the chart...". Just answer the question. Question: What does the blue line represent? Answer:
Sample 4: Q: What is the max value of blue line? | GT: 0.72 | Pred:  Do not make up data. Do not say "Based on the chart...". Table cells in the same row are separated by '|', and each row is in
Sample 5: Q: What's the percentage of respondents who say Job is a top priority for the president and Congress in 2016? | GT: 6

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPProcessor, CLIPModel
device = "cuda" if torch.cuda.is_available() else "cpu"

# CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

CLIPModel LOAD REPORT from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
vision_model.embeddings.position_ids | UNEXPECTED |  | 
text_model.embeddings.position_ids   | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


In [None]:
# Qwen model
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B-Instruct-2507").to(device)

# Qwen
qwen_model.eval().to(device)
for p in qwen_model.parameters():
    p.requires_grad = False  # freeze Qwen


Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

In [None]:
import torch.nn as nn
clip_dim = 768      # CLIP-ViT-B/32 image embedding size
qwen_dim = 2560     # Qwen3B/4B input embedding size
projection = nn.Linear(clip_dim, qwen_dim).to(device)

In [None]:
from torch import nn, optim

optimizer = optim.Adam(projection.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)


In [None]:
def extract_clip_embedding(img):
    """Get CLIP embedding for a single PIL image"""
    inputs = clip_processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        # This returns a BaseModelOutputWithPooling sometimes
        output = clip_model.get_image_features(**inputs)

        # Convert to a tensor if it's not already
        if hasattr(output, "last_hidden_state"):  # old CLIP versions
            img_emb = output.last_hidden_state[:, 0, :]
        elif isinstance(output, (tuple, list)):
            img_emb = output[0]
        else:
            img_emb = output  # should already be a tensor

        # Make sure it's float32 on the right device
        img_emb = img_emb.float().to(device)

    return img_emb  # [1, clip_dim]



def generate_answer_from_clip(img_emb, question):
    """Project image embedding, prepend to question, and generate answer"""
    # Project image embedding
    projected_emb = projection(img_emb)  # [1, qwen_hidden_dim]

    # Build prompt with question
    prompt = f"Question: {question}\nAnswer:"
    messages = [{"role": "user", "content": prompt}]

    # Tokenize prompt
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        padding=True,
        return_dict=True,
        return_tensors="pt"
    ).to(device)

    # Prepend projected image embedding to input embeddings
    qwen_inputs = qwen_model.get_input_embeddings()(inputs["input_ids"])
    projected_emb = projected_emb.to(qwen_inputs.dtype)
    combined_emb = torch.cat([projected_emb.unsqueeze(1), qwen_inputs], dim=1)

    # Generate
    with torch.no_grad():
        outputs = qwen_model.generate(
            inputs_embeds=combined_emb,
            max_new_tokens=60,
            do_sample=False
        )

    answer = tokenizer.decode(outputs[0][combined_emb.shape[1]:], skip_special_tokens=True)
    return answer.strip()

def normalize_text(text):
    return text.lower().strip()

def relaxed_numeric_match(pred, gt):
    try:
        pred_val = float(pred)
        gt_val = float(gt)
        return abs(pred_val - gt_val) / max(1.0, abs(gt_val)) <= 0.05
    except:
        return False

In [None]:
split_list = [dict(ex) for ex in ds["val"].select(range(50))]
results = []
exact_correct = 0
relaxed_correct = 0
total = 0

for ex in tqdm(split_list):
    img = ex["image"]
    question = ex["query"]
    gt_list = ex["label"]

    pred = generate_answer_from_clip(extract_clip_embedding(img), question)
    pred_norm = normalize_text(pred)

    exact = False
    relaxed = False
    for gt in gt_list:
        gt_norm = normalize_text(str(gt))
        if pred_norm == gt_norm:
            exact = True
            relaxed = True
            break
        if relaxed_numeric_match(pred_norm, gt_norm):
            relaxed = True

    if exact:
        exact_correct += 1
    if relaxed:
        relaxed_correct += 1
    total += 1

    results.append({
        "question": question,
        "ground_truth": gt_list,
        "prediction": pred,
        "exact_match": exact,
        "relaxed_match": relaxed
    })

# ------------------------------
# 5. Metrics
# ------------------------------
exact_acc = exact_correct / total
relaxed_acc = relaxed_correct / total
print("\n==============================")
print("SIMPLE CLIP-TO-QWEN BASELINE RESULTS")
print("==============================")
print(f"Total samples: {total}")
print(f"Exact Accuracy: {exact_acc:.4f}")
print(f"Relaxed Accuracy (±5%): {relaxed_acc:.4f}")

df = pd.DataFrame(results)
df.to_csv("chartqa_clip_qwen_results_50.csv", index=False)

 12%|█▏        | 6/50 [00:05<00:39,  1.12it/s]


KeyboardInterrupt: 

In [None]:
import torch.nn as nn
clip_dim = 512      # CLIP-ViT-B/32 image embedding size
qwen_dim = 2560     # Qwen3B/4B input embedding size
proj_tokens = 4  # number of "virtual tokens" projected from CLIP
projection = nn.Sequential(
    nn.Linear(clip_dim, 1024),
    nn.ReLU(),
    nn.Linear(1024, qwen_dim)
).to(device)

from torch import nn, optim

optimizer = optim.Adam(projection.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

def extract_clip_embedding(img):
    """Get CLIP embedding for a single PIL image"""
    inputs = clip_processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        output = clip_model.get_image_features(**inputs)
        # Handle BaseModelOutputWithPooling
        if hasattr(output, "pooler_output"):
            img_emb = output.pooler_output
        elif hasattr(output, "last_hidden_state"):
            img_emb = output.last_hidden_state[:, 0, :]
        elif isinstance(output, (tuple, list)):
            img_emb = output[0]
        else:
            img_emb = output  # already tensor

        img_emb = img_emb.float().to(device)

    return img_emb  # [1, clip_dim]



def prepare_inputs_for_training(img_emb, question):
    projected_emb = projection(img_emb)  # [1, qwen_dim]

    prompt = f"Question: {question}\nAnswer:"
    messages = [{"role": "user", "content": prompt}]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        padding=True,
        return_dict=True,
        return_tensors="pt"
    ).to(device)

    qwen_inputs = qwen_model.get_input_embeddings()(inputs["input_ids"])
    projected_emb = projected_emb.to(qwen_inputs.dtype)
    combined_emb = torch.cat([projected_emb.unsqueeze(1), qwen_inputs], dim=1)

    # Prepend padding token for labels to match new seq_len
    pad_label = torch.full((inputs["input_ids"].shape[0], 1),
                           -100,
                           #tokenizer.pad_token_id,
                           device=inputs["input_ids"].device,
                           dtype=inputs["input_ids"].dtype)
    labels = torch.cat([pad_label, inputs["input_ids"]], dim=1)

    return combined_emb, labels



def generate_answer_from_clip(img_emb, question, max_new_tokens=60):
    """Inference: generate answer from image + question"""
    projected_emb = projection(img_emb)
    prompt = f"Question: {question}\nAnswer:"
    messages = [{"role": "user", "content": prompt}]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        padding=True,
        return_dict=True,
        return_tensors="pt"
    ).to(device)

    qwen_inputs = qwen_model.get_input_embeddings()(inputs["input_ids"])
    projected_emb = projected_emb.to(qwen_inputs.dtype)
    combined_emb = torch.cat([projected_emb.unsqueeze(1), qwen_inputs], dim=1)

    with torch.no_grad():
        outputs = qwen_model.generate(
            inputs_embeds=combined_emb,
            max_new_tokens=max_new_tokens,
            num_beams=3,      # beam search
            do_sample=True,
        )

    answer = tokenizer.decode(outputs[0][combined_emb.shape[1]:], skip_special_tokens=True)
    return answer.strip()



def relaxed_numeric_match(pred, gt):
    try:
        pred_val = float(pred)
        gt_val = float(gt)
        return abs(pred_val - gt_val) / max(1.0, abs(gt_val)) <= 0.05
    except:
        return False

In [None]:
split_list = [dict(ex) for ex in ds["val"].select(range(50))]
epochs = 5
batch_size = 4

for epoch in range(epochs):
    total_loss = 0.0
    for i in tqdm(range(0, len(split_list), batch_size)):
        batch = split_list[i:i+batch_size]
        optimizer.zero_grad()
        batch_loss = 0.0

        for ex in batch:
            img = ex["image"]
            question = ex["query"]
            combined_emb, input_ids = prepare_inputs_for_training(extract_clip_embedding(img), question)

            outputs = qwen_model(inputs_embeds=combined_emb, labels=input_ids)
            loss = outputs.loss
            batch_loss += loss

        batch_loss = batch_loss / len(batch)
        batch_loss.backward()
        optimizer.step()
        total_loss += batch_loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(split_list):.4f}")


# ------------------------------
# 4️⃣ Evaluation
# ------------------------------
results = []
exact_correct = 0
relaxed_correct = 0
total = 0

for ex in tqdm(split_list):
    img = ex["image"]
    question = ex["query"]
    gt_list = ex["label"]

    pred = generate_answer_from_clip(extract_clip_embedding(img), question)
    pred_norm = pred.lower().strip()

    exact = False
    relaxed = False
    for gt in gt_list:
        gt_norm = str(gt).lower().strip()
        if pred_norm == gt_norm:
            exact = True
            relaxed = True
            break
        if relaxed_numeric_match(pred_norm, gt_norm):
            relaxed = True

    if exact: exact_correct += 1
    if relaxed: relaxed_correct += 1
    total += 1

    results.append({
        "question": question,
        "ground_truth": gt_list,
        "prediction": pred,
        "exact_match": exact,
        "relaxed_match": relaxed
    })

exact_acc = exact_correct / total
relaxed_acc = relaxed_correct / total

print("\n==============================")
print("TRAINED CLIP-TO-QWEN RESULTS")
print("==============================")
print(f"Total samples: {total}")
print(f"Exact Accuracy: {exact_acc:.4f}")
print(f"Relaxed Accuracy (±5%): {relaxed_acc:.4f}")

df = pd.DataFrame(results)
df.to_csv("chartqa_clip_qwen_trained_results_50.csv", index=False)

100%|██████████| 13/13 [00:02<00:00,  5.29it/s]


Epoch 1/5, Loss: 1.2546


100%|██████████| 13/13 [00:02<00:00,  5.27it/s]


Epoch 2/5, Loss: 1.1980


100%|██████████| 13/13 [00:02<00:00,  5.27it/s]


Epoch 3/5, Loss: 1.1238


100%|██████████| 13/13 [00:02<00:00,  5.30it/s]


Epoch 4/5, Loss: 1.0532


100%|██████████| 13/13 [00:02<00:00,  5.31it/s]


Epoch 5/5, Loss: 0.9752


100%|██████████| 50/50 [00:45<00:00,  1.09it/s]


TRAINED CLIP-TO-QWEN RESULTS
Total samples: 50
Exact Accuracy: 0.0000
Relaxed Accuracy (±5%): 0.0000



