In [None]:
!nvidia-smi -L

In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from datasets import load_dataset
from transformers import (
    SiglipModel,
    SiglipConfig,
    PreTrainedModel,
    AutoModel,
    AutoProcessor,
    Trainer,
    TrainingArguments,
    TrainerCallback
)
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import random
import os
from typing import Optional

In [None]:
dataset = load_dataset("justacoderwhocodes/dental_binary_treatment_classification", split="train")
print(f"Total dataset size: {len(dataset)}")

dataset = dataset.shuffle(seed=42)
split = dataset.train_test_split(test_size=0.1, seed=42)
train_ds = split["train"]
val_ds = split["test"]

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")

label_map = {"no_treatment": 0, "treatment": 1}
inv_label_map = {0: "no_treatment", 1: "treatment"}

print(f"Label distribution:")
print(f"Train: {train_ds.features['label'].str2int(train_ds['label'])}")

In [None]:
processor = AutoProcessor.from_pretrained("google/medsiglip-448")
medsiglip = SiglipModel.from_pretrained("google/medsiglip-448", torch_dtype=torch.bfloat16)
medsiglip.eval()

for param in medsiglip.parameters():
    param.requires_grad = False

print(f"MedSigLip parameters: {sum(p.numel() for p in medsiglip.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in medsiglip.parameters() if p.requires_grad):,}")

In [None]:
class MedSigLipBinaryClassifierConfig(SiglipConfig):
    model_type = "medsiglip_binary_classifier"
    
    def __init__(self, num_classes: int = 2, hidden_dim: int = 1152, **kwargs):
        super().__init__(**kwargs)
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim

class MedSigLipBinaryClassifier(PreTrainedModel):
    config_class = MedSigLipBinaryClassifierConfig
    base_model_prefix = "vision_model"
    
    def __init__(self, config: MedSigLipBinaryClassifierConfig):
        super().__init__(config)
        self.num_classes = config.num_classes
        self.hidden_dim = config.hidden_dim
        
        self.vision_model = SiglipModel(config)
        self.vision_model.eval()
        
        for param in self.vision_model.parameters():
            param.requires_grad = False
        
        self.classifier = nn.Linear(config.hidden_dim, config.num_classes)
        self.post_init()
    
    def forward(self, pixel_values: torch.Tensor, labels: Optional[torch.Tensor] = None):
        with torch.no_grad():
            outputs = self.vision_model(pixel_values=pixel_values)
            pooled = outputs.pooler_output
        logits = self.classifier(pooled)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        return {"loss": loss, "logits": logits}

print("Custom HuggingFace-compatible model class defined")

In [None]:
config = MedSigLipBinaryClassifierConfig.from_pretrained("google/medsiglip-448")
model = MedSigLipBinaryClassifier(config)

print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
class DentalClassDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, processor, label_map):
        self.dataset = hf_dataset
        self.processor = processor
        self.label_map = label_map
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        inputs = self.processor(images=item["image"], return_tensors="pt")
        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "labels": self.label_map[item["label"]]
        }

train_dataset = DentalClassDataset(train_ds, processor, label_map)
val_dataset = DentalClassDataset(val_ds, processor, label_map)

In [None]:
def visualize_predictions(model, dataset, inv_label_map, num_samples=4):
    model.eval()
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    axes = axes.flatten()
    
    for idx, sample_idx in enumerate(indices):
        image = dataset.dataset[sample_idx]["image"]
        inputs = processor(images=image, return_tensors="pt")
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(pixel_values=inputs["pixel_values"])
            logits = outputs["logits"].squeeze(0)
            pred = torch.argmax(logits).item()
        
        actual = dataset[sample_idx]["labels"]
        pred_label = inv_label_map[pred]
        actual_label = inv_label_map[actual]
        status = "\u2705" if pred == actual else "\u274c"
        
        axes[idx].imshow(image)
        axes[idx].set_title(f"Pred: {pred_label}\nActual: {actual_label} {status}")
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()
    model.train()

print("Visualization function defined")

In [None]:
class VizCallback(TrainerCallback):
    def __init__(self, model, dataset, inv_label_map, every_n_steps=200):
        self.model = model
        self.dataset = dataset
        self.inv_label_map = inv_label_map
        self.every_n_steps = every_n_steps
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        if state.global_step % self.every_n_steps == 0 and state.global_step > 0:
            print(f"\n{'='*50}")
            print(f"Step {state.global_step} - Visualization")
            print(f"{'='*50}")
            visualize_predictions(model, self.dataset, self.inv_label_map)
    
    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        print(f"\n{'='*50}")
        print(f"Epoch {state.epoch} - Visualization")
        print(f"{'='*50}")
        visualize_predictions(model, self.dataset, self.inv_label_map)

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1)
    return {"accuracy": accuracy_score(labels, predictions)}

args = TrainingArguments(
    output_dir="./medsiglip-dental-classifier",
    num_train_epochs=10,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-3,
    warmup_steps=50,
    lr_scheduler_type="cosine",
    logging_steps=50,
    eval_steps=500,
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    push_to_hub=True,
    hub_model_id="justacoderwhocodes/medsiglip-dental-classifier",
    bf16=True,
    report_to="none"
)

print("Training config set up")

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[
        VizCallback(model, val_dataset, inv_label_map, every_n_steps=200)
    ]
)

print("Trainer configured. Starting training...")

In [None]:
trainer.train()

In [None]:
# Save model class code to enable AutoModel loading with trust_remote_code=True
model_code = '''
import torch
import torch.nn as nn
from typing import Optional
from transformers import SiglipModel, SiglipConfig, PreTrainedModel

class MedSigLipBinaryClassifierConfig(SiglipConfig):
    model_type = "medsiglip_binary_classifier"
    
    def __init__(self, num_classes: int = 2, hidden_dim: int = 1152, **kwargs):
        super().__init__(**kwargs)
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim

class MedSigLipBinaryClassifier(PreTrainedModel):
    config_class = MedSigLipBinaryClassifierConfig
    base_model_prefix = "vision_model"
    
    def __init__(self, config: MedSigLipBinaryClassifierConfig):
        super().__init__(config)
        self.num_classes = config.num_classes
        self.hidden_dim = config.hidden_dim
        
        self.vision_model = SiglipModel(config)
        self.vision_model.eval()
        
        for param in self.vision_model.parameters():
            param.requires_grad = False
        
        self.classifier = nn.Linear(config.hidden_dim, config.num_classes)
        self.post_init()
    
    def forward(self, pixel_values: torch.Tensor, labels: Optional[torch.Tensor] = None):
        with torch.no_grad():
            outputs = self.vision_model(pixel_values=pixel_values)
            pooled = outputs.pooler_output
        logits = self.classifier(pooled)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        return {"loss": loss, "logits": logits}
'''

with open("modeling_medsiglip_classifier.py", "w") as f:
    f.write(model_code)

trainer.push_to_hub()

print("Model pushed to HuggingFace!")

In [None]:
from transformers import AutoModel

loaded_model = AutoModel.from_pretrained(
    "justacoderwhocodes/medsiglip-dental-classifier",
    trust_remote_code=True
)
processor_loaded = AutoProcessor.from_pretrained("justacoderwhocodes/medsiglip-dental-classifier")
loaded_model.eval()

print("Model loaded with AutoModel! Ready for inference.")

In [None]:
visualize_predictions(loaded_model, val_dataset, inv_label_map, num_samples=8)

In [None]:
def inference_example(image):
    inputs = processor_loaded(images=image, return_tensors="pt")
    inputs = {k: v.to(loaded_model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = loaded_model(pixel_values=inputs["pixel_values"])
        logits = outputs["logits"].squeeze(0)
        probs = torch.softmax(logits, dim=-1)
        pred = torch.argmax(probs).item()
        
    label = inv_label_map[pred]
    confidence = probs[pred].item()
    
    return label, confidence

sample_image = val_ds[0]["image"]
label, conf = inference_example(sample_image)
print(f"Predicted: {label} (confidence: {conf:.2%})")

plt.figure(figsize=(6, 6))
plt.imshow(sample_image)
plt.title(f"Prediction: {label}\nConfidence: {conf:.2%}")
plt.axis('off')
plt.show()

In [None]:
REPO_ID = "justacoderwhocodes/medsiglip-dental-classifier"
print("\nUsage for loading the model:")
print("-" * 50)
print(f"from transformers import AutoModel, AutoProcessor")
print(f"model = AutoModel.from_pretrained('{REPO_ID}', trust_remote_code=True)")
print(f"processor = AutoProcessor.from_pretrained('{REPO_ID}')")
print("\nTo run inference:")
print("inputs = processor(images=image, return_tensors='pt')")
print("outputs = model(pixel_values=inputs['pixel_values'])")
print("logits = outputs['logits']")
print("pred = torch.argmax(logits, dim=-1)")