In [None]:
!pip install git+https://github.com/huggingface/transformers accelerate
!pip install qwen-vl-utils
!pip install datasets

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-nzoxgvvb
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-nzoxgvvb
  Resolved https://github.com/huggingface/transformers to commit 6daa3eeba582facb57cd71db8efb66998b12942f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
from tqdm import tqdm
import pandas as pd
import os

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float16, device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

system_prompt = (
    "You are an expert visual assistant trained to describe handwritten digit images from the MNIST dataset. "
    "Simply mention what digit is in the image. For example, The handwritten digit image depicts a '3' or The handwritten digit image depicts a '5'"
    "Use clear and concise descriptions suitable for training AI models. No need to mention the digit/background colors."
)
user_prompt = "Describe this handwritten digit image."

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
mnist = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
batch_size = 64
loader = DataLoader(mnist, batch_size=batch_size, shuffle=False)

captions = []
tmp_dir = "./tmp_mnist_images"
os.makedirs(tmp_dir, exist_ok=True)

model.eval()

for batch_idx, (img_tensors, labels) in tqdm(enumerate(loader), total=len(loader), desc="Generating captions"):
    pil_images = []
    img_paths = []

    for i in range(img_tensors.size(0)):
        pil_img = transforms.ToPILImage()(img_tensors[i]).convert("RGB")
        img_path = os.path.join(tmp_dir, f"mnist_{batch_idx}_{i}.jpg")
        pil_img.save(img_path)
        pil_images.append(pil_img)
        img_paths.append(img_path)

    messages = []
    for img_path in img_paths:
        messages.append([
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img_path},
                    {"type": "text", "text": user_prompt}
                ]
            }
        ])

    batch_texts = [processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages]
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=batch_texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=50)
        trimmed_ids = generated_ids[:, inputs.input_ids.shape[1]:]
        decoded = processor.batch_decode(trimmed_ids, skip_special_tokens=True)

    for i, caption in enumerate(decoded):
        captions.append({
            "index": batch_idx * batch_size + i,
            "label": labels[i].item(),
            "caption": caption.strip()
        })

df = pd.DataFrame(captions)
df.to_csv("mnist_qwen25vl_captions_batched.csv", index=False)
print("✅ Saved captions to mnist_qwen25vl_captions_batched.csv")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Generating captions: 100%|██████████| 938/938 [36:08<00:00,  2.31s/it]


✅ Saved captions to mnist_qwen25vl_captions_batched.csv


In [None]:
df = pd.read_csv("mnist_qwen25vl_captions_batched.csv")
df.head()

Unnamed: 0,index,label,caption
0,0,5,The handwritten digit image depicts a '5'.
1,1,0,The handwritten digit image depicts a '3'.
2,2,4,The handwritten digit image depicts a '4'.
3,3,1,The handwritten digit image depicts a '3'.
4,4,9,The handwritten digit image depicts a '9'.


In [None]:
df["caption"][115]

"The handwritten digit image depicts a '4'."

In [None]:
from huggingface_hub import login

login(token="ENTER_HUGGINGFACE_WRITE_TOKEN")

In [None]:
from datasets import Dataset, DatasetDict, Features, Value, ClassLabel
from datasets.features import Image
import pandas as pd

df = pd.read_csv("mnist_qwen25vl_captions_batched.csv")

df["image"] = df["index"].apply(lambda x: f"./tmp_mnist_images/mnist_{x // 64}_{x % 64}.jpg")

features = Features({
    "index": Value("int32"),
    "label": ClassLabel(names=[str(i) for i in range(10)]),
    "caption": Value("string"),
    "image": Image()
})

dataset = Dataset.from_pandas(df, features=features)

dataset = dataset.cast_column("image", Image())

dataset.push_to_hub("kishore-s-15/mnist-qwen2.5-2B-captions", private=True)

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/60000 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/600 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/kishore-s-15/mnist-qwen2.5-2B-captions/commit/32ff15b10423fe09fcf841059f75679ca45e6c48', commit_message='Upload dataset', commit_description='', oid='32ff15b10423fe09fcf841059f75679ca45e6c48', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/kishore-s-15/mnist-qwen2.5-2B-captions', endpoint='https://huggingface.co', repo_type='dataset', repo_id='kishore-s-15/mnist-qwen2.5-2B-captions'), pr_revision=None, pr_num=None)

In [None]:
from datasets import Dataset, Features, Value, Image as HFImage

# Reload metadata
# df = pd.read_csv("mnist-qwen25vl-captions/data.csv")

# Define features
features = Features({
    "image": HFImage(),       # local file paths to images
    "label": Value("int64"),
    "caption": Value("string")
})

# Create and push
dataset = Dataset.from_pandas(df, features=features)
dataset.push_to_hub("kishore-s-15/mnist-image-captioned")
print("🚀 Hugging Face dataset pushed!")

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/300 [00:00<?, ?ba/s]

Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/300 [00:00<?, ?ba/s]

🚀 Hugging Face dataset pushed!


In [None]:
from datasets import load_dataset

ds = load_dataset("kishore-s-15/mnist-with-captions")
# print(ds[0])  # {'image': <PIL.Image.Image>, 'label': 3, 'caption': "..."}

README.md:   0%|          | 0.00/354 [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/250M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

In [None]:
ds['train'][0]

{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224>,
 'label': 5,
 'caption': 'The handwritten digit image depicts a "5". The stroke is relatively thick and straight, with a slight curve at the bottom. The background is black, and the digit is white.'}

In [None]:
from datasets import load_dataset

ds = load_dataset("kishore-s-15/mnist-qwen2.5-2B-captions")
# print(ds[0])  # {'image': <PIL.Image.Image>, 'label': 3, 'caption': "..."}

README.md:   0%|          | 0.00/605 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/141M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

In [None]:
ds["train"]["caption"][0]

"The handwritten digit image depicts a '5'."