In [1]:
### Configs ###
TEST_DATA_DIR = "../data/2025_Dataset"  # Put your test data directory here


## The path to put the models
POSE_CLASSIFICATION_MODEL_PATH = (
    "../model/resnet34_pose_model_single_label_augmented.pth"
)
DEPTH_REGRESSION_MODEL_PATH = "../model/resnet34_depth_augmented.pth"

## Model architecture in microrobot_dl.model.
POSE_CLASSIFICATION_MODEL_ARCH = "resnet34"
DEPTH_REGRESSION_MODEL_ARCH = "resnet34"

## Task for the model in microrobot_dl.task. Don't change.
POSE_CLASSIFICATION_MODEL_TASK = "pose_single"
DEPTH_REGRESSION_MODEL_TASK = "depth"

In [None]:
### Setup if using colab ###
import os
import sys
import subprocess
import shutil

COLAB_DATA_MODE = "mount"  # 'mount' or 'copy'
DRIVE_DATA_PATH = "/content/drive/MyDrive/microrobot-dl-data/data"  # Where your data is stored in Google Drive
GITHUB_REPO = "https://github.com/chihuangliu/microrobot-dl.git"
REPO_PATH = "/content/microrobot-dl"


def in_colab() -> bool:
    try:
        import google.colab  # type: ignore

        return True
    except Exception:
        return False


if in_colab():
    from google.colab import drive  # type: ignore

    print("Detected Colab. Mounting Drive...")
    drive.mount("/content/drive", force_remount=False)

    if not os.path.exists(REPO_PATH):
        print("Cloning repository...")
        subprocess.check_call(["git", "clone", GITHUB_REPO, REPO_PATH])
    else:
        print("Repository already cloned:", REPO_PATH)

    os.chdir(REPO_PATH)
    print("Installing package from", REPO_PATH)
    subprocess.check_call([sys.executable, "-m", "pip", "install", "uv"])
    subprocess.check_call([sys.executable, "-m", "uv", "pip", "install", "-e", "."])

    repo_data_path = os.path.join(REPO_PATH, "data")

    if os.path.exists(repo_data_path):
        if os.path.islink(repo_data_path):
            os.unlink(repo_data_path)
        elif os.path.isdir(repo_data_path):
            shutil.rmtree(repo_data_path)
        else:
            os.remove(repo_data_path)

    if not os.path.exists(DRIVE_DATA_PATH):
        print("Drive data path not found:", DRIVE_DATA_PATH)
    else:
        if COLAB_DATA_MODE == "copy":
            shutil.copytree(DRIVE_DATA_PATH, repo_data_path)
        elif COLAB_DATA_MODE == "mount":
            os.symlink(DRIVE_DATA_PATH, repo_data_path)

    notebooks_dir = os.path.join(REPO_PATH, "notebooks")
    if os.path.isdir(notebooks_dir):
        os.chdir(notebooks_dir)

    if REPO_PATH not in sys.path:
        sys.path.insert(0, REPO_PATH)
else:
    print("Not running on colab.")

Not running on colab.


In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from microrobot_dl.data_loader import ImageDataset2025
from microrobot_dl.model import get_model
from microrobot_dl.task import Task
from microrobot_dl.inference import evaluate_model

s
# Setup device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif (
    getattr(torch.backends, "mps", None) is not None
    and torch.backends.mps.is_available()
    and torch.backends.mps.is_built()
):
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

# Common Transform
transform_test = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ]
)

Using device: mps


# Pose Classification

In [4]:
print("=== Pose Classification Evaluation ===")

task = POSE_CLASSIFICATION_MODEL_TASK
multi_label = task == Task.pose_multi

# Load Dataset
dataset_pose = ImageDataset2025(
    base_dir=TEST_DATA_DIR,
    mode="pose",
    multi_label=multi_label,
    transform=transform_test,
)

test_loader_pose = DataLoader(dataset_pose, batch_size=64, shuffle=False)

# Determine num_outputs
if task == Task.pose_multi:
    num_classes_p = len(dataset_pose.idx_to_label_p)
    num_classes_r = len(dataset_pose.idx_to_label_r)
    num_outputs = num_classes_p + num_classes_r
else:
    num_outputs = len(dataset_pose.idx_to_label)
    num_classes_p = 0
    num_classes_r = 0

# Load Model (pose classification) - use configured arch variable
model_pose = get_model(
    POSE_CLASSIFICATION_MODEL_ARCH, num_outputs=num_outputs, in_channels=1
)
model_pose = model_pose.to(device)

checkpoint = torch.load(POSE_CLASSIFICATION_MODEL_PATH, map_location=device)
model_pose.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded pose model from {POSE_CLASSIFICATION_MODEL_PATH}")

# Evaluate
criterion = nn.CrossEntropyLoss()

results = evaluate_model(
    model_pose,
    test_loader_pose,
    device,
    task,
    criterion=criterion,
    num_classes_p=num_classes_p,
    num_classes_r=num_classes_r,
    multi_label=multi_label,
)

print(f"Accuracy: {results['accuracy']:.4f}")
if task == Task.pose_multi:
    print(f"Accuracy P: {results['accuracy_p']:.4f}")
    print(f"Accuracy R: {results['accuracy_r']:.4f}")
print(f"Loss: {results['loss']:.4f}")

=== Pose Classification Evaluation ===
Loaded pose model from ../model/resnet34_pose_model_single_label_augmented.pth
Loaded pose model from ../model/resnet34_pose_model_single_label_augmented.pth
Accuracy: 0.9795
Loss: 0.0713
Accuracy: 0.9795
Loss: 0.0713


# Depth regression

In [5]:
print("=== Depth Regression Evaluation ===")

task = DEPTH_REGRESSION_MODEL_TASK

# Load Dataset
dataset_depth = ImageDataset2025(
    base_dir=TEST_DATA_DIR,
    mode="depth",
    transform=transform_test,
)

test_loader_depth = DataLoader(dataset_depth, batch_size=64, shuffle=False)

# Load Model
model_depth = get_model(DEPTH_REGRESSION_MODEL_ARCH, num_outputs=1, in_channels=1)
model_depth = model_depth.to(device)

checkpoint = torch.load(DEPTH_REGRESSION_MODEL_PATH, map_location=device)
model_depth.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded depth model from {DEPTH_REGRESSION_MODEL_PATH}")

# Evaluate
criterion = nn.MSELoss()

results = evaluate_model(
    model_depth, test_loader_depth, device, task, criterion=criterion
)

print(f"RMSE: {results['rmse']:.4f}")

=== Depth Regression Evaluation ===
Loaded depth model from ../model/resnet34_depth_augmented.pth
Loaded depth model from ../model/resnet34_depth_augmented.pth
RMSE: 0.0515
RMSE: 0.0515
