# Qwen finetune

## Information

Ignore augmentations with these names:
- _colorpatch
- _sam

In [20]:
from pathlib import Path
import shutil

# STAGE_DESCRIPTIONS = {
# 	 "stage_0": "I can see the base plate, therefore the object must be in stage_0",
# 	 "stage_1": "I can see the cylinder being on top of the base plate, therefore the object must be in stage_1",
# 	 "stage_2": "I can see the big metal piece being on top of the cylinder, therefore the object must be in stage_2",
# 	 "stage_3": "I can see the small metal piece being on placed in the middle of the big metal piece, therefore the object must be in stage_3",
# 	 "stage_4": "I can see the small metal ring being placed in the center of the small metal piece, therefore the object must be in stage_4",
# 	 "stage_5": "I can see the 3 screws being screwed on the metal plate, therefore the object must be in stage_5",
# 	 "stage_6": "I can see the daker metal plate being placed on top of the object, therefore the object must be in stage_6",
# 	 "stage_7": "I can see 5 screws being screwed on the metal plate, therefore the object must be in stage_7",
# }

STAGE_DESCRIPTIONS = {
	"stage_0": "I can see the base plate, which is the main piece of stage_0",
	"stage_1": "I can see the cylinder the main piece of stage_1",
	"stage_2": "I can see the big metal piecethe main piece of stage_2",
	"stage_3": "I can see the smaller metal thinner piece the main piece of stage_3",
	"stage_4": "I can see the small metal ring the main piece of stage_5",
	"stage_5": "I can see 3 screws",
	"stage_6": "I can see the darker metal plate the main piece of stage_6",
	"stage_7": "I can see 5 screws",
}


def process_stage(stage: str, src_root, dst_root) -> None:
	"""Copy images for one stage and write annotation files.

	`src_root` and `dst_root` can be Path objects **or** plain strings."""
	src_root = Path(src_root)
	dst_root = Path(dst_root)

	src_dir = src_root / "images" / stage
	if not src_dir.is_dir():
		print(f"Warning: {src_dir} is missing, skipping.")
		return

	dst_img_dir = dst_root / "images" / stage
	dst_ann_dir = dst_root / "anno" / stage
	dst_img_dir.mkdir(parents=True, exist_ok=True)
	dst_ann_dir.mkdir(parents=True, exist_ok=True)

	desc_text = STAGE_DESCRIPTIONS[stage]

	for img_path in src_dir.iterdir():
		if img_path.is_dir():
			continue
		name = img_path.name
		if "_colorpatch" in name or "_sam" in name:
			continue

		shutil.copy2(img_path, dst_img_dir / name.replace('.jpg', '_solo.jpg'))
		(dst_ann_dir / f"{img_path.stem}_solo.txt").write_text(desc_text, encoding="utf-8")


# Example manual call (now works with strings):
for stage in STAGE_DESCRIPTIONS:
	process_stage(
		stage,
		"/Users/georgye/Documents/repos/ethz/dslab25/assets/vacuum_pump/rendered_single",
		"/Users/georgye/Documents/repos/ethz/dslab25/training/qwen",
	)
	process_stage(
		stage,
		"/Users/georgye/Documents/repos/ethz/dslab25/assets/vacuum_pump/rendered",
		"/Users/georgye/Documents/repos/ethz/dslab25/training/qwen",
	)

# Converting whole folder

In [26]:
#!/usr/bin/env python3
"""
prepare_qwen25vl_data.py
------------------------
Build the folder hierarchy and JSONL annotation files required to fine-tune
Qwen-2.5-VL on your assembly-stage images.

Images:  {images_root}/stage_k/XXXX.jpg
Labels:  {labels_root}/stage_k/XXXX.txt   (same relative path & basename)
"""

import json
import random
import shutil
from pathlib import Path
from typing import List, Tuple

# ────────────────────────────────────────────────────────────────────────────────
SYSTEM_MESSAGE = (
	"You job is it tell me what you see in the image, and if possible what stage "
	"the object is currently in.\nHere are the possible states:\n"
	"\t'state_0': 'First part of the object: Base block metal piece',\n"
	"\t'state_1': 'Second part of the object: Cylinder metal piece which gets stick "
	"on the base block stage_0',\n"
	"\t'state_2': 'Third part of the object: A Big metal piece which gets stick on "
	"the cylinder piece of stage_1',\n"
	"\t'state_3': 'Fourth part of the object: A smaller thin metal piece which gets "
	"put onto the center of the big metal piece of stage_2',\n"
	"\t'state_4': 'Fifth part of the object: A tiny metal ring which gets placed "
	"onto the center of the thing metal piece of stage_3',\n"
	"\t'state_5': 'Sixth part of the object: 3 screws now get screwed onto the piece',\n"
	"\t'state_6': 'Seventh part of the object: A darker metal plate now gets placed "
	"on top of the piece',\n"
	"\t'state_7': 'Eighth part of the object: 5 screws now get screwed onto the piece'"
)

USER_PREFIX = (
	"Describe the object and, if you can, tell me which stage (state_0 … state_7) "
	"it is currently in."
)
# ────────────────────────────────────────────────────────────────────────────────


def collect_pairs(img_root: Path, lbl_root: Path) -> List[Tuple[Path, str]]:
	"""Return list of (image_path, label_text)."""
	img_root = Path(img_root)
	lbl_root = Path(lbl_root)

	pairs: List[Tuple[Path, str]] = []
	for img_path in img_root.rglob("*.jpg"):
		rel = img_path.relative_to(img_root)
		lbl_path = lbl_root / rel.with_suffix(".txt")
		if not lbl_path.exists():
			raise FileNotFoundError(f"Missing label file for image: {img_path}")
		text = lbl_path.read_text(encoding="utf-8").strip()
		pairs.append((img_path, text))

	if not pairs:
		raise RuntimeError("No image/label pairs found - check your paths.")
	return pairs


def split_pairs(
	pairs: List[Tuple[Path, str]], train_split: float, seed: int
) -> Tuple[List[Tuple[Path, str]], List[Tuple[Path, str]]]:
	random.Random(seed).shuffle(pairs)
	cut = int(len(pairs) * train_split)
	return pairs[:cut], pairs[cut:]


def write_jsonl(
	pairs: List[Tuple[Path, str]], out_dir: Path, img_root: Path
) -> None:
	"""
	Copy images into out_dir and create annotations.jsonl
	(images are kept at top level of out_dir; stage_*/ sub-folders are replicated).
	"""
	out_dir = Path(out_dir)
	img_root = Path(img_root)
	anno_path = out_dir / "annotations.jsonl"
	out_dir.mkdir(parents=True, exist_ok=True)

	with anno_path.open("w", encoding="utf-8") as f:
		for img_path, label in pairs:
			rel_img = img_path.relative_to(img_root)
			dest_img = out_dir / rel_img
			dest_img.parent.mkdir(parents=True, exist_ok=True)
			shutil.copy2(img_path, dest_img)

			record = {
				"image": str(rel_img).replace("\\", "/"),  # JSONL wants forward slashes
				"prefix": USER_PREFIX,
				"suffix": label,
			}
			f.write(json.dumps(record, ensure_ascii=False) + "\n")


# ─── USER-SPECIFIC PATHS ────────────────────────────────────────────────────────
IMG_ROOT  = Path("/Users/georgye/Documents/repos/ethz/dslab25/training/qwen/images/augmented")
LBL_ROOT  = Path("/Users/georgye/Documents/repos/ethz/dslab25/training/qwen/annotation/augmented")
DATA_ROOT = Path("/Users/georgye/Documents/repos/ethz/dslab25/training/qwen/data")
TRAIN_SPLIT = 0.95
SEED = 32
# ────────────────────────────────────────────────────────────────────────────────

# 1. Pair up images and labels
pairs = collect_pairs(IMG_ROOT, LBL_ROOT)

# 2. Train / val split
train_pairs, val_pairs = split_pairs(pairs, TRAIN_SPLIT, SEED)

# 3. Write datasets
train_dir = DATA_ROOT / "train"
val_dir   = DATA_ROOT / "val"

for dirpath in (train_dir, val_dir):
		if dirpath.exists():
				shutil.rmtree(dirpath)  # start clean

write_jsonl(train_pairs, train_dir, IMG_ROOT)
write_jsonl(val_pairs,  val_dir,  IMG_ROOT)

# 4. Save the system prompt
(DATA_ROOT / "system_message.txt").write_text(SYSTEM_MESSAGE, encoding="utf-8")

# 5. Report
print(
		f"✅  Done.\n"
		f"• Train images: {len(train_pairs)} ➜ {train_dir}\n"
		f"• Val   images: {len(val_pairs)} ➜ {val_dir}\n"
		f"• JSONL files:  {train_dir/'annotations.jsonl'}, {val_dir/'annotations.jsonl'}"
)



✅  Done.
• Train images: 2675 ➜ /Users/georgye/Documents/repos/ethz/dslab25/training/qwen/data/train
• Val   images: 669 ➜ /Users/georgye/Documents/repos/ethz/dslab25/training/qwen/data/val
• JSONL files:  /Users/georgye/Documents/repos/ethz/dslab25/training/qwen/data/train/annotations.jsonl, /Users/georgye/Documents/repos/ethz/dslab25/training/qwen/data/val/annotations.jsonl


# Train

In [None]:
#!/usr/bin/env python3
"""
finetune_qwen25vl7b.py
──────────────────────
Fine-tune Qwen-2.5-VL-7B-Instruct (Hugging Face) on a local JSONL+image dataset.

Requires:
  pip install -q "git+https://github.com/huggingface/transformers" \
				 accelerate peft bitsandbytes qwen-vl-utils[decord]==0.0.8 \
				 lightning nltk
"""

import argparse
import json
import os
import random
from pathlib import Path
from typing import Any, List, Tuple

import lightning as L
import torch
from nltk import edit_distance
from peft import LoraConfig, get_peft_model
from qwen_vl_utils import process_vision_info
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import (
	BitsAndBytesConfig,
	Qwen2_5_VLForConditionalGeneration,
	Qwen2_5_VLProcessor,
)

# ────────────────────────────────────────────────────────────────────────────────
SYSTEM_MESSAGE = Path(
	__file__
).with_name("system_message.txt").read_text(encoding="utf-8")
# ────────────────────────────────────────────────────────────────────────────────


def format_chat(image_dir: Path, entry: dict) -> List[dict]:
	"""Return a 3-turn conversation in the format Qwen expects."""
	return [
		{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
		{
			"role": "user",
			"content": [
				{"type": "image", "image": str(image_dir / entry["image"])},
				{"type": "text", "text": entry["prefix"]},
			],
		},
		{"role": "assistant", "content": [{"type": "text", "text": entry["suffix"]}]},
	]


class JSONLDataset(Dataset):
	def __init__(self, jsonl_path: Path, image_dir: Path):
		self.image_dir = image_dir
		self.entries = [json.loads(l) for l in jsonl_path.read_text().splitlines()]

	def __len__(self):  # noqa: D401
		return len(self.entries)

	def __getitem__(self, idx: int) -> Tuple[Any, dict, List[dict]]:  # noqa: D401
		entry = self.entries[idx]
		return None, entry, format_chat(self.image_dir, entry)


# ─── Collate fns ────────────────────────────────────────────────────────────────
def make_collate(processor):
	def train_collate(batch):
		_, _, examples = zip(*batch)
		texts = [processor.apply_chat_template(e, tokenize=False) for e in examples]
		imgs = [process_vision_info(e)[0] for e in examples]
		model_in = processor(text=texts, images=imgs, return_tensors="pt", padding=True)
		labels = model_in["input_ids"].clone()
		labels[labels == processor.tokenizer.pad_token_id] = -100
		for tkn in (151652, 151653, 151655):
			labels[labels == tkn] = -100
		return (
			model_in["input_ids"],
			model_in["attention_mask"],
			model_in["pixel_values"],
			model_in["image_grid_thw"],
			labels,
		)

	def eval_collate(batch):
		_, data, examples = zip(*batch)
		suffixes = [d["suffix"] for d in data]
		prompts = [processor.apply_chat_template(e[:2], tokenize=False) for e in examples]
		imgs = [process_vision_info(e[:2])[0] for e in examples]
		model_in = processor(text=prompts, images=imgs, return_tensors="pt", padding=True)
		return (
			model_in["input_ids"],
			model_in["attention_mask"],
			model_in["pixel_values"],
			model_in["image_grid_thw"],
			suffixes,
		)

	return train_collate, eval_collate


# ─── Lightning module ───────────────────────────────────────────────────────────
class QwenTrainer(L.LightningModule):
	def __init__(self, cfg, model, processor, train_set, val_set):
		super().__init__()
		self.save_hyperparameters()
		self.cfg, self.model, self.processor = cfg, model, processor
		self.train_set, self.val_set = train_set, val_set
		self.train_collate, self.eval_collate = make_collate(processor)

	# ╭─ training ╮
	def training_step(self, batch, _):
		ids, msk, pix, thw, lbl = batch
		loss = self.model(
			input_ids=ids,
			attention_mask=msk,
			pixel_values=pix,
			image_grid_thw=thw,
			labels=lbl,
		).loss
		self.log("train_loss", loss, prog_bar=True)
		return loss

	# ╭─ validation ╮
	def validation_step(self, batch, _):
		ids, msk, pix, thw, refs = batch
		gen_ids = self.model.generate(
			input_ids=ids,
			attention_mask=msk,
			pixel_values=pix,
			image_grid_thw=thw,
			max_new_tokens=256,
		)
		outs = self.processor.batch_decode(
			[o[len(i) :] for i, o in zip(ids, gen_ids)],
			skip_special_tokens=True,
		)
		score = sum(edit_distance(o, r) / max(len(o), len(r)) for o, r in zip(outs, refs))
		self.log("val_edit_dist", score / len(refs), prog_bar=True)

	# ╭─ loaders ╮
	def train_dataloader(self):
		return DataLoader(
			self.train_set,
			batch_size=self.cfg["batch_size"],
			shuffle=True,
			num_workers=4,
			collate_fn=self.train_collate,
		)

	def val_dataloader(self):
		return DataLoader(
			self.val_set,
			batch_size=1,
			shuffle=False,
			num_workers=2,
			collate_fn=self.eval_collate,
		)

	# ╭─ optim ╮
	def configure_optimizers(self):
		return AdamW(self.model.parameters(), lr=self.cfg["lr"])


In [None]:

# ─── Entry-point ────────────────────────────────────────────────────────────────
ap = argparse.ArgumentParser()
ap.add_argument("--data_root", required=True, type=Path)
ap.add_argument("--epochs", default=10, type=int)
ap.add_argument("--save_dir", default="qwen_2_5_vl_7b_ft", type=Path)
args = ap.parse_args()

train_jsonl = args.data_root / "train" / "annotations.jsonl"
val_jsonl = args.data_root / "val" / "annotations.jsonl"

train_set = JSONLDataset(train_jsonl, args.data_root / "train")
val_set = JSONLDataset(val_jsonl, args.data_root / "val")

# ─── Model + processor ────────────────────────────────────────────────────
MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
lora_cfg = LoraConfig(
		r=8, lora_alpha=16, lora_dropout=0.05, bias="none", target_modules=["q_proj", "v_proj"]
)
bnb_cfg = BitsAndBytesConfig(
		load_in_4bit=True,
		bnb_4bit_use_double_quant=True,
		bnb_4bit_quant_type="nf4",
		bnb_4bit_compute_type=torch.bfloat16,
)

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
		MODEL_ID, device_map="auto", quantization_config=bnb_cfg, torch_dtype=torch.bfloat16
)
model = get_peft_model(model, lora_cfg)
processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)

cfg = dict(batch_size=1, lr=2e-4)
lit = QwenTrainer(cfg, model, processor, train_set, val_set)

# ─── Checkpoint callback ──────────────────────────────────────────────────
class SaveBoth(L.Callback):
		def __init__(self, out: Path):  # noqa: D401
				self.out = out

		def on_train_epoch_end(self, trainer, pl_module):
				path = self.out / f"epoch_{trainer.current_epoch}"
				path.mkdir(parents=True, exist_ok=True)
				pl_module.processor.save_pretrained(path)
				pl_module.model.save_pretrained(path)
				print(f"[ckpt] {path}")

trainer = L.Trainer(
		accelerator="gpu",
		devices=1,
		max_epochs=args.epochs,
		accumulate_grad_batches=8,
		gradient_clip_val=1.0,
		log_every_n_steps=10,
		callbacks=[SaveBoth(args.save_dir)],
		precision="bf16-mixed",
)

trainer.fit(lit)