## Imports & Data (run this at start)

In [7]:
import re
import os
from tqdm import tqdm
from datasets import load_dataset

from PIL import Image
from transformers.image_utils import load_image
import matplotlib.pyplot as plt
import numpy as np

import torch
from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig, Trainer, TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, PeftModel

In [8]:
data_files = {"train": "train.jsonl", "dev": "dev.jsonl", "test": "test.jsonl"}
dataset = load_dataset("cambridgeltl/vsr_random", data_files=data_files)

train_ds = dataset["train"]
test_ds = dataset["test"]
val_ds = dataset["dev"]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "HuggingFaceTB/SmolVLM-Instruct"

## Testing out performance pre fine-tuning

First we load a sample image from the data and plot it.

In [9]:
img_path = os.path.join("visual-spatial-reasoning/images", train_ds[2]["image"])
img = load_image(img_path)
plt.imshow(img)

ValueError: Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got visual-spatial-reasoning/images\000000262118.jpg. Failed with cannot identify image file <_io.BytesIO object at 0x0000018628E68220>

We can also see what the caption and label for this is

In [16]:
caption = train_ds[2]["caption"]
label = train_ds[2]["label"]
relation = train_ds[2]["relation"]

print(caption)
print(label)
print(relation)

The bed is right of the bench.
1
right of
bed bench.


Now let's load the SmolVLM-Instruct model, which is the Base model finetuned for handling structured prompts/questions.

In [None]:
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id,
                                                torch_dtype=torch.bfloat16,
                                                _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager").to(DEVICE)

model.eval()

In [None]:
def run_inference(img_path, caption):
  # Load images
  img = load_image(img_path)
 
  # Create input messages
  messages = [
      {
          "role": "user",
          "content": [
              {"type": "image"},
              {"type": "text", "text": f"{caption.rstrip('.')}, true or false?"}
          ]
      },
  ]

  # Prepare inputs
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
  inputs = processor(text=prompt, images=[img], return_tensors="pt")
  inputs = inputs.to(DEVICE)

  # Generate outputs
  with torch.no_grad():
    generated_ids = model.generate(**inputs, max_new_tokens=500)
  generated_texts = processor.batch_decode(
      generated_ids,
      skip_special_tokens=True,
  )

  return generated_texts[0]

output = run_inference(img_path, caption)

print(output)

In this case, the model predicts correctly, since the bed is on the right side of the bench (from the perspective of the bench). Checkout line 3 of visual-spatial-reasoning/data/split/random/train.jsonl for the ground truth label (its 1 for true).

## Prelim Run on Test Set

In [None]:
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id,
                                                torch_dtype=torch.bfloat16,
                                                _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager").to(DEVICE)
model.eval()

In [None]:
# Regular expression pattern to extract assistant response
answer_pattern = re.compile(r"Assistant:\s*(\w+)")

correct = 0
preds = []
total = len(test_ds)

# Function to process each entry
def evaluate(entry):
    global correct
    global preds
    img_path = os.path.join("visual-spatial-reasoning/images", entry["image"])

    output = run_inference(img_path, entry["caption"])

    match = answer_pattern.search(output)
    answer = match.group(1) if match else None

    if answer == "True":
      preds.append(1)
    else:
      preds.append(0)

    # Validate answer
    if answer in {"True", "False"} and (answer == "True") == (entry["label"] == 1):
        correct += 1

# Process dataset with tqdm for progress tracking
for entry in tqdm(test_ds, desc="Processing images"):
    evaluate(entry)

# save preds
with open("preds.txt", "w") as f:
    for i in range(len(preds)):
        f.write(str(preds[i])+"\n")

# Print results
print(f"Total images: {total}")
print(f"Correct: {correct} ({correct / total:.2%} accuracy)")

You should get an accuracy of 66%. Not terrible, but def below human performance of 95%

## Training

A lot of this is taken straight from the following notebook: https://github.com/huggingface/smollm/blob/main/vision/finetuning/Smol_VLM_FT.ipynb

In this case we are fine-tuning the instruct model rather than base since we want it to be able to answer spatial queries in a Q&A format which Instruct has already been made for.

In [None]:
# If you want to deploy the model, you should not quantize to 4bits since converting to onnx not supported

USE_QLORA = False


processor = AutoProcessor.from_pretrained(model_id)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
    use_dora=False if USE_QLORA else True,
    init_lora_weights="gaussian"
)

lora_config.inference_mode = False

if USE_QLORA:

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForVision2Seq.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        _attn_implementation="flash_attention_2"
    ).to(DEVICE)

else:

    model = AutoModelForVision2Seq.from_pretrained(
        model_id,
        torch_dtype = torch.bfloat16,
        _attn_implementation="flash_attention_2"
    ).to(DEVICE)

model.add_adapter(lora_config)
model.enable_adapters()

if USE_QLORA:
    model = prepare_model_for_kbit_training(model)
    
model = get_peft_model(model, lora_config)

print(model.get_nb_trainable_parameters())

The collate function is pretty important when you work with batch training. It defines the structure of your prompts and how to combine them for batch processing.

In [None]:
image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")]

def collate_fn(examples):
    
  texts = []
  images = []

  for example in examples:

      image_path = os.path.join("visual-spatial-reasoning/images", example["image"])

      if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")

      image = load_image(image_path)

      if image.mode != 'RGB':
        image = image.convert('RGB')

      caption = example["caption"]
      label = "True." if example["label"] == 1 else "False."

      messages = [
          {
              "role": "user",
              "content": [
                  {"type": "image"},
                  {"type": "text", "text": f"{caption.rstrip('.')}, true or false?"}
              ]
          },
          {
              "role": "assistant",
              "content": [
                  {"type": "text", "text": label}
              ]
          }
      ]
      text = processor.apply_chat_template(messages, add_generation_prompt=False)

      texts.append(text.strip())
      images.append([image])

  batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
  labels = batch["input_ids"].clone()
  labels[labels == processor.tokenizer.pad_token_id] = -100
  labels[labels == image_token_id] = -100
  batch["labels"] = labels

  return batch

In [None]:
model_name = model_id.split("/")[-1]

training_args = TrainingArguments(
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=5,
    save_total_limit=1,
    optim="paged_adamw_8bit" if USE_QLORA else "adamw_hf", # for 8-bit, keep this, else adamw_hf
    bf16=True, # underlying precision for 8bit
    output_dir=f"./{model_name}-vsr",
    hub_model_id=f"{model_name}-vsr",
    report_to="tensorboard",
    remove_unused_columns=False,
    gradient_checkpointing=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_ds,
)

Now we can train! On a single A100, this took me about 3.5 hours

In [None]:
trainer.train()

The final output model params should be saved in the directory SmolVLM-Instruct-vsr/checkpoint-480

## Load & Re-Test

In [None]:
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager"
)

model = PeftModel.from_pretrained(model, "SmolVLM-Instruct-vsr/checkpoint-5")
model = model.merge_and_unload()
model = model.to(torch.bfloat16).to(DEVICE)
model.eval()

processor = AutoProcessor.from_pretrained(model_id)

In [None]:
correct = 0
preds = []
total = len(test_ds)

# Process dataset with tqdm for progress tracking
for entry in tqdm(test_ds, desc="Processing images"):
    evaluate(entry)

# save preds
with open("preds_post.txt", "w") as f:
    for i in range(len(preds)):
        f.write(str(preds[i])+"\n")

# Print results
print(f"Total images: {total}")
print(f"Correct: {correct} ({correct / total:.2%} accuracy)")

Improved the accuracy from 55% to 70% on the test set after fine-tuning! Nice.
We trained for about 2hrs, so this is decent. We have to keep in mind that spatial awareness is not an easy task and in some ways can be subjective depending on perspective. 

## Export for Deployment

In [None]:
# Load images
image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
image2 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")


# Initialize processor and model
processor = AutoProcessor.from_pretrained(model_id)

# Create input messages
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "image"},
            {"type": "text", "text": "Can you describe the images?"}
        ]
    },
]

# Prepare inputs
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image1, image2], return_tensors="pt")

for k,v in inputs.items():
    print(k, v.shape)

In [None]:
import torch.onnx

DEVICE = "cpu"

# Load Model
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager"
)
# model = PeftModel.from_pretrained(model, "SmolVLM-Instruct-vsr/checkpoint-5")
# model = model.merge_and_unload()
# model = model.to(torch.bfloat16)  # Keep dtype in bfloat16

# Convert to eval mode
model.eval()

# Move model to CPU before export
model.to(DEVICE)
inputs = inputs.to(DEVICE)

tensor_inputs = {
    "pixel_values": inputs["pixel_values"].to(torch.float32),
    "pixel_attention_mask": inputs["pixel_attention_mask"].to(torch.float32),
    "input_ids": inputs["input_ids"],
    "attention_mask": inputs["attention_mask"]
}


# Dynamic axes for variable batch size and sequence length
dynamic_axes = {
    "pixel_values": {0: "batch_size", 1: "patches"},
    "pixel_attention_mask": {0: "batch_size", 1: "patches"},
    "input_ids": {0: "batch_size", 1: "sequence_length"},
    "attention_mask": {0: "batch_size", 1: "sequence_length"},
    "output": {0: "batch_size", 1: "sequence_length"}
}


with torch.no_grad():
    # Pass inputs as a tuple of tensors
    torch.onnx.export(
        model,
        tensor_inputs,  # Proper unpacking
        "smolvlm.onnx",
        input_names=["pixel_values", "pixel_attention_mask", "input_ids", "attention_mask"],
        output_names=["output"],
        dynamic_axes=dynamic_axes,
        opset_version=13,
        do_constant_folding=True,
        export_params=True
    )



print("SmolVLM model exported to smolvlm.onnx")