# Set Up the Environment

### Importing Libraries

In [None]:

%pip install torchmetrics
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
import torchvision
from torchvision import transforms
from torchmetrics.image import StructuralSimilarityIndexMeasure
from transformers import AutoModelForDepthEstimation
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
import time
import PIL
from PIL import Image
import requests
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

### Mount Drive

In [2]:
from google.colab import drive

drive.mount('/content/drive')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Define Needed Classes & Functions

### Class for Depth Model

In [3]:
class DepthModel(nn.Module):
    """
    Wrapper class for the Depth-Anything model.
    """
    def __init__(self, mod):
        super().__init__()
        self.model = mod
        self.eval()

    def forward(self, x):
        original_size = x.shape[2:]
        outputs = self.model(x)
        predicted_depth = outputs.predicted_depth
        predicted_depth = F.interpolate(predicted_depth.unsqueeze(1), size=original_size, mode='bilinear', align_corners=False)
        return predicted_depth

### Class for Student Model

In [4]:
class StudentModel(nn.Module):
    """
    Student model with MobileNetV3 encoder and a custom decoder.
    Designed for real-time inference on edge devices.
    """
    def __init__(self, output_channels=1):
        super(StudentModel, self).__init__()
        # Load MobileNetV3 Large features as the encoder
        self.encoder = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT).features

        # Freeze encoder parameters
        for param in self.encoder.parameters():
            param.requires_grad = False

        # --- Enhancement: 1x1 Convolutions for Skip Connections ---
        # These layers refine the features from the encoder before fusion.
        # Channel dimensions correspond to MobileNetV3 skip connection outputs.
        self.skip_s4_conv = nn.Sequential(
            nn.Conv2d(24, 24, kernel_size=1, bias=False),
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True)
        )
        self.skip_s8_conv = nn.Sequential(
            nn.Conv2d(40, 40, kernel_size=1, bias=False),
            nn.BatchNorm2d(40),
            nn.ReLU(inplace=True)
        )
        self.skip_s16_conv = nn.Sequential(
            nn.Conv2d(80, 80, kernel_size=1, bias=False),
            nn.BatchNorm2d(80),
            nn.ReLU(inplace=True)
        )
        self.skip_s32_conv = nn.Sequential(
            nn.Conv2d(112, 112, kernel_size=1, bias=False),
            nn.BatchNorm2d(112),
            nn.ReLU(inplace=True)
        )

        # Decoder blocks (input channels remain the same as the 1x1 convs don't change channel dims)
        self.decoder_block1 = nn.Sequential(
            nn.Conv2d(960 + 112, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.decoder_block2 = nn.Sequential(
            nn.Conv2d(512 + 80, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.decoder_block3 = nn.Sequential(
            nn.Conv2d(256 + 40, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.decoder_block4 = nn.Sequential(
            nn.Conv2d(128 + 24, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )

        self.final_conv = nn.Conv2d(64, output_channels, kernel_size=1)
        self.final_activation = nn.ReLU()

    def forward(self, x):
        input_shape = x.shape[2:]
        skip_features = {}

        # --- Encoder Path & Feature Extraction ---
        # Iterate through encoder layers to get skip connections
        # Apply 1x1 convs immediately after extraction.
        for i, layer in enumerate(self.encoder):
            x = layer(x)
            if i == 2:  # s4 (H/4, W/4 resolution, 24 channels)
                skip_features['s4'] = self.skip_s4_conv(x)
            elif i == 4:  # s8 (H/8, W/8 resolution, 40 channels)
                skip_features['s8'] = self.skip_s8_conv(x)
            elif i == 7:  # s16 (H/16, W/16 resolution, 80 channels)
                skip_features['s16'] = self.skip_s16_conv(x)
            elif i == 11:  # s32 (H/16, W/16 resolution, 112 channels)
                skip_features['s32'] = self.skip_s32_conv(x)

        # Ensure s32 matches the spatial dimension of the encoder's final output
        if skip_features['s32'].shape[2:] != x.shape[2:]:
            skip_features['s32'] = F.interpolate(skip_features['s32'], size=x.shape[2:], mode='bilinear', align_corners=False)

        # --- Decoder Path with Enhanced Skip Connections ---
        # Concatenate final encoder output (x) with the refined s32 skip connection
        x = torch.cat([x, skip_features['s32']], dim=1)
        x = self.decoder_block1(x)

        # Concatenate with refined s16
        # Ensure spatial dimensions match before concatenation
        if skip_features['s16'].shape[2:] != x.shape[2:]:
             skip_features['s16'] = F.interpolate(skip_features['s16'], size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip_features['s16']], dim=1)
        x = self.decoder_block2(x)

        # Concatenate with refined s8
        # Ensure spatial dimensions match before concatenation
        if skip_features['s8'].shape[2:] != x.shape[2:]:
             skip_features['s8'] = F.interpolate(skip_features['s8'], size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip_features['s8']], dim=1)
        x = self.decoder_block3(x)

        # Concatenate with refined s4
        # Ensure spatial dimensions match before concatenation
        if skip_features['s4'].shape[2:] != x.shape[2:]:
             skip_features['s4'] = F.interpolate(skip_features['s4'], size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip_features['s4']], dim=1)
        x = self.decoder_block4(x)

        # Final layers
        x = self.final_conv(x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        x = self.final_activation(x)

        return x

### Class for Dataset Loading & Preprocessing

In [16]:
class UnlabeledImageDataset(Dataset):
    """
    Custom dataset for unlabeled images.
    """
    def __init__(self, root_dir, transform=None, resize_size=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(('png', 'jpg', 'JPG'))]
        self.resize_size = resize_size
        print(f"Found {len(self.image_paths)} images in {root_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.resize_size:
            image = image.resize(self.resize_size)

        if self.transform:
            image = self.transform(image)


        return image

### Distillation Loss

In [6]:
class DepthCombinedDistillationLoss(nn.Module):
    """
    Custom loss function for depth distillation.
    Using Pixel wise MSE Loss, SSIM, Gradient MAE Loss, and Scale-Invariant MSE Loss.
    """
    def __init__(self, lambda_depth=1.0, lambda_si=1.0, lambda_grad=1.0, lambda_ssim=1.0, lambda_smooth=0.2, window_size=11):
        super().__init__()
        self.lambda_depth = lambda_depth   # Weight for depth map MSE loss
        self.lambda_si = lambda_si         # Weight for Scale-Invariant MSE loss
        self.lambda_grad = lambda_grad     # Weight for Gradient loss
        self.lambda_ssim = lambda_ssim     # Weight for SSIM loss
        self.lambda_smooth = lambda_smooth # Weight for smoothness regularizer

        self.mse_depth_loss = nn.MSELoss()  # Mean Squared Error for depth maps
        self.l1_loss = nn.L1Loss()          # L1 Loss for gradients

        # Initialize the SSIM calculation from torchmetrics
        if self.lambda_ssim > 0:
            self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0, kernel_size=window_size)

    def forward(self, student_outputs, teacher_outputs):
        student_depth = student_outputs
        teacher_depth = teacher_outputs

        # Ensure tensors have a batch dimension
        if student_depth.dim() == 3:
            student_depth = student_depth.unsqueeze(0)
        if teacher_depth.dim() == 3:
            teacher_depth = teacher_depth.unsqueeze(0)

        # Ensure tensors have a channel dimension for SSIM
        if student_depth.dim() == 3: # Assuming (Batch, Height, Width)
            student_depth = student_depth.unsqueeze(1) # Becomes (Batch, Channel, Height, Width)
        if teacher_depth.dim() == 3: # Assuming (Batch, Height, Width)
            teacher_depth = teacher_depth.unsqueeze(1) # Becomes (Batch, Channel, Height, Width)


        total_loss = torch.tensor(0.0, device=student_depth.device)

        # 1. MSE Depth Loss
        if self.lambda_depth > 0:
            loss_depth = self.mse_depth_loss(student_depth, teacher_depth)
            total_loss += self.lambda_depth * loss_depth

        # 2. Scale-Invariant MSE Loss
        if self.lambda_si > 0:
            diff = student_depth - teacher_depth
            loss_si = torch.mean(diff**2) - torch.mean(diff)**2
            total_loss += self.lambda_si * loss_si

        # 3. Gradient Loss (using L1 on gradients)
        if self.lambda_grad > 0:
            student_grad_x = torch.abs(student_depth[:, :, :, :-1] - student_depth[:, :, :, 1:])
            student_grad_y = torch.abs(student_depth[:, :, :-1, :] - student_depth[:, :, 1:, :])
            teacher_grad_x = torch.abs(teacher_depth[:, :, :, :-1] - teacher_depth[:, :, :, 1:])
            teacher_grad_y = torch.abs(teacher_depth[:, :, :-1, :] - teacher_depth[:, :, 1:, :])

            loss_grad = self.l1_loss(student_grad_x, teacher_grad_x) + self.l1_loss(student_grad_y, teacher_grad_y)
            total_loss += self.lambda_grad * loss_grad

        # 4.Smoothness Loss (Regularizer)
        if self.lambda_smooth > 0:
            # Penalizes the L1 norm of the student's depth gradients
            loss_smooth = torch.mean(student_grad_x) + torch.mean(student_grad_y)
            total_loss += self.lambda_smooth * loss_smooth

        # 5. SSIM Loss
        if self.lambda_ssim > 0:
            # Move ssim module to the same device as the tensors
            self.ssim.to(student_depth.device)

            # The torchmetrics SSIM implementation returns a value between -1 and 1.
            # A value of 1 indicates perfect similarity.
            # To use it as a loss, we subtract it from 1.
            d_ssim = self.ssim(student_depth, teacher_depth)
            loss_ssim = (1 - d_ssim) / 2 # Normalize to be between 0 and 1
            total_loss += self.lambda_ssim * loss_ssim

        return total_loss

### The Training Function

In [7]:
def train_knowledge_distillation(teacher, student, train_dataloader, val_dataloader, criterion, optimizer, epochs, device):
    """
    Train the student model using Response-Based knowledge distillation.
    """
    teacher.eval() # Teacher should always be in evaluation mode

    print(f"Starting Knowledge Distillation Training on {device}...")
    min_loss = float('inf')
    for epoch in range(epochs):
        student.train() # Student in training mode
        running_loss = 0.0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        start_time = time.time()

        for images in progress_bar:
            images = images.to(device)
            optimizer.zero_grad()

            # Forward pass with Teacher model (no_grad as teacher is fixed)
            with torch.no_grad():
                teacher_outputs = teacher(images) # Returns depth map

            # Forward pass with Student model
            student_outputs = student(images) # Returns depth map

            # Calculate distillation loss
            loss = criterion(student_outputs, teacher_outputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()


        epoch_loss = running_loss / len(train_dataloader)
        current_lr = scheduler.get_last_lr()[0]
        end_time = time.time()
        print(f"End of Epoch {epoch+1},Time: {end_time - start_time:.2f}s, Current LR: {current_lr:.6f}, Average Loss: {epoch_loss:.4f}")
        scheduler.step()

        # Validation loop
        student.eval() # Student in evaluation mode for validation
        val_running_loss = 0.0
        with torch.no_grad():
            progress_bar_val = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Validation]")
            for val_images in progress_bar_val:
                val_images = val_images.to(device)
                teacher_outputs = teacher(val_images)
                student_outputs = student(val_images)
                val_loss = criterion(student_outputs, teacher_outputs)
                val_running_loss += val_loss.item()

        val_epoch_loss = val_running_loss / len(val_dataloader)
        print(f"Average Validation Loss: {val_epoch_loss:.4f}")

        if val_epoch_loss < min_loss:
            min_loss = val_epoch_loss
            print("Validation loss improved. Saving the model.")
            torch.save(student.state_dict(), f"/content/drive/MyDrive/best_student.pth")

    print("Knowledge Distillation Training Finished!")

# Training Process

### Define Parameters & Models

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load teacher model
teacher = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf").to(device)


teacher_model = DepthModel(teacher).to(device)

# student  model architecture
student_model = StudentModel().to(device)

######################################################
# Load the entire model object from the checkpoint
# student_model = torch.load('/content/drive/MyDrive/distillSkip1700.pth', weights_only=False).to(device)

###########################
# Initialize optimizer for the student model
# student_optimizer = optim.Adam(student_model.parameters(), lr=1e-4)
student_optimizer = optim.AdamW(student_model.parameters(), lr=1e-3, weight_decay=1e-3)

# Training parameters
num_epochs = 60
# scheduler = CosineAnnealingLR(student_optimizer, T_max=num_epochs, eta_min=1e-5ss)
scheduler = StepLR(student_optimizer, step_size=20, gamma=0.1)
# Distillation Loss
# Instantiate the custom loss function
distillation_criterion = DepthCombinedDistillationLoss(lambda_depth=0.5,  lambda_si=1.0, lambda_grad=1.0, lambda_ssim=1.0, window_size=11)

# transformations for input images (teacher and student will use the same)

input_size = (384, 384)

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Specify the path to your unlabeled data in Google Drive
unlabeled_data_path = '/content/drive/MyDrive/images/'

# Create dataset and data loader
unlabeled_dataset = UnlabeledImageDataset(root_dir=unlabeled_data_path, transform=transform, resize_size=input_size)

dataset_size = len(unlabeled_dataset)
train_size = int(0.7 * dataset_size)
val_size = dataset_size - train_size
# Split the dataset
train_dataset, val_dataset = torch.utils.data.random_split(unlabeled_dataset, [train_size, val_size])
# Create separate dataloaders for training and validation
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")

Using device: cuda
Found 152 images in /content/drive/MyDrive/images/
Training set size: 106
Validation set size: 46


Before Training

In [24]:
# Load train image
image_path = "/content/drive/MyDrive/images/image1.JPG"
train_image = cv2.imread(image_path)
train_image = cv2.cvtColor(train_image, cv2.COLOR_BGR2RGB)  # Convert to RGB

train_input_tensor = eval_transform(Image.fromarray(train_image)).unsqueeze(0).to(device)
# Load image
image_path = "/content/test.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB

input_tensor = eval_transform(Image.fromarray(image)).unsqueeze(0).to(device)

student_model.eval() # Use the DepthModel instance


with torch.no_grad():
    # Student prediction (before training) using the DepthModel instance
    student_depth_before = student_model(input_tensor)
    student_depth_before_training = student_depth_before.squeeze().cpu().numpy()
    student_depth_before = student_model(train_input_tensor)
    student_depth_before_training_train_image = student_depth_before.squeeze().cpu().numpy()


### Freeze Student's Encoder

In [21]:
# --- Verify which layers are trainable ---
print("All parameters in student_model:")
for name, param in student_model.named_parameters():
    if param.requires_grad:
        print(f"  {name}")

# print("----------------------------------------")
# for name, param in student_model.named_parameters():
#     param.requires_grad = False

# # Unfreeze parameters in the 'head' layer
# for name, param in student_model.decoder.named_parameters():
#     param.requires_grad = True

# # --- Verify which layers are trainable ---
# print("Trainable parameters in student_model:")
# for name, param in student_model.named_parameters():
#     if param.requires_grad:
#         print(f"  {name}")


All parameters in student_model:
  skip_s4_conv.0.weight
  skip_s4_conv.1.weight
  skip_s4_conv.1.bias
  skip_s8_conv.0.weight
  skip_s8_conv.1.weight
  skip_s8_conv.1.bias
  skip_s16_conv.0.weight
  skip_s16_conv.1.weight
  skip_s16_conv.1.bias
  skip_s32_conv.0.weight
  skip_s32_conv.1.weight
  skip_s32_conv.1.bias
  decoder_block1.0.weight
  decoder_block1.1.weight
  decoder_block1.1.bias
  decoder_block2.0.weight
  decoder_block2.1.weight
  decoder_block2.1.bias
  decoder_block3.0.weight
  decoder_block3.1.weight
  decoder_block3.1.bias
  decoder_block4.0.weight
  decoder_block4.1.weight
  decoder_block4.1.bias
  final_conv.weight
  final_conv.bias


### Run the Training

In [None]:
    # Run the training
train_knowledge_distillation(
    teacher=teacher_model,
    student=student_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    criterion=distillation_criterion,
    optimizer=student_optimizer,
    epochs=num_epochs,
    device=device
)

print("Training complete. Student model saved at specified checkpoints.")


Starting Knowledge Distillation Training on cuda...


Epoch 1/60: 100%|██████████| 11/11 [00:36<00:00,  3.34s/it]


End of Epoch 1,Time: 36.78s, Current LR: 0.001000, Average Loss: 24463.1326


Epoch 1/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.38s/it]


Average Validation Loss: 18808.2149
Validation loss improved. Saving the model.


Epoch 2/60: 100%|██████████| 11/11 [00:34<00:00,  3.14s/it]


End of Epoch 2,Time: 34.51s, Current LR: 0.001000, Average Loss: 24694.9606


Epoch 2/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.36s/it]


Average Validation Loss: 20035.3724


Epoch 3/60: 100%|██████████| 11/11 [00:33<00:00,  3.05s/it]


End of Epoch 3,Time: 33.59s, Current LR: 0.001000, Average Loss: 24025.5257


Epoch 3/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.35s/it]


Average Validation Loss: 20200.0467


Epoch 4/60: 100%|██████████| 11/11 [00:34<00:00,  3.15s/it]


End of Epoch 4,Time: 34.69s, Current LR: 0.001000, Average Loss: 23803.8210


Epoch 4/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.35s/it]


Average Validation Loss: 20155.3242


Epoch 5/60: 100%|██████████| 11/11 [00:34<00:00,  3.10s/it]


End of Epoch 5,Time: 34.10s, Current LR: 0.001000, Average Loss: 24288.4776


Epoch 5/60 [Validation]: 100%|██████████| 12/12 [00:15<00:00,  1.32s/it]


Average Validation Loss: 19749.7863


Epoch 6/60: 100%|██████████| 11/11 [00:33<00:00,  3.03s/it]


End of Epoch 6,Time: 33.39s, Current LR: 0.001000, Average Loss: 24152.2907


Epoch 6/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.34s/it]


Average Validation Loss: 19844.2915


Epoch 7/60: 100%|██████████| 11/11 [00:34<00:00,  3.17s/it]


End of Epoch 7,Time: 34.89s, Current LR: 0.001000, Average Loss: 23967.5391


Epoch 7/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.38s/it]


Average Validation Loss: 19367.0225


Epoch 8/60: 100%|██████████| 11/11 [00:36<00:00,  3.31s/it]


End of Epoch 8,Time: 36.44s, Current LR: 0.001000, Average Loss: 23483.5344


Epoch 8/60 [Validation]: 100%|██████████| 12/12 [00:17<00:00,  1.46s/it]


Average Validation Loss: 20456.0722


Epoch 9/60: 100%|██████████| 11/11 [00:34<00:00,  3.14s/it]


End of Epoch 9,Time: 34.56s, Current LR: 0.001000, Average Loss: 23414.9486


Epoch 9/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.36s/it]


Average Validation Loss: 19490.9606


Epoch 10/60: 100%|██████████| 11/11 [00:35<00:00,  3.21s/it]


End of Epoch 10,Time: 35.31s, Current LR: 0.001000, Average Loss: 23102.4071


Epoch 10/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.35s/it]


Average Validation Loss: 19588.9369


Epoch 11/60: 100%|██████████| 11/11 [00:34<00:00,  3.18s/it]


End of Epoch 11,Time: 34.97s, Current LR: 0.001000, Average Loss: 22678.6415


Epoch 11/60 [Validation]: 100%|██████████| 12/12 [00:15<00:00,  1.33s/it]


Average Validation Loss: 18611.4924
Validation loss improved. Saving the model.


Epoch 12/60: 100%|██████████| 11/11 [00:34<00:00,  3.16s/it]


End of Epoch 12,Time: 34.72s, Current LR: 0.001000, Average Loss: 22835.8116


Epoch 12/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.37s/it]


Average Validation Loss: 19877.1902


Epoch 13/60: 100%|██████████| 11/11 [00:36<00:00,  3.33s/it]


End of Epoch 13,Time: 36.59s, Current LR: 0.001000, Average Loss: 22363.4302


Epoch 13/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.36s/it]


Average Validation Loss: 19041.5875


Epoch 14/60: 100%|██████████| 11/11 [00:34<00:00,  3.12s/it]


End of Epoch 14,Time: 34.35s, Current LR: 0.001000, Average Loss: 22292.6509


Epoch 14/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.37s/it]


Average Validation Loss: 19102.0377


Epoch 15/60: 100%|██████████| 11/11 [00:33<00:00,  3.08s/it]


End of Epoch 15,Time: 33.88s, Current LR: 0.001000, Average Loss: 21969.1618


Epoch 15/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.36s/it]


Average Validation Loss: 18422.2500
Validation loss improved. Saving the model.


Epoch 16/60: 100%|██████████| 11/11 [00:35<00:00,  3.20s/it]


End of Epoch 16,Time: 35.23s, Current LR: 0.001000, Average Loss: 22355.8592


Epoch 16/60 [Validation]: 100%|██████████| 12/12 [00:15<00:00,  1.33s/it]


Average Validation Loss: 18793.0407


Epoch 17/60: 100%|██████████| 11/11 [00:35<00:00,  3.27s/it]


End of Epoch 17,Time: 35.95s, Current LR: 0.001000, Average Loss: 22156.9152


Epoch 17/60 [Validation]: 100%|██████████| 12/12 [00:16<00:00,  1.38s/it]


Average Validation Loss: 18236.5278
Validation loss improved. Saving the model.


Epoch 18/60:  64%|██████▎   | 7/11 [00:24<00:13,  3.37s/it]

# Evaluation

### On training

In [None]:
# Load image
image_path = "/content/drive/MyDrive/images/image1.JPG"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB

input_tensor = eval_transform(Image.fromarray(image)).unsqueeze(0).to(device)

# Set models to evaluation mode
teacher_model.eval() # Use the DepthModel instance
student_model.eval() # Use the DepthModel instance


with torch.no_grad():
    # Student prediction (before training) using the DepthModel instance
    start_time = time.time()
    student_output_after = student_model(input_tensor)
    end_time = time.time()
    inference_time_ms = (end_time - start_time) * 1000
    print(f"✅ Student model inference time: {inference_time_ms:.2f} ms")

    student_output_after_training = student_output_after.squeeze().cpu().numpy()

    # Teacher prediction using the DepthModel instance
    teacher_depth = teacher_model(input_tensor)
    teacher_depth = teacher_depth.squeeze().cpu().numpy()

# loss = distillation_criterion(student_output_after_training, teacher_depth)
# print(loss.item())
#Befor training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_depth_before_training_train_image, cmap="viridis")
plt.title("Student Depth Estimation (Before Training)")
plt.axis("off")

plt.show()

#After training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_output_after_training, cmap="viridis")
plt.title("Student Depth Estimation (After Training)")
plt.axis("off")

plt.show()



### On Testing

In [None]:
# Load image
image_path = "/content/test.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB

input_tensor = eval_transform(Image.fromarray(image)).unsqueeze(0).to(device)

# Set models to evaluation mode
teacher_model.eval() # Use the DepthModel instance
student_model.eval() # Use the DepthModel instance

loss = 0;
with torch.no_grad():
    # Student prediction (before training) using the DepthModel instance
    student_output_after = student_model(input_tensor)
    student_output_after_training = student_output_after.squeeze().cpu().numpy()

    # Teacher prediction using the DepthModel instance
    teacher_depth = teacher_model(input_tensor)
    loss = distillation_criterion(student_output_after, teacher_depth)
    teacher_depth = teacher_depth.squeeze().cpu().numpy()

print(loss.item())
#Befor training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_depth_before_training, cmap="viridis")
plt.title("Student Depth Estimation (Before Training)")
plt.axis("off")

plt.show()

#After training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_output_after_training, cmap="viridis")
plt.title("Student Depth Estimation (After Training)")
plt.axis("off")

plt.show()

### Saving a CheckPoint

In [None]:
torch.save(student_model, "/content/drive/MyDrive/distillSkipL2500.pth")