# Captioning Goya's work with small language models (SLM)



In [1]:
import gc
import GPUtil
import pandas as pd
from pathlib import Path
import PIL
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch

# Paths
DIRS = {
    0: Path("data/save"), # downloaded images
    1: Path("data/clean") # processed images
}

# General purpose functions
def flush_cuda():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

## Retrieving images of Goya's work from Wikidata

I retrieved the metadata and images of Goya's work from Wikidata/Wikimedia Commons. If you use this code, please follow Wikidata's User-Agent policy: https://www.wikidata.org/wiki/Wikidata:Data_access


In [None]:
import csv
import requests
from SPARQLWrapper import SPARQLWrapper, JSON
import time

endpoint_url = "https://query.wikidata.org/sparql"
USER_AGENT = open("user_agent.txt") # Follow Wikidata's User-Agent policy https://www.wikidata.org/wiki/Wikidata:Data_access
DOWNLOAD = False  # Download images

## Query images from Wikidata with 'creator' (wdt:P170) 'Francisco Goya' (wd:Q5432)
query = """#Works by Goya
#title: Works by Goya
#defaultView:ImageGrid
SELECT ?item ?itemLabel ?pic ?title ?inception WHERE {
  ?item wdt:P170 wd:Q5432;
  wdt:P18 ?pic.
  OPTIONAL { ?item wdt:P1476 ?title. }
  OPTIONAL { ?item wdt:P571 ?inception. }
  SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],mul,en". }
}"""


def get_results(endpoint_url, query, user_agent):
    user_agent = user_agent
    sparql = SPARQLWrapper(endpoint_url, agent=user_agent)
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    return sparql.query().convert()


results = get_results(endpoint_url, query, USER_AGENT)

## Donwload images
headers = {'User-Agent': USER_AGENT}
session = requests.Session()
session.headers.update(headers)

res = dict()
for i, painting in enumerate(results["results"]["bindings"]):
    uri = painting["item"]["value"]
    jpg_url = painting["pic"]["value"]
    id = uri.split('/')[-1] # e.g. 'Q6172289'

    file = f'{id}.jpg'
    file_path = DIRS[0] / file
    
    # Cases with unprocessable images
    if id in ["Q1988253", "Q17519472", "Q20181269"]:
        continue

    # Cases with missing inception data
    if "inception" in painting.keys():
        inception = painting["inception"]["value"].split("T")[0].split("-")[0]
    else:
        inception = ""

    # Cases with missing title or different title languages
    if "title" in painting.keys():
        title_lang = painting["title"]["xml:lang"] 
        title = painting["title"]["value"]
    else:
        title_lang = None

    if "xml:lang" in painting["itemLabel"].keys():
        label_lang = painting["itemLabel"]["xml:lang"]
    else:
        label_lang = "en"

    # Cases where itemLabel is the id, e.g. 'Q5849545'
    label = painting["itemLabel"]["value"]
    if label==id:
        label = ""

    title_es = ""
    title_en = ""
    if title_lang=="en":
        title_en = title
    elif title_lang=="es":
          title_es = title
    elif title_lang == None:
        title_en = ""
        title_es = ""
    
    label_es = ""
    label_en = ""
    if label_lang=="en":
         label_en = label
    if label_lang=="es":
         label_es = label

    updates = {
        "id": id,
        "uri": uri,
        "file": file,
        "inception": inception,
        "title_en": title_en,
        "title_es": title_es,
        "label_en": label_en,
        "label_es": label_es,
        "caption": "",
    }

    # Some paintings appear multiple times in 'results', mainly due to several inception dates
    if id not in res:
        res[id] = updates.copy()
    else:
        for k, v in updates.items():
            if res[id].get(k) in ("", None):
                 res[id][k] = v
    
    if DOWNLOAD==True:
        try:
            r = session.get(jpg_url, stream=True, timeout=30)
            r.raise_for_status()
            raw = r.content
            with open(file_path, "wb") as f:
                f.write(raw)
            time.sleep(0.2)
        except Exception as e:
            print(f"Failed {id}", e)

# Store
fieldnames = list(list(res.values())[0].keys())
res = [v for v in res.values()]
with open(DIRS[0]/"metadata.csv", "w", newline="", encoding="utf-8") as f:
    w = csv.DictWriter(f, fieldnames=fieldnames)
    w.writeheader()
    w.writerows(res)

# Additional cleaning
df = pd.read_csv(DIRS[0]/"metadata.csv", dtype="str")
df['label_en'] = df['label_en'].fillna(df["title_en"])
df['label_en'] = df['label_en'].fillna("unknown")
df['inception'] = df['inception'].fillna("unknown")
df.to_csv(DIRS[0]/"medata.data.csv", index=False)

## Resize images

The downloaded images varied greatly in size. I rescaled them so that their longer side measured 512px, matching the input size used to train the image captioning model.

In [None]:
from transformers.image_utils import load_image

MAX_SIZE = 512 # size={"longest_edge": N*512} to fit caption model config
JPEG_QUALITY = 90

def load_and_resize_image(path: Path, max_size: int) -> PIL.Image.Image:
    image = load_image(path.__str__())

    # Resize large images
    w, h = image.size
    m = max(w, h)
    if m <= max_size:
        return image
    else:
        scale = max_size / m
        new_size = (max(1, int(round(w * scale))), max(1, int(round(h * scale)))) 
        return image.resize(new_size, PIL.Image.LANCZOS)


data = pd.read_csv(DIRS[0]/"metadata.csv")

for i, row in data.iterrows():
    # Get image
    image_path = DIRS[0]/row["file"]
    image = load_and_resize_image(image_path, max_size=MAX_SIZE)

    # Save
    path = DIRS[1]/image_path.name
    image.save(
        path,
        format="JPEG",
        quality=JPEG_QUALITY,
        optimize=True,
        progressive=True,
        subsampling="4:2:0",
    )


## Generate image captions

I produced image captions using the Hugging Face's SmolVLM-500M-Instruct model: https://huggingface.co/HuggingFaceTB/SmolVLM-500M-Instruct

In [None]:
from datetime import datetime
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# Free GPU memory
try:
    del model
except NameError:
    pass
flush_cuda()
GPUtil.showUtilization()

# Initialize processor and model
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
model = AutoModelForImageTextToText.from_pretrained(
    "HuggingFaceTB/SmolVLM-500M-Instruct",
    quantization_config=quantization_config,
    attn_implementation="sdpa",
)#.to(device)

GPUtil.showUtilization()

In [None]:
REPLACE = False # replace existing captions in metadata.csv

## Set prompt
prompt_raw = """
Describe the objects, people, and background you see \
in the painting or drawing in a factual manner. 
"""

message = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": prompt_raw},
        ]
    },
]

prompt = processor.apply_chat_template(message, add_generation_prompt=True)

def make_input(image: PIL.Image.Image, prompt: str):
   return processor(text=prompt, images=[image], return_tensors="pt")

paths = sorted([p for p in DIRS[1].rglob("*") if p.suffix.lower() in {".jpg", ".jpeg"}])
df = pd.read_csv(DIRS[0]/"metadata.csv", dtype=str)

# Run inference
with torch.inference_mode():
   for i, p in enumerate(tqdm(paths, total=len(paths))):
    
    img_id = p.stem
    caption_missing = df.loc[df["id"]==img_id, "caption"].isna().any()

    if REPLACE or caption_missing:
        with PIL.Image.open(p) as image:
            inputs = make_input(image, prompt).to(device)

        gen_ids = model.generate(
            **inputs,
            max_new_tokens=200,
            num_beams=4,
            repetition_penalty=1.15,
            do_sample=False,
            #use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict_in_generate=False,
        )

        gen_text = processor.batch_decode(
            gen_ids,
            skip_special_tokens=True,
        )[0].split("Assistant:")[1].strip()
        
        df.loc[df["id"]==img_id, "caption"] = gen_text
        del inputs, gen_ids
        torch.cuda.empty_cache()

    if i % 5 == 0:
        df.to_csv(DIRS[0]/"metadata.csv", index=False)

# Save
df.to_csv(DIRS[0]/"metadata.csv", index=False)
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
df.to_csv(DIRS[0]/f"metadata_{timestamp}.csv", index=False) # to avoid overwriting

## Example output

In [2]:
from helpers import show_images

df = pd.read_csv(DIRS[0]/"metadata.csv")         
ids = ["Q59260209", "Q64956388", "Q64956396"]
df_sample = df[df['id'].isin(ids)]

show_images(df_sample, DIRS[1], img_height=400)

## Semantic search with embeddings

I embedded the produced descriptions into a vector space for semantic search using the multi-qa-mpnet-base-dot-v1 model: https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1 

In [None]:
from datasets import Dataset
from transformers import AutoTokenizer, AutoModel

# Prepare database
df = pd.read_csv(DIRS[0]/"metadata.csv", dtype="str")
df['label_en'] = df['label_en'].fillna(df["title_en"])
df['label_en'] = df['label_en'].fillna("unknown")
df['inception'] = df['inception'].fillna("unknown")

data = Dataset.from_pandas(df)
def concatenate_text(row):
    return {
        #"text": "Title: " + row['label_en'] 
        #+ "\n Year: " + str(row['inception'])
        #+ "\n Caption: " + row['caption']
        "text": "Caption: " + row['caption'] # search the caption only
    }
data = data.map(concatenate_text)

# Free GPU memory
try:
    del model
except NameError:
    pass
flush_cuda()

# Initialize tokenizer and model
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID).to(device)

# Pool embedding dimensions
def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

def get_embeddings(text_list):
    encoded_input = tokenizer(
        text_list, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_inputs = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_inputs)
    return cls_pooling(model_output)


# Calculate embeddings
embeddings = data.map(
    lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]}
)

# Add FAISS index
_ = embeddings.add_faiss_index(column="embeddings")

## Query the database

In [4]:
from helpers import show_images

query = "Scenes of a bull fight"
n_results = 6 # number of results

query_embd = get_embeddings([query]).cpu().detach().numpy()
scores, samples = embeddings.get_nearest_examples(
    "embeddings", query_embd, k=n_results
)

samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=True, inplace=True)

show_images(samples_df, DIRS[1])