<a href="https://colab.research.google.com/github/haddybhaiya/sem-i-con/blob/main/train_convNeXt_tiny.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install timm
!pip install scikit-learn




In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import timm
import os
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
from tqdm import tqdm


In [None]:
DATASET_PATH = "/content/drive/MyDrive/synthetic_dataset"   # change if needed
NUM_CLASSES = 8
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 20
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CLASSES = [
    "clean","bridge","cmp","crack","open","ler","via","other"
]


In [None]:
class SEMDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, cls = self.samples[idx]

        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        img = img.astype(np.float32) / 255.0

        img = np.expand_dims(img, axis=0)

        label = CLASSES.index(cls)

        return torch.tensor(img), torch.tensor(label)


In [None]:
from sklearn.model_selection import train_test_split

all_samples = []

for cls in CLASSES:
    folder = os.path.join(DATASET_PATH, cls)

    for img in os.listdir(folder):
        all_samples.append((os.path.join(folder, img), cls))

train_samples, val_samples = train_test_split(
    all_samples,
    test_size=0.2,
    stratify=[s[1] for s in all_samples],   # keeps class balance
    random_state=42
)

print("Train:", len(train_samples))
print("Val:", len(val_samples))


Train: 1920
Val: 480


In [None]:
train_dataset = SEMDataset(train_samples)
val_dataset   = SEMDataset(val_samples)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE)


In [None]:
model = build_convnext().to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for x, y in tqdm(train_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss:", total_loss / len(train_loader))


100%|██████████| 120/120 [21:58<00:00, 10.99s/it]


Epoch 1 Loss: 0.25449785136200564


100%|██████████| 120/120 [00:41<00:00,  2.89it/s]


Epoch 2 Loss: 0.03702565324122891


100%|██████████| 120/120 [00:40<00:00,  2.95it/s]


Epoch 3 Loss: 0.013663354355321644


100%|██████████| 120/120 [00:41<00:00,  2.92it/s]


Epoch 4 Loss: 0.0025174609759233135


100%|██████████| 120/120 [00:41<00:00,  2.92it/s]


Epoch 5 Loss: 7.263057671783221e-05


100%|██████████| 120/120 [00:41<00:00,  2.93it/s]


Epoch 6 Loss: 3.920390437694247e-05


100%|██████████| 120/120 [00:40<00:00,  2.93it/s]


Epoch 7 Loss: 2.9905097881055555e-05


100%|██████████| 120/120 [00:40<00:00,  2.94it/s]


Epoch 8 Loss: 2.381059314302547e-05


100%|██████████| 120/120 [00:41<00:00,  2.92it/s]


Epoch 9 Loss: 1.9478914578030525e-05


100%|██████████| 120/120 [00:41<00:00,  2.92it/s]


Epoch 10 Loss: 1.6277008057841158e-05


100%|██████████| 120/120 [00:40<00:00,  2.93it/s]


Epoch 11 Loss: 1.3809070946990686e-05


100%|██████████| 120/120 [00:40<00:00,  2.95it/s]


Epoch 12 Loss: 1.1877049848862952e-05


100%|██████████| 120/120 [00:41<00:00,  2.93it/s]


Epoch 13 Loss: 1.0320555869232825e-05


100%|██████████| 120/120 [00:40<00:00,  2.94it/s]


Epoch 14 Loss: 9.044484075578415e-06


100%|██████████| 120/120 [00:41<00:00,  2.92it/s]


Epoch 15 Loss: 7.981408517328721e-06


100%|██████████| 120/120 [00:40<00:00,  2.93it/s]


Epoch 16 Loss: 7.096141278376914e-06


100%|██████████| 120/120 [00:41<00:00,  2.92it/s]


Epoch 17 Loss: 6.337951964496824e-06


100%|██████████| 120/120 [00:40<00:00,  2.93it/s]


Epoch 18 Loss: 5.684935202528626e-06


100%|██████████| 120/120 [00:40<00:00,  2.93it/s]


Epoch 19 Loss: 5.126411087985616e-06


100%|██████████| 120/120 [00:41<00:00,  2.92it/s]

Epoch 20 Loss: 4.630343515070005e-06





In [None]:
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for x, y in val_loader:
        x = x.to(DEVICE)

        out = model(x)
        preds = torch.argmax(out, dim=1).cpu()

        all_preds.extend(preds.numpy())
        all_labels.extend(y.numpy())

print(classification_report(all_labels, all_preds, target_names=CLASSES))


              precision    recall  f1-score   support

       clean       1.00      1.00      1.00        60
      bridge       1.00      1.00      1.00        60
         cmp       1.00      1.00      1.00        60
       crack       1.00      1.00      1.00        60
        open       1.00      1.00      1.00        60
         ler       1.00      1.00      1.00        60
         via       1.00      1.00      1.00        60
       other       1.00      1.00      1.00        60

    accuracy                           1.00       480
   macro avg       1.00      1.00      1.00       480
weighted avg       1.00      1.00      1.00       480



In [None]:
torch.save(model.state_dict(), "convnext_sem.pth")
print("Model saved")


Model saved


In [None]:
!pip install onnxruntime
!pip install onnx
!pip install onnxscript

Collecting onnxruntime
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m69.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected pac

In [None]:
import torch
import timm
import torch.nn as nn
import os

MODEL_PATH = "/content/convnext_sem.pth"
ONNX_PATH = "sem.onnx"

NUM_CLASSES = 8
IMG_SIZE = 224
DEVICE = "cpu"

CLASSES = ["clean","bridge","cmp","crack","open","ler","via","other"]

# ---- Build Model ----
model = timm.create_model("convnext_tiny", pretrained=False, in_chans=1)
model.head.fc = nn.Linear(model.head.fc.in_features, NUM_CLASSES)

model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

dummy_input = torch.randn(1, 1, IMG_SIZE, IMG_SIZE)

torch.onnx.export(
    model,
    dummy_input,
    ONNX_PATH,
    input_names=["input"],
    output_names=["output"],
    opset_version=17,
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
)

print("✅ FP32 ONNX exported:", ONNX_PATH)


  torch.onnx.export(
W0203 10:55:42.174000 800 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 17 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features
W0203 10:55:43.365000 800 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'input' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0, sampling_ratio: 'int' = -1, aligned: 'bool' = False). Treating as an Input.
W0203 10:55:43.367000 800 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'boxes' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0, samplin

[torch.onnx] Obtain model graph for `ConvNeXt([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ConvNeXt([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 122, in _partial_convert_version
    return onnx.version_converter.convert_version(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnx/version_converter.py", line 39, in convert_version
    converted_model_str = C.convert_version(model_str, target_version)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /github/workspace/onnx/version_converter/adapters/axes_input_to_attribute.h:65: adapt: Asserti

[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅




✅ FP32 ONNX exported: sem.onnx


In [None]:
import torch
import timm
import torch.nn as nn

MODEL_PATH = "convnext_sem.pth"
OUT_PATH = "sem_int8.pth"
NUM_CLASSES = 8

device = "cpu"

# Build model
model = timm.create_model(
    "convnext_tiny",
    pretrained=False,
    num_classes=NUM_CLASSES,
    in_chans=1
)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

# ---- Dynamic quantization ----
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear},
    dtype=torch.qint8
)

torch.save(quantized_model.state_dict(), OUT_PATH)

print("✅ Torch INT8 model saved:", OUT_PATH)


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.quantization.quantize_dynamic(


✅ Torch INT8 model saved: sem_int8.pth


In [None]:
import torch
import timm

MODEL_PATH = "sem_int8.pth"
OUT_PATH = "sem_int8.pt"
NUM_CLASSES = 8

model = timm.create_model(
    "convnext_tiny",
    pretrained=False,
    num_classes=NUM_CLASSES,
    in_chans=1
)

model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()

example = torch.randn(1,1,224,224)

scripted = torch.jit.trace(model, example)
scripted.save(OUT_PATH)

print("✅ TorchScript INT8 exported:", OUT_PATH)


RuntimeError: Error(s) in loading state_dict for ConvNeXt:
	Missing key(s) in state_dict: "stages.0.blocks.0.mlp.fc1.weight", "stages.0.blocks.0.mlp.fc1.bias", "stages.0.blocks.0.mlp.fc2.weight", "stages.0.blocks.0.mlp.fc2.bias", "stages.0.blocks.1.mlp.fc1.weight", "stages.0.blocks.1.mlp.fc1.bias", "stages.0.blocks.1.mlp.fc2.weight", "stages.0.blocks.1.mlp.fc2.bias", "stages.0.blocks.2.mlp.fc1.weight", "stages.0.blocks.2.mlp.fc1.bias", "stages.0.blocks.2.mlp.fc2.weight", "stages.0.blocks.2.mlp.fc2.bias", "stages.1.blocks.0.mlp.fc1.weight", "stages.1.blocks.0.mlp.fc1.bias", "stages.1.blocks.0.mlp.fc2.weight", "stages.1.blocks.0.mlp.fc2.bias", "stages.1.blocks.1.mlp.fc1.weight", "stages.1.blocks.1.mlp.fc1.bias", "stages.1.blocks.1.mlp.fc2.weight", "stages.1.blocks.1.mlp.fc2.bias", "stages.1.blocks.2.mlp.fc1.weight", "stages.1.blocks.2.mlp.fc1.bias", "stages.1.blocks.2.mlp.fc2.weight", "stages.1.blocks.2.mlp.fc2.bias", "stages.2.blocks.0.mlp.fc1.weight", "stages.2.blocks.0.mlp.fc1.bias", "stages.2.blocks.0.mlp.fc2.weight", "stages.2.blocks.0.mlp.fc2.bias", "stages.2.blocks.1.mlp.fc1.weight", "stages.2.blocks.1.mlp.fc1.bias", "stages.2.blocks.1.mlp.fc2.weight", "stages.2.blocks.1.mlp.fc2.bias", "stages.2.blocks.2.mlp.fc1.weight", "stages.2.blocks.2.mlp.fc1.bias", "stages.2.blocks.2.mlp.fc2.weight", "stages.2.blocks.2.mlp.fc2.bias", "stages.2.blocks.3.mlp.fc1.weight", "stages.2.blocks.3.mlp.fc1.bias", "stages.2.blocks.3.mlp.fc2.weight", "stages.2.blocks.3.mlp.fc2.bias", "stages.2.blocks.4.mlp.fc1.weight", "stages.2.blocks.4.mlp.fc1.bias", "stages.2.blocks.4.mlp.fc2.weight", "stages.2.blocks.4.mlp.fc2.bias", "stages.2.blocks.5.mlp.fc1.weight", "stages.2.blocks.5.mlp.fc1.bias", "stages.2.blocks.5.mlp.fc2.weight", "stages.2.blocks.5.mlp.fc2.bias", "stages.2.blocks.6.mlp.fc1.weight", "stages.2.blocks.6.mlp.fc1.bias", "stages.2.blocks.6.mlp.fc2.weight", "stages.2.blocks.6.mlp.fc2.bias", "stages.2.blocks.7.mlp.fc1.weight", "stages.2.blocks.7.mlp.fc1.bias", "stages.2.blocks.7.mlp.fc2.weight", "stages.2.blocks.7.mlp.fc2.bias", "stages.2.blocks.8.mlp.fc1.weight", "stages.2.blocks.8.mlp.fc1.bias", "stages.2.blocks.8.mlp.fc2.weight", "stages.2.blocks.8.mlp.fc2.bias", "stages.3.blocks.0.mlp.fc1.weight", "stages.3.blocks.0.mlp.fc1.bias", "stages.3.blocks.0.mlp.fc2.weight", "stages.3.blocks.0.mlp.fc2.bias", "stages.3.blocks.1.mlp.fc1.weight", "stages.3.blocks.1.mlp.fc1.bias", "stages.3.blocks.1.mlp.fc2.weight", "stages.3.blocks.1.mlp.fc2.bias", "stages.3.blocks.2.mlp.fc1.weight", "stages.3.blocks.2.mlp.fc1.bias", "stages.3.blocks.2.mlp.fc2.weight", "stages.3.blocks.2.mlp.fc2.bias", "head.fc.weight", "head.fc.bias". 
	Unexpected key(s) in state_dict: "stages.0.blocks.0.mlp.fc1.scale", "stages.0.blocks.0.mlp.fc1.zero_point", "stages.0.blocks.0.mlp.fc1._packed_params.dtype", "stages.0.blocks.0.mlp.fc1._packed_params._packed_params", "stages.0.blocks.0.mlp.fc2.scale", "stages.0.blocks.0.mlp.fc2.zero_point", "stages.0.blocks.0.mlp.fc2._packed_params.dtype", "stages.0.blocks.0.mlp.fc2._packed_params._packed_params", "stages.0.blocks.1.mlp.fc1.scale", "stages.0.blocks.1.mlp.fc1.zero_point", "stages.0.blocks.1.mlp.fc1._packed_params.dtype", "stages.0.blocks.1.mlp.fc1._packed_params._packed_params", "stages.0.blocks.1.mlp.fc2.scale", "stages.0.blocks.1.mlp.fc2.zero_point", "stages.0.blocks.1.mlp.fc2._packed_params.dtype", "stages.0.blocks.1.mlp.fc2._packed_params._packed_params", "stages.0.blocks.2.mlp.fc1.scale", "stages.0.blocks.2.mlp.fc1.zero_point", "stages.0.blocks.2.mlp.fc1._packed_params.dtype", "stages.0.blocks.2.mlp.fc1._packed_params._packed_params", "stages.0.blocks.2.mlp.fc2.scale", "stages.0.blocks.2.mlp.fc2.zero_point", "stages.0.blocks.2.mlp.fc2._packed_params.dtype", "stages.0.blocks.2.mlp.fc2._packed_params._packed_params", "stages.1.blocks.0.mlp.fc1.scale", "stages.1.blocks.0.mlp.fc1.zero_point", "stages.1.blocks.0.mlp.fc1._packed_params.dtype", "stages.1.blocks.0.mlp.fc1._packed_params._packed_params", "stages.1.blocks.0.mlp.fc2.scale", "stages.1.blocks.0.mlp.fc2.zero_point", "stages.1.blocks.0.mlp.fc2._packed_params.dtype", "stages.1.blocks.0.mlp.fc2._packed_params._packed_params", "stages.1.blocks.1.mlp.fc1.scale", "stages.1.blocks.1.mlp.fc1.zero_point", "stages.1.blocks.1.mlp.fc1._packed_params.dtype", "stages.1.blocks.1.mlp.fc1._packed_params._packed_params", "stages.1.blocks.1.mlp.fc2.scale", "stages.1.blocks.1.mlp.fc2.zero_point", "stages.1.blocks.1.mlp.fc2._packed_params.dtype", "stages.1.blocks.1.mlp.fc2._packed_params._packed_params", "stages.1.blocks.2.mlp.fc1.scale", "stages.1.blocks.2.mlp.fc1.zero_point", "stages.1.blocks.2.mlp.fc1._packed_params.dtype", "stages.1.blocks.2.mlp.fc1._packed_params._packed_params", "stages.1.blocks.2.mlp.fc2.scale", "stages.1.blocks.2.mlp.fc2.zero_point", "stages.1.blocks.2.mlp.fc2._packed_params.dtype", "stages.1.blocks.2.mlp.fc2._packed_params._packed_params", "stages.2.blocks.0.mlp.fc1.scale", "stages.2.blocks.0.mlp.fc1.zero_point", "stages.2.blocks.0.mlp.fc1._packed_params.dtype", "stages.2.blocks.0.mlp.fc1._packed_params._packed_params", "stages.2.blocks.0.mlp.fc2.scale", "stages.2.blocks.0.mlp.fc2.zero_point", "stages.2.blocks.0.mlp.fc2._packed_params.dtype", "stages.2.blocks.0.mlp.fc2._packed_params._packed_params", "stages.2.blocks.1.mlp.fc1.scale", "stages.2.blocks.1.mlp.fc1.zero_point", "stages.2.blocks.1.mlp.fc1._packed_params.dtype", "stages.2.blocks.1.mlp.fc1._packed_params._packed_params", "stages.2.blocks.1.mlp.fc2.scale", "stages.2.blocks.1.mlp.fc2.zero_point", "stages.2.blocks.1.mlp.fc2._packed_params.dtype", "stages.2.blocks.1.mlp.fc2._packed_params._packed_params", "stages.2.blocks.2.mlp.fc1.scale", "stages.2.blocks.2.mlp.fc1.zero_point", "stages.2.blocks.2.mlp.fc1._packed_params.dtype", "stages.2.blocks.2.mlp.fc1._packed_params._packed_params", "stages.2.blocks.2.mlp.fc2.scale", "stages.2.blocks.2.mlp.fc2.zero_point", "stages.2.blocks.2.mlp.fc2._packed_params.dtype", "stages.2.blocks.2.mlp.fc2._packed_params._packed_params", "stages.2.blocks.3.mlp.fc1.scale", "stages.2.blocks.3.mlp.fc1.zero_point", "stages.2.blocks.3.mlp.fc1._packed_params.dtype", "stages.2.blocks.3.mlp.fc1._packed_params._packed_params", "stages.2.blocks.3.mlp.fc2.scale", "stages.2.blocks.3.mlp.fc2.zero_point", "stages.2.blocks.3.mlp.fc2._packed_params.dtype", "stages.2.blocks.3.mlp.fc2._packed_params._packed_params", "stages.2.blocks.4.mlp.fc1.scale", "stages.2.blocks.4.mlp.fc1.zero_point", "stages.2.blocks.4.mlp.fc1._packed_params.dtype", "stages.2.blocks.4.mlp.fc1._packed_params._packed_params", "stages.2.blocks.4.mlp.fc2.scale", "stages.2.blocks.4.mlp.fc2.zero_point", "stages.2.blocks.4.mlp.fc2._packed_params.dtype", "stages.2.blocks.4.mlp.fc2._packed_params._packed_params", "stages.2.blocks.5.mlp.fc1.scale", "stages.2.blocks.5.mlp.fc1.zero_point", "stages.2.blocks.5.mlp.fc1._packed_params.dtype", "stages.2.blocks.5.mlp.fc1._packed_params._packed_params", "stages.2.blocks.5.mlp.fc2.scale", "stages.2.blocks.5.mlp.fc2.zero_point", "stages.2.blocks.5.mlp.fc2._packed_params.dtype", "stages.2.blocks.5.mlp.fc2._packed_params._packed_params", "stages.2.blocks.6.mlp.fc1.scale", "stages.2.blocks.6.mlp.fc1.zero_point", "stages.2.blocks.6.mlp.fc1._packed_params.dtype", "stages.2.blocks.6.mlp.fc1._packed_params._packed_params", "stages.2.blocks.6.mlp.fc2.scale", "stages.2.blocks.6.mlp.fc2.zero_point", "stages.2.blocks.6.mlp.fc2._packed_params.dtype", "stages.2.blocks.6.mlp.fc2._packed_params._packed_params", "stages.2.blocks.7.mlp.fc1.scale", "stages.2.blocks.7.mlp.fc1.zero_point", "stages.2.blocks.7.mlp.fc1._packed_params.dtype", "stages.2.blocks.7.mlp.fc1._packed_params._packed_params", "stages.2.blocks.7.mlp.fc2.scale", "stages.2.blocks.7.mlp.fc2.zero_point", "stages.2.blocks.7.mlp.fc2._packed_params.dtype", "stages.2.blocks.7.mlp.fc2._packed_params._packed_params", "stages.2.blocks.8.mlp.fc1.scale", "stages.2.blocks.8.mlp.fc1.zero_point", "stages.2.blocks.8.mlp.fc1._packed_params.dtype", "stages.2.blocks.8.mlp.fc1._packed_params._packed_params", "stages.2.blocks.8.mlp.fc2.scale", "stages.2.blocks.8.mlp.fc2.zero_point", "stages.2.blocks.8.mlp.fc2._packed_params.dtype", "stages.2.blocks.8.mlp.fc2._packed_params._packed_params", "stages.3.blocks.0.mlp.fc1.scale", "stages.3.blocks.0.mlp.fc1.zero_point", "stages.3.blocks.0.mlp.fc1._packed_params.dtype", "stages.3.blocks.0.mlp.fc1._packed_params._packed_params", "stages.3.blocks.0.mlp.fc2.scale", "stages.3.blocks.0.mlp.fc2.zero_point", "stages.3.blocks.0.mlp.fc2._packed_params.dtype", "stages.3.blocks.0.mlp.fc2._packed_params._packed_params", "stages.3.blocks.1.mlp.fc1.scale", "stages.3.blocks.1.mlp.fc1.zero_point", "stages.3.blocks.1.mlp.fc1._packed_params.dtype", "stages.3.blocks.1.mlp.fc1._packed_params._packed_params", "stages.3.blocks.1.mlp.fc2.scale", "stages.3.blocks.1.mlp.fc2.zero_point", "stages.3.blocks.1.mlp.fc2._packed_params.dtype", "stages.3.blocks.1.mlp.fc2._packed_params._packed_params", "stages.3.blocks.2.mlp.fc1.scale", "stages.3.blocks.2.mlp.fc1.zero_point", "stages.3.blocks.2.mlp.fc1._packed_params.dtype", "stages.3.blocks.2.mlp.fc1._packed_params._packed_params", "stages.3.blocks.2.mlp.fc2.scale", "stages.3.blocks.2.mlp.fc2.zero_point", "stages.3.blocks.2.mlp.fc2._packed_params.dtype", "stages.3.blocks.2.mlp.fc2._packed_params._packed_params", "head.fc.scale", "head.fc.zero_point", "head.fc._packed_params.dtype", "head.fc._packed_params._packed_params". 

In [None]:
import onnxruntime as ort
import cv2
import numpy as np

MODEL_PATH = "models/sem_int8.onnx"

CLASSES = ["clean","bridge","cmp","crack","open","ler","via","other"]

IMG_SIZE = 224
OTHER_THRESHOLD = 0.65

session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name


def preprocess(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    img = img.astype(np.float32) / 255.0

    img = np.expand_dims(img, axis=0)
    img = np.expand_dims(img, axis=0)

    return img


def softmax(x):
    e = np.exp(x - np.max(x))
    return e / np.sum(e)


def infer(img_path):
    x = preprocess(img_path)
    out = session.run(None, {input_name: x})

    logits = out[0][0]
    probs = softmax(logits)

    cls_id = int(np.argmax(probs))
    conf = float(probs[cls_id])

    pred = CLASSES[cls_id]

    if conf < OTHER_THRESHOLD:
        pred = "other"

    return pred, conf


if __name__ == "__main__":
    img = "dataset/sample/test.png"
    print(infer(img))


In [None]:
import psutil
import time
import onnxruntime as ort
import numpy as np
import cv2

FP32_MODEL = "models/sem.onnx"
INT8_MODEL = "models/sem_int8.onnx"

CLASSES = ["clean","bridge","cmp","crack","open","ler","via","other"]

IMG_SIZE_HIGH = 224
IMG_SIZE_LOW = 160

CPU_THRESHOLD = 70

session_fp32 = ort.InferenceSession(FP32_MODEL)
session_int8 = ort.InferenceSession(INT8_MODEL)

input_fp32 = session_fp32.get_inputs()[0].name
input_int8 = session_int8.get_inputs()[0].name


def preprocess(img_path, size):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (size, size))
    img = img.astype(np.float32) / 255.0

    img = np.expand_dims(img, axis=0)
    img = np.expand_dims(img, axis=0)

    return img


def softmax(x):
    e = np.exp(x - np.max(x))
    return e / np.sum(e)


def auto_edge(img_path):
    cpu = psutil.cpu_percent(interval=0.1)

    if cpu > CPU_THRESHOLD:
        session = session_int8
        input_name = input_int8
        size = IMG_SIZE_LOW
        mode = "INT8_LOW"
    else:
        session = session_fp32
        input_name = input_fp32
        size = IMG_SIZE_HIGH
        mode = "FP32_HIGH"

    start = time.time()

    x = preprocess(img_path, size)
    out = session.run(None, {input_name: x})

    latency = time.time() - start

    logits = out[0][0]
    probs = softmax(logits)

    cls_id = int(np.argmax(probs))
    conf = float(probs[cls_id])

    return {
        "class": CLASSES[cls_id],
        "confidence": conf,
        "mode": mode,
        "latency": latency,
        "cpu": cpu
    }


if __name__ == "__main__":
    print(auto_edge("dataset/sample/test.png"))


In [None]:
from sklearn.metrics import classification_report
import os

y_true = []
y_pred = []

TEST_PATH = "dataset/test"

for cls in CLASSES:
    folder = os.path.join(TEST_PATH, cls)

    for img in os.listdir(folder):
        path = os.path.join(folder, img)

        pred, _ = infer(path)

        y_true.append(cls)
        y_pred.append(pred)

print(classification_report(y_true, y_pred))
