# Setup


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
params = {
    "epochs": 8,
    "batch_size": 16,
    "learning_rate": 0.001,
    "class_power": 0.9,  # between 0 and 1 (inclusive), scales positive class weight (0 removes class weighting, 1 leaves class ratio unchanged)
    "focal_power": 2,  # focusing parameter (gamma) in focal loss (default 2)
    "threshold": 0.5,  # probability threshold for positive classification
    "seed": 42,  # rng seed for reproducibility
}

epochs = params["epochs"]
batch_size = params["batch_size"]
lr = params["learning_rate"]
class_power = params["class_power"]
focal_power = params["focal_power"]

threshold = params["threshold"]
seed = params["seed"]

In [3]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cuda')

In [4]:
import torch

torch.manual_seed(seed)
generator = torch.Generator().manual_seed(seed)

In [5]:
import torch

torch.set_float32_matmul_precision("high")

# Dataset


In [6]:
from datasets import load_dataset
from isic.dataset import MetadataTextFormatter, MessagesFormatter

ds = load_dataset("mrbrobot/isic-2024", split="train")
ds = ds.select_columns(["image", "age_approx", "sex", "anatom_site_general", "target"])

# format metadata as text
ds = ds.with_format("arrow")
ds = ds.map(MetadataTextFormatter(), batched=True, desc="Formatting metadata")

# format metadata & image as prompt
ds = ds.select_columns(["image", "text", "target"])
ds = ds.with_format(None)
ds = ds.with_transform(MessagesFormatter())

len(ds)

401059

# Model Definition


In [7]:
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

model = Qwen3VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-4B-Instruct", dtype="auto", device_map="auto"
)

processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
messages = ds[0]["messages"]

messages

[{'role': 'system',
  'content': [{'type': 'text',
    'text': 'Classify provided example as benign or malignant.'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=139x139>},
   {'type': 'text',
    'text': '| Field | Value |\n|-------|-------|\n| Age | 60 years |\n| Sex | male |\n| Lesion Site | lower extremity |'}]}]

In [9]:
inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_dict=True,
    return_tensors="pt",
)
inputs = inputs.to(model.device)

for k, v in inputs.items():
    print(f"{k} -> {type(v)}")

input_ids -> <class 'torch.Tensor'>
attention_mask -> <class 'torch.Tensor'>
pixel_values -> <class 'torch.Tensor'>
image_grid_thw -> <class 'torch.Tensor'>


In [10]:
generated_ids = model.generate(**inputs, max_new_tokens=4)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

output_text

['Benign']

In [11]:
model(**inputs, max_new_tokens=4)

Qwen3VLCausalLMOutputWithPast(loss=None, logits=tensor([[[ 3.6250,  2.7188,  4.9688,  ..., -2.2344, -2.2344, -2.2344],
         [ 2.5312,  2.6406,  2.8594,  ...,  1.2969,  1.2969,  1.2969],
         [ 7.7812,  8.0625,  7.8438,  ...,  2.4219,  2.4219,  2.4219],
         ...,
         [ 4.1875,  4.6562,  9.3750,  ...,  0.1406,  0.1406,  0.1406],
         [ 1.8984,  6.0312,  4.2812,  ..., -3.0625, -3.0625, -3.0625],
         [ 8.4375, 12.5000, 15.5000,  ..., -3.1250, -3.1250, -3.1250]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>), past_key_values=DynamicCache(layers=[DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, DynamicLayer, Dyn

In [12]:
print(model)

Qwen3VLForConditionalGeneration(
  (model): Qwen3VLModel(
    (visual): Qwen3VLVisionModel(
      (patch_embed): Qwen3VLVisionPatchEmbed(
        (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
      (pos_embed): Embedding(2304, 1024)
      (rotary_pos_emb): Qwen3VLVisionRotaryEmbedding()
      (blocks): ModuleList(
        (0-23): 24 x Qwen3VLVisionBlock(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): Qwen3VLVisionAttention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (mlp): Qwen3VLVisionMLP(
            (linear_fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (linear_fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (act_fn): GELUTanh()
          )
        )
      )
 

In [13]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params

4437815808

# Training


Class imbalance measurement & handling


In [14]:
class_counts = [400666, 393]  # [benign, malignant] from EDA

print("Class Distribution:")
print(f"Benign: {class_counts[0]:,} samples")
print(f"Malignant: {class_counts[1]:,} samples")
print(f"Imbalance ratio: {class_counts[0] / class_counts[1]:.1f}:1")

Class Distribution:
Benign: 400,666 samples
Malignant: 393 samples
Imbalance ratio: 1019.5:1


In [15]:
df = ds.to_pandas()
neg_count = (df["target"] == 0).sum()
pos_count = (df["target"] == 1).sum()
pos_weight = neg_count / pos_count

print(f"Positive weight: {pos_weight:.1f}")
print(f"Scaled positive class weight: {pos_weight**class_power:.1f}")

Positive weight: 1019.5
Scaled positive class weight: 510.0


In [16]:
from isic.loss import WeightedFocalLoss

scaled_pos_weight = torch.tensor([pos_weight**class_power], device=device)
criterion = WeightedFocalLoss(pos_weight=scaled_pos_weight, gamma=focal_power)

In [17]:
split = ds.train_test_split(test_size=0.2, seed=seed)
train_ds, val_ds = split["train"], split["test"]

print(f"Train samples: {len(train_ds):,}")
print(f"Val samples: {len(val_ds):,}")

Train samples: 320,847
Val samples: 80,212


In [18]:
model.config

Qwen3VLConfig {
  "architectures": [
    "Qwen3VLForConditionalGeneration"
  ],
  "image_token_id": 151655,
  "model_type": "qwen3_vl",
  "text_config": {
    "attention_bias": false,
    "attention_dropout": 0.0,
    "bos_token_id": 151643,
    "dtype": "bfloat16",
    "eos_token_id": 151645,
    "head_dim": 128,
    "hidden_act": "silu",
    "hidden_size": 2560,
    "initializer_range": 0.02,
    "intermediate_size": 9728,
    "max_position_embeddings": 262144,
    "model_type": "qwen3_vl_text",
    "num_attention_heads": 32,
    "num_hidden_layers": 36,
    "num_key_value_heads": 8,
    "rms_norm_eps": 1e-06,
    "rope_scaling": {
      "mrope_interleaved": true,
      "mrope_section": [
        24,
        20,
        20
      ],
      "rope_type": "default"
    },
    "rope_theta": 5000000,
    "tie_word_embeddings": true,
    "use_cache": true,
    "vocab_size": 151936
  },
  "tie_word_embeddings": true,
  "transformers_version": "4.57.1",
  "video_token_id": 151656,
  "vision_co

In [19]:
import torch.nn as nn

# Replace lm_head with binary classifier
hidden_size = model.config.text_config.hidden_size
model.lm_head = nn.Linear(hidden_size, 1).to(model.device)

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze only the classification head
for param in model.lm_head.parameters():
    param.requires_grad = True

# Count trainable parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable:,}")
print(f"Total parameters: {total:,}")

Trainable parameters: 2,561
Total parameters: 4,437,818,369


In [20]:
from isic.dataset import VLMCollator

collator = VLMCollator(processor)

In [None]:
from isic.loss import WeightedFocalLoss, VLMLoss
from isic.metrics import BinaryMetricsComputer

criterion = WeightedFocalLoss(pos_weight=scaled_pos_weight, gamma=focal_power)

compute_loss = VLMLoss(criterion)

# Initialize metrics computer for HuggingFace Trainer
metrics_computer = BinaryMetricsComputer(
    threshold=threshold,
    device=device,
)

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    # output_dir="./results/qwen-phase1",
    num_train_epochs=2,  # Phase 1: short training with frozen backbone
    per_device_train_batch_size=batch_size,  # Adjust based on GPU memory
    per_device_eval_batch_size=batch_size,
    learning_rate=lr,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="roc_auc",  # Use AUROC for model selection
    bf16=True,  # Mixed precision for efficiency
    remove_unused_columns=False,  # Important: keep all columns for collator
    data_seed=seed,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator,
    compute_loss_func=compute_loss,
    compute_metrics=metrics_computer,  # Pass metrics computer instance directly
)

In [23]:
trainer.train()

* Trackio project initialized: huggingface
* Trackio metrics will be synced to Hugging Face Dataset: mrbrobot/trackio-dataset
* Found existing space: https://huggingface.co/spaces/mrbrobot/trackio
* View dashboard by going to: https://mrbrobot-trackio.hf.space/


* Created new run: mrbrobot-1761314816


Epoch,Training Loss,Validation Loss
1,9.3918,8.850162
2,2.4948,1.561652


Exception in thread Thread-12 (_init_client_background):
Traceback (most recent call last):
  File "/workspaces/isic/.venv/lib/python3.12/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
    yield
  File "/workspaces/isic/.venv/lib/python3.12/site-packages/httpx/_transports/default.py", line 250, in handle_request
    resp = self._pool.handle_request(req)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/isic/.venv/lib/python3.12/site-packages/httpcore/_sync/connection_pool.py", line 256, in handle_request
    raise exc from None
  File "/workspaces/isic/.venv/lib/python3.12/site-packages/httpcore/_sync/connection_pool.py", line 236, in handle_request
    response = connection.handle_request(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/isic/.venv/lib/python3.12/site-packages/httpcore/_sync/connection.py", line 101, in handle_request
    raise exc
  File "/workspaces/isic/.venv/lib/python3.12/site-packages/httpcore/_sync/con

* Run finished. Uploading logs to Trackio (please wait...)


TrainOutput(global_step=40106, training_loss=6.260735238941329, metrics={'train_runtime': 24249.4262, 'train_samples_per_second': 26.462, 'train_steps_per_second': 1.654, 'total_flos': 1.8230358175625964e+18, 'train_loss': 6.260735238941329, 'epoch': 2.0})