## 1. Introduction: SAMIRO (Plug-and-play Regularization)

This notebook demonstrates how to integrate SAMIRO into your existing lane detection pipeline.

SAMIRO is a plug-and-play module that enhances performance by aligning spatial information between two networks:

Teacher Features: Intermediate feature maps from a self-supervised pre-trained backbone.

Student Features: Intermediate feature maps from your target lane detection backbone.

## 2. Environment Configuration
First, define the feature map dimensions for both the student and the teacher. This configuration allows SAMIRO to handle different channel sizes automatically.

In [5]:
from utils import SAMIROLossDemo
import torch

# Define feature map dimensions
# Student: Your lane detection backbone
# Teacher: Pre-trained self-supervised backbone
size_cfg = {
    "batch_size": 4,
    "student_feature_maps": [
        {"C": 64,  "H": 92, "W": 160, "K": 1},
        {"C": 128,  "H": 46, "W": 80, "K": 1},
        {"C": 256, "H": 46, "W": 80, "K": 2},
    ],
    "teacher_feature_maps": [
        {"C": 256,  "H": 92, "W": 160},
        {"C": 512,  "H": 46, "W": 80},
        {"C": 1024, "H": 23, "W": 40},
    ]
}

# Initialize SAMIRO Loss
model = SAMIROLossDemo(size_cfg)

## 3. Knowledge Transfer via Mutual Information
To compute the regularization loss, pass the feature lists from both models. SAMIRO will calculate the spatial attention-based mutual information loss.

In [6]:
# Dummy features representing backbone outputs
student_feats = [
    torch.randn(4, 64, 92, 160),
    torch.randn(4, 128, 46, 80),
    torch.randn(4, 256, 46, 80),
]

teacher_feats = [
    torch.randn(4, 256, 92, 160),
    torch.randn(4, 512, 46, 80),
    torch.randn(4, 1024, 23, 40),
]

# Calculate Regularization Loss
reg_loss = model(student_feats, teacher_feats)
print("SAMIRO Reg Loss:", reg_loss.item())

SAMIRO Reg Loss: 5.309111595153809


## 4 . Final Training Objective
To train your model with SAMIRO, simply add the reg_loss to your original lane detection loss. $$Total\_Loss = Loss_{Lane} + \lambda \cdot Loss_{SAMIRO}$$