# LLM-Based Rust Detection Model Training
Train a vision-language model (CLIP) for rust detection and export to ONNX

In [1]:
!pip -q install roboflow torch torchvision tqdm onnx onnxruntime transformers pillow

import os, json, shutil
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModel
from tqdm.auto import tqdm

print("✅ Imports complete")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/91.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.8/91.8 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.8/66.8 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.9/49.9 MB[0m [31m20.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.5/17.5 MB[0m [31m76.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m66.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m70.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m74.3 MB/s[0m eta [36m0:00:00[0m
[?25h✅ Imports complete


## 1. Download Dataset from Roboflow

In [2]:
from roboflow import Roboflow

RF_API_KEY = "n7MyCMrV8InDdMIca94B"
rf = Roboflow(api_key=RF_API_KEY)

project = rf.workspace("test-stage").project("rust-detection-t8vza")
dataset = project.version(8).download("yolov8")
print("Dataset downloaded to:", dataset.location)

loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in Rust-Detection-8 to yolov8:: 100%|██████████| 1217454/1217454 [00:30<00:00, 40041.62it/s]





Extracting Dataset Version Zip to Rust-Detection-8 in yolov8:: 100%|██████████| 83900/83900 [00:10<00:00, 7697.17it/s]


Dataset downloaded to: /content/Rust-Detection-8


## 2. Convert YOLO to Classification Dataset

In [3]:
from pathlib import Path
from tqdm.auto import tqdm

YOLO_ROOT = Path(dataset.location)
OUT_ROOT = Path("/content/rust_cls")
OUT_ROOT.mkdir(exist_ok=True)

def label_has_boxes(label_path: Path) -> bool:
    if not label_path.exists():
        return False
    txt = label_path.read_text().strip()
    return len(txt) > 0

def convert_split(split: str):
    img_dir = YOLO_ROOT / split / "images"
    lab_dir = YOLO_ROOT / split / "labels"

    if not img_dir.exists():
        print(f"⚠️  {split} images not found, skipping")
        return

    out_rust = OUT_ROOT / split / "rust"
    out_no_rust = OUT_ROOT / split / "no_rust"
    out_rust.mkdir(parents=True, exist_ok=True)
    out_no_rust.mkdir(parents=True, exist_ok=True)

    imgs = list(img_dir.glob("*.*"))
    rust_count = no_rust_count = 0

    for img_path in tqdm(imgs, desc=f"Converting {split}"):
        label_path = lab_dir / (img_path.stem + ".txt")
        has_rust = label_has_boxes(label_path)

        if has_rust:
            shutil.copy2(img_path, out_rust / img_path.name)
            rust_count += 1
        else:
            shutil.copy2(img_path, out_no_rust / img_path.name)
            no_rust_count += 1

    print(f"{split}: rust={rust_count}, no_rust={no_rust_count}")

for split in ["train", "valid", "test"]:
    convert_split(split)

print("✅ Classification dataset created at:", OUT_ROOT)

Converting train:   0%|          | 0/37188 [00:00<?, ?it/s]

train: rust=35020, no_rust=2168


Converting valid:   0%|          | 0/3156 [00:00<?, ?it/s]

valid: rust=2968, no_rust=188


Converting test:   0%|          | 0/1600 [00:00<?, ?it/s]

test: rust=1471, no_rust=129
✅ Classification dataset created at: /content/rust_cls


## 3. Load CLIP Model (Vision-Language Model)
CLIP uses both vision and language understanding for classification

In [4]:
MODEL_NAME = "openai/clip-vit-base-patch32"  # 151M parameters, good balance

# Load CLIP
clip_model = CLIPModel.from_pretrained(MODEL_NAME)
clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Extract vision encoder for fine-tuning
vision_model = clip_model.vision_model
vision_model = vision_model.to(device)

print(f"✅ CLIP model loaded: {MODEL_NAME}")
print(f"Vision model parameters: {sum(p.numel() for p in vision_model.parameters()) / 1e6:.1f}M")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]



pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

CLIPModel LOAD REPORT from: openai/clip-vit-base-patch32
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
text_model.embeddings.position_ids   | UNEXPECTED |  | 
vision_model.embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

The image processor of type `CLIPImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

Device: cuda
✅ CLIP model loaded: openai/clip-vit-base-patch32
Vision model parameters: 87.5M


## 4. Create Custom Classification Head

In [5]:
class CLIPRustClassifier(nn.Module):
    def __init__(self, clip_vision_model, num_classes=2):
        super().__init__()
        self.vision_model = clip_vision_model
        self.vision_model.eval()  # Freeze vision model initially

        # Get output dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224).to(device)
            output = self.vision_model(dummy).last_hidden_state[:, 0, :]  # Use CLS token
            hidden_dim = output.shape[-1]

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, pixel_values):
        # Extract vision features
        vision_outputs = self.vision_model(pixel_values)
        pooled_output = vision_outputs.last_hidden_state[:, 0, :]  # CLS token

        # Classify
        logits = self.classifier(pooled_output)
        return logits

# Create model
model = CLIPRustClassifier(vision_model, num_classes=2)
model = model.to(device)

print("✅ Classification model created")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

✅ Classification model created
Total parameters: 87.9M


## 5. Create Dataset and DataLoader

In [6]:
from torch.utils.data import Dataset
from torchvision import transforms as T

class RustDataset(Dataset):
    def __init__(self, root_dir, split, processor):
        self.root_dir = Path(root_dir) / split
        self.processor = processor

        # Collect all images
        self.samples = []
        for label_idx, label_name in enumerate(["no_rust", "rust"]):
            label_dir = self.root_dir / label_name
            if label_dir.exists():
                for img_path in label_dir.glob("*.*"):
                    self.samples.append((str(img_path), label_idx))

        print(f"{split}: {len(self.samples)} images")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")

        # Process with CLIP processor
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)

        return pixel_values, label

# Create datasets
train_dataset = RustDataset(OUT_ROOT, "train", clip_processor)
val_dataset = RustDataset(OUT_ROOT, "valid", clip_processor)
test_dataset = RustDataset(OUT_ROOT, "test", clip_processor)

# Create data loaders
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print("✅ Dataloaders ready")

train: 37188 images
valid: 3156 images
test: 1600 images
✅ Dataloaders ready


## 6. Training Setup with Class Balancing

In [7]:
# Calculate class weights
labels = [label for _, label in train_dataset.samples]
class_counts = np.bincount(labels)
class_weights = len(labels) / (len(class_counts) * class_counts)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

print(f"Class distribution: no_rust={class_counts[0]}, rust={class_counts[1]}")
print(f"Class weights: {class_weights.cpu().numpy()}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

print("✅ Training setup complete")

Class distribution: no_rust=2168, rust=35020
Class weights: [8.576569   0.53095376]
✅ Training setup complete


## 7. Training Loop

In [8]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0

    for pixel_values, labels in loader:
        pixel_values = pixel_values.to(device)
        labels = labels.to(device)

        logits = model(pixel_values)
        loss = criterion(logits, labels)
        total_loss += loss.item()

        preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Calculate metrics
    accuracy = (all_preds == all_labels).mean()

    # Balanced accuracy
    balanced_acc = 0
    for c in range(2):
        mask = all_labels == c
        if mask.sum() > 0:
            balanced_acc += (all_preds[mask] == c).mean()
    balanced_acc /= 2

    avg_loss = total_loss / len(loader)

    return accuracy, balanced_acc, avg_loss

# Training
EPOCHS = 10
best_val_acc = 0
best_state = None

print("Starting training...")
for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for pixel_values, labels in pbar:
        pixel_values = pixel_values.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(pixel_values)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    scheduler.step()

    # Validation
    val_acc, val_bal_acc, val_loss = evaluate(model, val_loader)
    print(f"Epoch {epoch}: val_acc={val_acc:.4f}, val_bal_acc={val_bal_acc:.4f}, val_loss={val_loss:.4f}")

    # Save best model
    if val_bal_acc > best_val_acc:
        best_val_acc = val_bal_acc
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        print(f"✅ New best model! val_bal_acc={val_bal_acc:.4f}")

print(f"\n🎯 Best validation balanced accuracy: {best_val_acc:.4f}")

# Load best model
if best_state is not None:
    model.load_state_dict(best_state)
    model = model.to(device)

# Test evaluation
test_acc, test_bal_acc, test_loss = evaluate(model, test_loader)
print(f"🧪 Test Results: acc={test_acc:.4f}, bal_acc={test_bal_acc:.4f}")

Starting training...


Epoch 1/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Epoch 1: val_acc=0.6663, val_bal_acc=0.7404, val_loss=0.6006
✅ New best model! val_bal_acc=0.7404


Epoch 2/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 2: val_acc=0.7015, val_bal_acc=0.7417, val_loss=0.5616
✅ New best model! val_bal_acc=0.7417


Epoch 3/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>^^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
^    ^self._shutdown_workers()Exception ignored in: ^
<function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
        Traceback (most recent call last):
if w.is

Epoch 3: val_acc=0.7551, val_bal_acc=0.7577, val_loss=0.4827
✅ New best model! val_bal_acc=0.7577


Epoch 4/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionErrorException ignored in: : <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
can only test a child processTraceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/d

Epoch 4: val_acc=0.7170, val_bal_acc=0.7474, val_loss=0.4300


Epoch 5/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 5: val_acc=0.7738, val_bal_acc=0.7527, val_loss=0.3679


Epoch 6/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 6: val_acc=0.7985, val_bal_acc=0.7633, val_loss=0.3932
✅ New best model! val_bal_acc=0.7633


Epoch 7/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 7: val_acc=0.8140, val_bal_acc=0.7292, val_loss=0.3567


Epoch 8/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 8: val_acc=0.8432, val_bal_acc=0.7373, val_loss=0.3832


Epoch 9/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 9: val_acc=0.8910, val_bal_acc=0.6755, val_loss=0.4211


Epoch 10/10:   0%|          | 0/1163 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^Exception ignored in: ^^^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>^
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
^^    ^^self._shutdown_workers()^^
^^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
^^    ^if w.is_alive():^^^^
^   ^^ ^^

Epoch 10: val_acc=0.9015, val_bal_acc=0.6636, val_loss=0.4745

🎯 Best validation balanced accuracy: 0.7633


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7a8c859c18a0>^^
^Traceback (most recent call last):
^^^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
^^    ^self._shutdown_workers()
^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
^^^    if w.is_alive():^^^
^ ^ ^ 

🧪 Test Results: acc=0.7919, bal_acc=0.7206


## 8. Export Model for Inference
We'll export a simplified version compatible with your existing code

In [9]:
# Create a standalone inference model
class RustDetectorONNX(nn.Module):
    def __init__(self, vision_model, classifier):
        super().__init__()
        self.vision_model = vision_model
        self.classifier = classifier

    def forward(self, pixel_values):
        vision_outputs = self.vision_model(pixel_values)
        pooled_output = vision_outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(pooled_output)
        return logits

# Create exportable model
export_model = RustDetectorONNX(model.vision_model, model.classifier)
export_model.eval().cpu()

print("✅ Export model created")

✅ Export model created


## 9. Export to ONNX

In [10]:
!pip install -q onnxscript

import torch.onnx

# Prepare dummy input (CLIP expects 224x224 images)
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
onnx_path = "rust_model_llm.onnx"

try:
    torch.onnx.export(
        export_model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
        opset_version=17,
        do_constant_folding=True
    )
    print(f"✅ ONNX model exported: {onnx_path}")
except Exception as e:
    print(f"⚠️  ONNX export failed: {e}")
    print("Saving PyTorch model instead...")
    torch.save(export_model.state_dict(), "rust_model_llm.pth")
    print("✅ PyTorch model saved: rust_model_llm.pth")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/689.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m686.1/689.1 kB[0m [31m21.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m689.1/689.1 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/159.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m159.3/159.3 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25h

# 'dynamic_axes' is not recommended when dynamo=True, and may lead to 'torch._dynamo.exc.UserError: Constraints violated.' Supply the 'dynamic_shapes' argument instead if export is unsuccessful.
W0215 21:24:38.383000 1871 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 17 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features
W0215 21:24:39.732000 1871 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'input' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0, sampling_ratio: 'int' = -1, aligned: 'bool' = False). Treating as an Input.
W0215 21:24:39.733000 187

[torch.onnx] Obtain model graph for `RustDetectorONNX([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `RustDetectorONNX([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...




[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 52 of general pattern rewrite rules.
✅ ONNX model exported: rust_model_llm.onnx


## 10. Save Labels and Metadata

In [11]:
# Save labels
labels = ["no_rust", "rust"]
with open("rust_labels_llm.json", "w") as f:
    json.dump(labels, f, indent=2)

# Save metadata with model info
metadata = {
    "model_type": "CLIP-based classifier",
    "base_model": MODEL_NAME,
    "labels": labels,
    "input_size": [224, 224],
    "preprocessing": {
        "mean": [0.48145466, 0.4578275, 0.40821073],
        "std": [0.26862954, 0.26130258, 0.27577711]
    },
    "test_accuracy": float(test_acc),
    "test_balanced_accuracy": float(test_bal_acc),
    "training_samples": len(train_dataset),
    "validation_samples": len(val_dataset),
    "test_samples": len(test_dataset)
}

with open("rust_model_metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print("✅ Labels and metadata saved")
print(f"Labels: {labels}")
print(f"Test accuracy: {test_acc:.4f}")
print(f"Test balanced accuracy: {test_bal_acc:.4f}")

✅ Labels and metadata saved
Labels: ['no_rust', 'rust']
Test accuracy: 0.7919
Test balanced accuracy: 0.7206


## 11. Test ONNX Model

In [12]:
import onnxruntime as ort

try:
    # Load ONNX model
    ort_session = ort.InferenceSession(onnx_path)

    # Test inference
    test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
    ort_inputs = {ort_session.get_inputs()[0].name: test_input}
    ort_outputs = ort_session.run(None, ort_inputs)

    print("✅ ONNX model works!")
    print(f"Output shape: {ort_outputs[0].shape}")

    # Test on a real image
    if len(test_dataset) > 0:
        img, label = test_dataset[0]
        img_batch = img.unsqueeze(0).numpy()

        ort_inputs = {ort_session.get_inputs()[0].name: img_batch}
        ort_outputs = ort_session.run(None, ort_inputs)
        pred_class = ort_outputs[0].argmax(axis=1)[0]

        print(f"\nTest prediction: {labels[pred_class]} (true: {labels[label]})")
        print(f"Logits: {ort_outputs[0][0]}")

except Exception as e:
    print(f"⚠️  ONNX test failed: {e}")

✅ ONNX model works!
Output shape: (1, 2)

Test prediction: rust (true: no_rust)
Logits: [-5.710206   1.3972186]


## 12. Download Model Files

In [13]:
from google.colab import files
import os

print("Downloading model files...")

# Download all model files
if os.path.exists("rust_model_llm.onnx"):
    files.download("rust_model_llm.onnx")
    print(f"✅ Downloaded: rust_model_llm.onnx ({os.path.getsize('rust_model_llm.onnx') / 1e6:.1f} MB)")

if os.path.exists("rust_model_llm.pth"):
    files.download("rust_model_llm.pth")
    print(f"✅ Downloaded: rust_model_llm.pth ({os.path.getsize('rust_model_llm.pth') / 1e6:.1f} MB)")

files.download("rust_labels_llm.json")
files.download("rust_model_metadata.json")

print("\n✅ All files downloaded!")
print("\nFiles to use in your project:")
print("  - rust_model_llm.onnx (or .pth)")
print("  - rust_labels_llm.json")
print("  - rust_model_metadata.json")

Downloading model files...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Downloaded: rust_model_llm.onnx (0.1 MB)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


✅ All files downloaded!

Files to use in your project:
  - rust_model_llm.onnx (or .pth)
  - rust_labels_llm.json
  - rust_model_metadata.json


## 13. Generate Integration Code

In [None]:
integration_code = '''
# Integration code for your existing project

import onnxruntime as ort
import numpy as np
from PIL import Image

# Load LLM-based rust detection model
session = ort.InferenceSession("rust_model_llm.onnx")
with open("rust_labels_llm.json", "r") as f:
    labels = json.load(f)

def preprocess_image(image_path):
    """Preprocess image for LLM model"""
    img = Image.open(image_path).convert("RGB")
    img = img.resize((224, 224))

    # CLIP preprocessing
    img_array = np.array(img).astype(np.float32) / 255.0
    mean = np.array([0.48145466, 0.4578275, 0.40821073])
    std = np.array([0.26862954, 0.26130258, 0.27577711])
    img_array = (img_array - mean) / std

    # CHW format
    img_array = np.transpose(img_array, (2, 0, 1))
    img_array = np.expand_dims(img_array, 0)

    return img_array

def predict_rust_llm(image_path):
    """Predict rust using LLM-based model"""
    img = preprocess_image(image_path)

    inputs = {session.get_inputs()[0].name: img}
    logits = session.run(None, inputs)[0]

    pred_class = logits.argmax(axis=1)[0]
    confidence = np.exp(logits[0]) / np.exp(logits[0]).sum()

    return {
        "prediction": labels[pred_class],
        "confidence": float(confidence[pred_class]),
        "probabilities": {
            labels[i]: float(confidence[i]) for i in range(len(labels))
        }
    }

# Example usage
result = predict_rust_llm("path/to/image.jpg")
print(f"Prediction: {result['prediction']}")
print(f"Confidence: {result['confidence']:.2%}")
'''

with open("integration_code.py", "w") as f:
    f.write(integration_code)

print("✅ Integration code saved to integration_code.py")
print("\n" + "="*60)
print(integration_code)
print("="*60)

files.download("integration_code.py")

## Summary

### What you got:
1. **rust_model_llm.onnx** - CLIP-based vision-language model for rust detection
2. **rust_labels_llm.json** - Class labels (no_rust, rust)
3. **rust_model_metadata.json** - Model info and preprocessing params
4. **integration_code.py** - Code to use in your project

### Key Differences from CNN Model:
- Uses CLIP vision encoder (language-aware features)
- Better semantic understanding of "rust" concept
- Pre-trained on large image-text datasets
- More robust to variations

### Next Steps:
1. Download the files from Colab
2. Place them in your `artifacts/` folder
3. Update your image modality code to support both models
4. Compare performance: CNN vs LLM-based

### Performance:
- Test Accuracy: {test_acc:.2%}
- Balanced Accuracy: {test_bal_acc:.2%}
- Model Size: ~150MB (CLIP-based)