In [1]:
!pip -q install roboflow torch torchvision tqdm onnx onnxruntime

import os, json, glob, random
from pathlib import Path
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm.auto import tqdm


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/91.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.8/91.8 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.8/66.8 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.9/49.9 MB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.5/17.5 MB[0m [31m87.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m78.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m84.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
from roboflow import Roboflow

RF_API_KEY = "n7MyCMrV8InDdMIca94B"
rf = Roboflow(api_key=RF_API_KEY)

# IMPORTANT:
# On the Universe dataset page, click "Download Dataset" -> choose "Folder Structure" or "Image Classification"
# then copy the code snippet here (it will contain workspace/project/version).
project = rf.workspace("test-stage").project("rust-detection-t8vza")   # <-- may differ, use snippet
# Change "folder" to "yolov8"
dataset = project.version(8).download("yolov8")  # <-- use snippet's exact format/version
print("Dataset downloaded to:", dataset.location)


loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in Rust-Detection-8 to yolov8:: 100%|██████████| 1217454/1217454 [01:03<00:00, 19179.66it/s]





Extracting Dataset Version Zip to Rust-Detection-8 in yolov8:: 100%|██████████| 83900/83900 [00:15<00:00, 5549.48it/s]


Dataset downloaded to: /content/Rust-Detection-8


Colab: Convert YOLO dataset into ImageFolder classification

In [5]:
import os, shutil, random
from pathlib import Path
from tqdm.auto import tqdm

YOLO_ROOT = Path("/content/Rust-Detection-8")   # dataset.location
OUT_ROOT  = Path("/content/rust_cls")           # output classification folder
random.seed(42)

splits = ["train", "valid", "test"]

def label_has_boxes(label_path: Path) -> bool:
    if not label_path.exists():
        return False
    txt = label_path.read_text().strip()
    return len(txt) > 0

def convert_split(split: str):
    img_dir = YOLO_ROOT / split / "images"
    lab_dir = YOLO_ROOT / split / "labels"
    assert img_dir.exists(), f"missing {img_dir}"
    assert lab_dir.exists(), f"missing {lab_dir}"

    out_rust = OUT_ROOT / split / "rust"
    out_nor  = OUT_ROOT / split / "no_rust"
    out_rust.mkdir(parents=True, exist_ok=True)
    out_nor.mkdir(parents=True, exist_ok=True)

    imgs = list(img_dir.glob("*.*"))
    rust_count = 0
    nor_count = 0

    for p in tqdm(imgs, desc=f"convert {split}"):
        label_path = lab_dir / (p.stem + ".txt")
        is_rust = label_has_boxes(label_path)

        if is_rust:
            shutil.copy2(p, out_rust / p.name)
            rust_count += 1
        else:
            shutil.copy2(p, out_nor / p.name)
            nor_count += 1

    print(f"{split}: rust={rust_count}, no_rust={nor_count}")

for s in splits:
    convert_split(s)

print("✅ classification dataset at:", OUT_ROOT)


convert train:   0%|          | 0/37188 [00:00<?, ?it/s]

train: rust=35020, no_rust=2168


convert valid:   0%|          | 0/3156 [00:00<?, ?it/s]

valid: rust=2968, no_rust=188


convert test:   0%|          | 0/1600 [00:00<?, ?it/s]

test: rust=1471, no_rust=129
✅ classification dataset at: /content/rust_cls


do we actually have “no_rust” images?

In [6]:
for s in ["train","valid","test"]:
    r = len(list((OUT_ROOT/s/"rust").glob("*.*")))
    n = len(list((OUT_ROOT/s/"no_rust").glob("*.*")))
    print(s, "rust", r, "no_rust", n)


train rust 35020 no_rust 2168
valid rust 2968 no_rust 188
test rust 1471 no_rust 129


# Train rust/no_rust classifier (with class-imbalance handling) + export ONNX

Cell 1 — Imports + Datasets

In [7]:
!pip -q install torch torchvision tqdm onnx onnxruntime

from pathlib import Path
import json, numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm.auto import tqdm

OUT_ROOT = Path("/content/rust_cls")   # where you created ImageFolder dataset
train_dir = OUT_ROOT/"train"
val_dir   = OUT_ROOT/"valid"
test_dir  = OUT_ROOT/"test"

img_size = 224

train_tfms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1,0.1,0.1,0.05),
    transforms.ToTensor(),
])

val_tfms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
])

train_ds = datasets.ImageFolder(train_dir, transform=train_tfms)
val_ds   = datasets.ImageFolder(val_dir, transform=val_tfms)
test_ds  = datasets.ImageFolder(test_dir, transform=val_tfms)

print("classes:", train_ds.classes)

BATCH = 64
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


classes: ['no_rust', 'rust']
device: cuda


Cell 2 — Class weights (fix imbalance) + model

In [8]:
# Count class frequencies
targets = np.array(train_ds.targets)
class_counts = np.bincount(targets, minlength=len(train_ds.classes))
print("class_counts:", dict(zip(train_ds.classes, class_counts)))

# weights ~ inverse frequency
weights = (class_counts.sum() / (len(class_counts) * class_counts)).astype(np.float32)
print("class_weights:", dict(zip(train_ds.classes, weights)))

class_weights = torch.tensor(weights, dtype=torch.float32).to(device)

model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
model.classifier[3] = nn.Linear(model.classifier[3].in_features, len(train_ds.classes))
model = model.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)


class_counts: {'no_rust': np.int64(2168), 'rust': np.int64(35020)}
class_weights: {'no_rust': np.float32(8.576569), 'rust': np.float32(0.53095376)}
Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth


100%|██████████| 9.83M/9.83M [00:00<00:00, 135MB/s]


Cell 3 — Train + evaluate (accuracy + balanced accuracy)

In [9]:
@torch.no_grad()
def eval_metrics(loader):
    model.eval()
    all_y = []
    all_p = []
    for x,y in loader:
        x = x.to(device)
        logits = model(x)
        pred = logits.argmax(1).cpu().numpy()
        all_p.append(pred)
        all_y.append(y.numpy())
    y_true = np.concatenate(all_y)
    y_pred = np.concatenate(all_p)

    acc = (y_true == y_pred).mean()

    # balanced accuracy
    bal_acc = 0.0
    for c in range(len(train_ds.classes)):
        mask = (y_true == c)
        if mask.sum() > 0:
            bal_acc += (y_pred[mask] == c).mean()
    bal_acc /= len(train_ds.classes)

    return float(acc), float(bal_acc)

EPOCHS = 5
best_bal = 0.0
best_state = None

for ep in range(1, EPOCHS+1):
    model.train()
    pbar = tqdm(train_loader, desc=f"epoch {ep}/{EPOCHS}")
    for x,y in pbar:
        x,y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=float(loss))

    val_acc, val_bal = eval_metrics(val_loader)
    print(f"VAL acc={val_acc:.4f}  bal_acc={val_bal:.4f}")

    if val_bal > best_bal:
        best_bal = val_bal
        best_state = {k:v.cpu().clone() for k,v in model.state_dict().items()}

print("best_val_bal_acc:", best_bal)

# restore best
if best_state is not None:
    model.load_state_dict(best_state)

test_acc, test_bal = eval_metrics(test_loader)
print(f"TEST acc={test_acc:.4f}  bal_acc={test_bal:.4f}")


epoch 1/5:   0%|          | 0/582 [00:00<?, ?it/s]

Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)


VAL acc=0.8086  bal_acc=0.8360


epoch 2/5:   0%|          | 0/582 [00:00<?, ?it/s]

VAL acc=0.8821  bal_acc=0.8028


epoch 3/5:   0%|          | 0/582 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e5a9c00d940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e5a9c00d940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

VAL acc=0.8850  bal_acc=0.8118


epoch 4/5:   0%|          | 0/582 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e5a9c00d940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    Exception ignored in: if w.is_alive():
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e5a9c00d940>
 Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
      self._shutdown_workers() 
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
     ^if w.is_alive():
^ ^ ^ ^  ^ ^ ^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^ 
   File "/usr/lib/py

VAL acc=0.8777  bal_acc=0.7780


epoch 5/5:   0%|          | 0/582 [00:00<?, ?it/s]

VAL acc=0.9176  bal_acc=0.7694
best_val_bal_acc: 0.835970207031026
TEST acc=0.8387  bal_acc=0.7992


Cell 4 — Export ONNX + labels

In [11]:
!pip install onnxscript
model.eval().cpu()

onnx_path = "rust_model.onnx"
dummy = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy,
    onnx_path,
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
    opset_version=17
)

with open("rust_labels.json", "w") as f:
    json.dump(train_ds.classes, f, indent=2)

print("✅ Saved:", onnx_path, "rust_labels.json")
print("labels:", train_ds.classes)

Collecting onnxscript
  Downloading onnxscript-0.6.2-py3-none-any.whl.metadata (13 kB)
Collecting onnx_ir<2,>=0.1.15 (from onnxscript)
  Downloading onnx_ir-0.1.16-py3-none-any.whl.metadata (3.2 kB)
Downloading onnxscript-0.6.2-py3-none-any.whl (689 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/689.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m689.1/689.1 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnx_ir-0.1.16-py3-none-any.whl (159 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/159.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m159.3/159.3 kB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx_ir, onnxscript
Successfully installed onnx_ir-0.1.16 onnxscript-0.6.2


# 'dynamic_axes' is not recommended when dynamo=True, and may lead to 'torch._dynamo.exc.UserError: Constraints violated.' Supply the 'dynamic_shapes' argument instead if export is unsuccessful.
W0214 04:01:10.755000 2437 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 17 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features
W0214 04:01:11.499000 2437 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'input' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0, sampling_ratio: 'int' = -1, aligned: 'bool' = False). Treating as an Input.
W0214 04:01:11.500000 243

[torch.onnx] Obtain model graph for `MobileNetV3([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `MobileNetV3([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...




[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 120, in call
    converted_proto = _c_api_utils.call_onnx_api(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 115, in _partial_convert_version
    return onnx.version_converter.convert_version(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnx/version_converter.py", line 39, in convert_version
    converted_model_str = C.convert_version(model_str, target_version)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /github/workspace/onnx/version_converter/adapters/axes_input_to_attribute.h:65: adapt: Asserti

Applied 68 of general pattern rewrite rules.
✅ Saved: rust_model.onnx rust_labels.json
labels: ['no_rust', 'rust']


Cell 5 — Download to local

In [12]:
from google.colab import files
files.download("rust_model.onnx")
files.download("rust_labels.json")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>