# Set Up the Environment

### Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
from transformers import AutoModelForDepthEstimation
from torchvision import transforms
from typing import Tuple, List
import timm
import os
from torch.utils.data import Dataset, DataLoader
import time
import numpy as np
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
import config
device = config.DEVICE

### Mount Drive

In [2]:
from google.colab import drive

drive.mount('/content/drive')



Mounted at /content/drive


In [None]:
!mkdir -p /content/drive/MyDrive/Coco

# !wget http://images.cocodataset.org/zips/train2017.zip
# !wget http://images.cocodataset.org/zips/val2017.zip

# # Download the 2017 annotations
# !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip

# !wget -q http://images.cocodataset.org/zips/test2017.zip -P /content/drive/MyDrive/Coco
# !wget -q http://images.cocodataset.org/annotations/image_info_test2017.zip -P /content/drive/MyDrive/Coco
!unzip -q /content/drive/MyDrive/Coco/test2017.zip -d /content/drive/MyDrive/Coco/
!unzip -q /content/drive/MyDrive/Coco/image_info_test2017.zip -d /content/drive/MyDrive/Coco/annotations/
print("COCO download and extraction complete.")

COCO download and extraction complete.


In [None]:
!unzip -q /content/drive/MyDrive/Coco/image_info_test2017.zip -d /content/drive/MyDrive/Coco/annotations/

replace /content/drive/MyDrive/Coco/annotations/annotations/image_info_test-dev2017.json? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/drive/MyDrive/Coco/annotations/annotations/image_info_test2017.json? [y]es, [n]o, [A]ll, [N]one, [r]ename: A


# Define Needed Classes & Functions

### Class for Depth Model

In [3]:
class TeacherWrapper(nn.Module):
    """
    A wrapper for the teacher depth estimation model (Depth-Anything-V2).

    This class provides a unified interface for the teacher model. It handles
    loading the pre-trained model from Hugging Face (or a local cache) and
    performs the necessary pre- and post-processing steps. During inference,
    it extracts the final depth prediction and intermediate feature maps,
    which serve as targets for training the student model via knowledge
    distillation.
    """
    def __init__(self, model_id: str = 'depth-anything/depth-anything-v2-small-hf',
                cache_dir: str = None,
                selected_features_indices: List[int] = [3, 5, 7, 11]
        ):
        """
        Initializes the TeacherWrapper.

        Args:
            model_id (str): The identifier for the pre-trained model on the
                            Hugging Face Hub, or a path to a local directory
                            containing the model files.
            cache_dir (str): The directory where the downloaded model should be
                             cached.
        """
        super().__init__()
        self.model_id = model_id
        # Load the pre-trained depth estimation model
        self.model = AutoModelForDepthEstimation.from_pretrained(model_id, cache_dir=cache_dir)
        # Set the model to evaluation mode, as we don't want to train it
        self.model.eval()
        self.selected_features_indices = selected_features_indices

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Forward pass for the teacher model.

        This method should always be called within a `torch.no_grad()` context,
        as the teacher's weights should remain frozen during distillation.

        Args:
            x (torch.Tensor): The input image tensor.

        Returns:
            A tuple containing:
            - The final, normalized depth map (torch.Tensor).
            - A list of intermediate feature maps from the teacher's encoder
              (List[torch.Tensor]).
        """
        original_size = x.shape[2:]

        # 1. Get Model Outputs
        # We get the model's outputs, including the hidden states, which we
        # will use as feature targets for the student.
        outputs = self.model(x, output_hidden_states=True)
        predicted_depth = outputs.predicted_depth
        hidden_states = outputs.hidden_states

        # 2. Normalize Depth Map
        # The raw output of the model is not normalized, so we normalize it to
        # the range [0, 1] for consistent training.
        if predicted_depth.dim() == 3:
            predicted_depth = predicted_depth.unsqueeze(1)
        b, c, h, w = predicted_depth.shape
        predicted_depth_flat = predicted_depth.view(b, -1)
        max_vals = predicted_depth_flat.max(dim=1, keepdim=True)[0]
        max_vals[max_vals == 0] = 1.0  # Avoid division by zero
        normalized_depth = (predicted_depth_flat / max_vals).view(b, c, h, w)

        # 3. Interpolate to Original Size
        # The model's output may be smaller than the input image, so we
        # interpolate it back to the original size.
        final_depth = F.interpolate(normalized_depth, size=original_size, mode='bilinear', align_corners=False)

        # 4. Select Feature Maps for Distillation
        # We select a subset of the hidden states to use as feature targets.
        # For ViT-based models like DINOv2, these indices correspond to the
        # outputs of different blocks in the encoder.

        selected_features = [hidden_states[i] for i in self.selected_features_indices]

        # 5. Reshape ViT Features
        # The feature maps from Vision Transformer (ViT) models have a different
        # shape ([B, SeqLen, C]) than those from CNNs ([B, C, H, W]). We need to
        # reshape them to be compatible with the student's CNN-based features.
        reshaped_features = []
        patch_size = self.model.config.patch_size
        H_grid = x.shape[2] // patch_size
        W_grid = x.shape[3] // patch_size

        for feature_map in selected_features:
            batch_size, seq_len, num_channels = feature_map.shape
            # The first token in the sequence is the [CLS] token, which we remove
            image_patch_tokens = feature_map[:, 1:, :]
            # Reshape the sequence of patch tokens into a 2D feature map
            reshaped_map = image_patch_tokens.transpose(1, 2).reshape(batch_size, num_channels, H_grid, W_grid)
            reshaped_features.append(reshaped_map)

        return final_depth, reshaped_features


### Class for Student Model

In [4]:
import torch
import torch.nn as nn


class UpsampleBlock(nn.Module):
    """
    A building block for the decoder that upsamples feature maps and refines them.

    This block first increases the spatial resolution of the input feature map by a
    factor of 2 using bilinear interpolation. It then applies a series of
    convolutional layers to refine the upsampled features. For efficiency,
    it uses depthwise separable convolutions.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the UpsampleBlock.

        Args:
            in_channels (int): The number of channels in the input feature map.
            out_channels (int): The number of channels in the output feature map.
        """
        super().__init__()

        # Upsampling layer to increase spatial resolution
        self.upsample = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=False
        )

        # Convolutional layers to refine the upsampled features
        self.conv = nn.Sequential(
            # First 3x3 convolution
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            # Second 1x1 convolution
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),

            # (This second block operates on out_channels)
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, groups=out_channels, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),

            nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the UpsampleBlock.

        Args:
            x (torch.Tensor): The input feature map.

        Returns:
            torch.Tensor: The upsampled and refined feature map.
        """
        # Apply upsampling and then the convolutional layers
        upsampled_features = self.upsample(x)
        return self.conv(upsampled_features)

In [5]:
class FeatureFusionBlock(nn.Module):
    """
    A block that fuses features from two different sources.

    This block is used to combine features from a higher-level (more abstract)
    decoder stage with features from a lower-level (more detailed) encoder stage
    via a skip connection. The features are concatenated along the channel
    dimension and then refined using a series of convolutional layers.
    """
    def __init__(self, channels: int):
        """
        Initializes the FeatureFusionBlock.

        Args:
            channels (int): The number of channels in each of the input feature maps.
                            The output will also have this many channels.
        """
        super().__init__()

        # Convolutional layers to process the fused features
        self.conv = nn.Sequential(
            # The input to this conv layer has 2 * channels because we concatenate two feature maps
            nn.Conv2d(channels * 2, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            # Another conv layer to further refine the features
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
        )

    def forward(self, higher_level_features: torch.Tensor, skip_features: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the FeatureFusionBlock.

        Args:
            higher_level_features (torch.Tensor): The feature map from the previous,
                                                  higher-level decoder stage. It is
                                                  assumed to have been upsampled to match
                                                  the spatial dimensions of `skip_features`.
            skip_features (torch.Tensor): The feature map from the corresponding
                                          encoder stage (skip connection).

        Returns:
            torch.Tensor: The fused and refined feature map.
        """
        # Concatenate the two feature maps along the channel dimension
        fused_features = torch.cat([higher_level_features, skip_features], dim=1)
        # Process the fused features with the convolutional layers
        return self.conv(fused_features)



In [6]:
class MiniDPT(nn.Module):
    """
    A lightweight, DPT-inspired decoder for monocular depth estimation.

    This decoder takes a list of feature maps from an encoder at different
    spatial resolutions and progressively fuses them to generate a high-resolution
    depth map. The architecture is inspired by the Dense Prediction Transformer (DPT)
    but is simplified for use with a lightweight backbone like MobileViT.
    """
    def __init__(self, encoder_channels: List[int], decoder_channels: List[int]):
        """
        Initializes the MiniDPT decoder.

        Args:
            encoder_channels (List[int]): A list of the number of channels for each
                                          feature map extracted from the encoder.
                                          The list should be ordered from the lowest
                                          level (largest spatial resolution) to the
                                          highest level (smallest spatial resolution).
                                          Example: [64, 128, 256, 512]
            decoder_channels (List[int]): A list of the number of channels for each
                                          stage of the decoder. The length of this
                                          list must be the same as `encoder_channels`.
                                          Example: [256, 128, 96, 64]
        """
        super().__init__()

        if len(encoder_channels) != len(decoder_channels):
            raise ValueError("Encoder and decoder channel lists must have the same length.")

        # Reverse for processing from high-level to low-level
        encoder_channels = encoder_channels[::-1]
        decoder_channels = decoder_channels[::-1]

        # 1. Projection Convolutions
        # These 1x1 convolutions project the encoder features to the number of
        # channels specified for the decoder.
        self.projection_convs = nn.ModuleList()
        for i in range(len(encoder_channels)):
            self.projection_convs.append(nn.Sequential(
                nn.Conv2d(encoder_channels[i], decoder_channels[i], kernel_size=1, bias=False),
                nn.BatchNorm2d(decoder_channels[i]),
                nn.ReLU(inplace=True),
            ))

        # 2. Upsampling and Fusion Blocks
        # These blocks are used to upsample the features from a higher decoder
        # level and fuse them with the projected features from the corresponding
        # encoder level (skip connection).
        self.upsample_blocks = nn.ModuleList()
        self.fusion_blocks = nn.ModuleList()

        for i in range(len(decoder_channels) - 1):
            # Upsample from the current decoder channel count to the next (lower) one
            self.upsample_blocks.append(UpsampleBlock(decoder_channels[i], decoder_channels[i+1]))
            # Fusion block takes the upsampled features and the projected skip connection
            self.fusion_blocks.append(FeatureFusionBlock(decoder_channels[i+1]))

        # 3. Prediction Head
        # This final part of the decoder takes the fused features from the last
        # stage and produces the final single-channel depth map.
        self.prediction_head = nn.Sequential(
            nn.Conv2d(decoder_channels[-1], decoder_channels[-1] // 2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(decoder_channels[-1] // 2, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()
        )


    def forward(self, encoder_features: List[torch.Tensor]) -> torch.Tensor:
        """
        Forward pass of the MiniDPT decoder.

        Args:
            encoder_features (List[torch.Tensor]): A list of feature maps from the
                                                   encoder, ordered from the lowest
                                                   level to the highest level.

        Returns:
            torch.Tensor: The final predicted depth map.
        """

        # Reverse the features to process from the highest level to the lowest
        features = encoder_features[::-1]

        # Project all encoder features to the decoder's channel dimensions
        projected_features = [self.projection_convs[i](features[i]) for i in range(len(features))]

        # Start with the highest-level (most abstract) feature map
        current_features = projected_features[0]

        # Iteratively upsample and fuse with lower-level skip connections
        for i in range(len(self.fusion_blocks)):
            upsampled = self.upsample_blocks[i](current_features)
            skip_connection = projected_features[i+1]
            current_features = self.fusion_blocks[i](upsampled, skip_connection)

        # Generate final prediction using the prediction head
        return self.prediction_head(current_features)


In [7]:
class StudentDepthModel(nn.Module):
    """
    The student model for monocular depth estimation.

    This model consists of a lightweight, pre-trained encoder (e.g., MobileViT)
    and a custom lightweight decoder (MiniDPT). It is designed to be trained
    efficiently, making it suitable for deployment on resource-constrained
    devices. The training is done via knowledge distillation from a larger,
    more powerful teacher model.
    """
    def __init__(self, feature_indices: Tuple[int, ...] = (0, 1, 2, 3),
                 decoder_channels: Tuple[int, ...] = (64, 128, 160, 256),
                 pretrained: bool = True):
        """
        Initializes the StudentDepthModel.

        Args:
            encoder_name (str): The name of the encoder model to use from the `timm`
                                library.
            feature_indices (Tuple[int, ...]): A tuple of indices specifying which
                                               feature maps to extract from the encoder.
            decoder_channels (Tuple[int, ...]): A tuple of channel counts for the
                                                decoder stages.
            pretrained (bool): Whether to load pre-trained weights for the encoder.
        """
        super().__init__()
        if len(feature_indices) != len(decoder_channels):
            raise ValueError("The number of feature indices must match the number of decoder channel dimensions.")

        # 1. Instantiate the Encoder
        # We use the `timm` library to create a pre-trained encoder.
        # `features_only=True` makes the model return a List of feature maps
        # at different stages, instead of a final classification output.
        self.encoder = timm.create_model(
            'mobilevit_xs',
            pretrained=pretrained,
            features_only=True, # This returns a List of feature maps
        )
        self.feature_indices = feature_indices

        # 2. Determine Encoder Output Channels
        # To connect the encoder to the decoder, we need to know the number of
        # channels in the feature maps that the encoder produces. We can find
        # this by doing a dummy forward pass.
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            features = self.encoder(dummy_input)
            encoder_channels = [features[i].shape[1] for i in self.feature_indices]

        # 3. Instantiate the Decoder
        # The decoder takes the feature maps from the encoder and upsamples them
        # to produce the final depth map.
        self.decoder = MiniDPT(encoder_channels, list(decoder_channels))

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Forward pass of the StudentDepthModel.

        Args:
            x (torch.Tensor): The input image tensor.

        Returns:
            A tuple containing:
            - The final predicted depth map (torch.Tensor).
            - A list of intermediate feature maps from the encoder, which will be
              used for feature-based distillation (List[torch.Tensor]).
        """
        # Get the feature maps from the encoder
        features = self.encoder(x)
        # Select the feature maps at the specified indices
        selected_features = [features[i] for i in self.feature_indices]
        # Pass the selected features to the decoder to get the depth map
        depth_map = self.decoder(selected_features)
        return depth_map, selected_features


### Class for Dataset Loading & Preprocessing

In [8]:
from torchvision.datasets import CocoDetection
class CocoUnlabeledDataset(Dataset):
    """
    Custom dataset for unlabeled COCO images.
    This class wraps torchvision's CocoDetection to provide only the images,
    ignoring the annotations, suitable for an unlabeled training task.
    """
    def __init__(self, root_dir, ann_file, transform=None, resize_size=None):
        self.resize_size = resize_size
        self.coco_dataset = CocoDetection(root=root_dir, annFile=ann_file, transform=transform)
        print(f"Found {len(self.coco_dataset)} images in {root_dir}")

    def __len__(self):
        return len(self.coco_dataset)

    def __getitem__(self, idx):
        """
        Returns only the transformed image, ignoring the target annotations.
        The transform is applied by the underlying CocoDetection dataset.
        """
        try:
            image, _ = self.coco_dataset[idx]
            # if self.resize_size:
            #   image = image.resize(self.resize_size)
            return image
        except Exception as e:
            print(f"Warning: Skipping image at index {idx} due to error: {e}")
            # Return a placeholder tensor if an image fails to load
            return torch.zeros((3, self.resize_size[0], self.resize_size[1]))


In [9]:
class UnlabeledImageDataset(Dataset):
    """
    Custom dataset for unlabeled images.
    """
    def __init__(self, root_dir, transform=None, resize_size=None):
        self.root_dir = root_dir
        self.transform = transform
        self.resize_size = resize_size
        self.image_paths = []

        safe = os.path.join(root_dir, 'safe')
        not_safe = os.path.join(root_dir, 'note_safe')
        for dirpath, _, filenames in os.walk(root_dir):
            for f in filenames:
                if f.lower().endswith(('png', 'jpg', 'jpeg')):
                    self.image_paths.append(os.path.join(dirpath, f))


        for dirpath, _, filenames in os.walk(safe):
            for f in filenames:
                if f.lower().endswith(('png', 'jpg', 'jpeg')):
                    self.image_paths.append(os.path.join(dirpath, f))


        for dirpath, _, filenames in os.walk(not_safe):
            for f in filenames:
                if f.lower().endswith(('png', 'jpg', 'jpeg')):
                    self.image_paths.append(os.path.join(dirpath, f))

        print(f"Found {len(self.image_paths)} images in {root_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.resize_size:
            image = image.resize(self.resize_size)

        if self.transform:
            image = self.transform(image)

        return image

### Distillation Loss

In [10]:
def compute_depth_gradients(depth_map: torch.Tensor) -> torch.Tensor:
    """
    Computes the image gradients (dy, dx) for a batch of depth maps.

    This is done by applying Sobel filters to the depth map. The gradients
    are used to compute a loss that encourages the student model to preserve
    edges and fine details from the teacher's prediction.

    Args:
        depth_map (torch.Tensor): A batch of single-channel depth maps.

    Returns:
        torch.Tensor: A tensor containing the absolute gradients in the y and x
                      directions, concatenated along the channel dimension.
    """
    # Create Sobel filters for GPU computation
    sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=depth_map.device).view(1, 1, 3, 3)
    sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=depth_map.device).view(1, 1, 3, 3)

    # Apply filters using depthwise convolution
    padded_map = F.pad(depth_map, (1, 1, 1, 1), mode='replicate')

    grad_y = F.conv2d(padded_map, sobel_y, padding=0)
    grad_x = F.conv2d(padded_map, sobel_x, padding=0)

    # Return the absolute gradients, stacked along the channel dimension
    return torch.cat([grad_y.abs(), grad_x.abs()], dim=1)


class DistillationLoss(nn.Module):
    """
    A comprehensive loss function for knowledge distillation in depth estimation.

    This loss function combines four different components to train the student
    model effectively:
    1.  Scale-Invariant Log (SILog) Loss: Measures the overall accuracy of the
        predicted depth map.
    2.  Gradient Matching Loss (L1): Enforces that the student's depth map
        has similar edges and fine details as the teacher's.
    3.  Feature Matching Loss (L1): Encourages the student's intermediate
        feature representations to be similar to the teacher's.
    4.  Attention Matching Loss (L2): Encourages the student to focus on the
        same spatial regions of the image as the teacher.
    """
    def __init__(self, lambda_silog: float = 1.0, lambda_grad: float = 0.2,
                 lambda_feat: float = 0.1, lambda_attn: float = 1.0, alpha: float = 0.5):
        """
        Initializes the DistillationLoss.

        Args:
            lambda_silog (float): The weight for the SILog depth loss.
            lambda_grad (float): The weight for the gradient matching loss.
            lambda_feat (float): The weight for the feature matching loss.
            lambda_attn (float): The weight for the attention matching loss.
            alpha (float): A parameter for the SILog loss that balances between
                           scale and shift invariance.
        """
        super().__init__()
        self.lambda_silog = lambda_silog
        self.lambda_grad = lambda_grad
        self.lambda_feat = lambda_feat
        self.lambda_attn = lambda_attn
        self.alpha = alpha

        self.l1_loss = nn.L1Loss()
        self.l2_loss = nn.MSELoss()

        self.projection_convs = None

    def _initialize_projections(self, student_features: List[torch.Tensor],
                                teacher_features: List[torch.Tensor], device: torch.device):
        """
        Dynamically creates projection layers to match the channel counts of the
        student and teacher features. This is necessary because the student and
        teacher models may have different numbers of channels in their
        intermediate feature maps.
        """
        self.projection_convs = nn.ModuleList()
        for s_feat, t_feat in zip(student_features, teacher_features):
            s_chan, t_chan = s_feat.shape[1], t_feat.shape[1]
            if s_chan != t_chan:
                # Create a 1x1 convolution to project student channels to teacher channels
                proj = nn.Conv2d(s_chan, t_chan, kernel_size=1, bias=False).to(device)
            else:
                proj = nn.Identity().to(device)
            self.projection_convs.append(proj)

    def _compute_attention_map(self, feature_map: torch.Tensor) -> torch.Tensor:
        """
        Computes a spatial attention map from a feature map by summarizing
        across the channel dimension. This provides a simple way to capture
        which spatial regions the model is focusing on.
        """
        return torch.mean(torch.abs(feature_map), dim=1, keepdim=True)


    def forward(
        self,
        student_depth: torch.Tensor,
        teacher_depth: torch.Tensor,
        student_features: List[torch.Tensor],
        teacher_features: List[torch.Tensor],
    ) -> torch.Tensor:
        """
        Calculates the combined distillation loss.

        Args:
            student_depth (torch.Tensor): The depth map predicted by the student.
            teacher_depth (torch.Tensor): The depth map predicted by the teacher.
            student_features (List[torch.Tensor]): Intermediate features from the student.
            teacher_features (List[torch.Tensor]): Intermediate features from the teacher.

        Returns:
            torch.Tensor: The total combined loss.
        """
        device = student_depth.device

        # Initialize projection layers on the first pass
        if self.projection_convs is None:
            self._initialize_projections(student_features, teacher_features, device)

        # --- 1. SILog Depth Loss ---
        valid_mask = (student_depth > 1e-8) & (teacher_depth > 1e-8)
        log_diff = torch.log(student_depth[valid_mask]) - torch.log(teacher_depth[valid_mask])
        num_pixels = log_diff.numel()
        silog_loss = torch.sum(log_diff ** 2) / num_pixels - self.alpha * (torch.sum(log_diff) ** 2) / (num_pixels ** 2) if num_pixels > 0 else torch.tensor(0.0, device=device)

        # --- 2. Gradient Matching Loss ---
        student_grads = compute_depth_gradients(student_depth)
        teacher_grads = compute_depth_gradients(teacher_depth)
        grad_loss = self.l1_loss(student_grads, teacher_grads)

        # --- 3. Feature & Attention Matching Loss ---
        feature_loss = torch.tensor(0.0, device=device)
        attention_loss = torch.tensor(0.0, device=device)

        for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
            # Project the student feature to match the teacher's channel dimension
            s_feat_projected = self.projection_convs[i](s_feat)

            # Interpolate if spatial sizes don't match (essential for ViT vs CNN features)
            if s_feat_projected.shape[2:] != t_feat.shape[2:]:
                s_feat_resized = F.interpolate(s_feat_projected, size=t_feat.shape[2:], mode='bilinear', align_corners=False)
            else:
                s_feat_resized = s_feat_projected

            feature_loss += self.l1_loss(s_feat_resized, t_feat)

            # Calculate the attention map loss
            s_attn = self._compute_attention_map(s_feat_resized)
            t_attn = self._compute_attention_map(t_feat)
            attention_loss += self.l2_loss(s_attn, t_attn)

        # --- 4. Combine All Losses ---
        total_loss = (self.lambda_silog * silog_loss) + \
                     (self.lambda_grad * grad_loss) + \
                     (self.lambda_feat * feature_loss) + \
                     (self.lambda_attn * attention_loss)

        return total_loss


### transforms

In [11]:

def get_train_transforms(input_size=(config.IMG_HEIGHT, config.IMG_WIDTH)):
    """Returns a composition of transforms for training."""
    return transforms.Compose([
        transforms.RandomHorizontalFlip(p=config.FLIP_PROP),
        transforms.RandomRotation(degrees=config.ROTATION_DEG),
        transforms.RandomResizedCrop(input_size, scale=(config.MIN_SCALE, config.MAX_SCALE)),
        transforms.ColorJitter(brightness=config.BRIGHTNESS, contrast=config.CONTRAST, saturation=config.SATURATION, hue=config.HUE),
        transforms.ToTensor(),
        transforms.Normalize(mean=config.IMGNET_NORMALIZE_MEAN, std=config.IMGNET_NORMALIZE_STD)
    ])

def get_eval_transforms(input_size=(config.IMG_HEIGHT, config.IMG_WIDTH)):
    """Returns a composition of transforms for evaluation."""
    return transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=config.IMGNET_NORMALIZE_MEAN, std=config.IMGNET_NORMALIZE_STD)
    ])

### Visuals

In [12]:
import numpy as np
import matplotlib.pyplot as plt


def apply_color_map(depth_map, cmap='inferno'):
    """
    Applies a colormap to a grayscale depth map for visualization.

    Args:
        depth_map (np.ndarray): The input depth map as a 2D numpy array.
                                Values can be in any range.
        cmap (str): The name of the matplotlib colormap to use.
                    Defaults to 'inferno'.

    Returns:
        np.ndarray: The colorized depth map as a numpy array with RGB values
                    in the range [0, 255].
    """
    # 1. Normalize the depth map to be in the range [0, 1]
    # This is necessary for the colormap to be applied correctly.
    depth_range = np.max(depth_map) - np.min(depth_map)
    if depth_range == 0:
        depth_range = np.max(depth_map)
    depth_normalized = (depth_map - np.min(depth_map)) / depth_range

    # 2. Get the colormap from matplotlib
    colormap = plt.get_cmap(cmap)

    # 3. Apply the colormap to the normalized depth map
    # The colormap function returns RGBA values in the range [0, 1].
    colored_depth = colormap(depth_normalized)

    # 4. Convert to an 8-bit RGB image
    # We discard the alpha channel and scale the values to [0, 255].
    colored_depth_rgb = (colored_depth[:, :, :3] * 255).astype(np.uint8)

    return colored_depth_rgb

def plot_depth_comparison(original_img, teacher_depth, student_depth, title=""):
    """Plots the original image, teacher depth, and student depth side-by-side."""
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(original_img)
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(teacher_depth, cmap="viridis")
    plt.title("Teacher Depth Map")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(student_depth, cmap="viridis")
    plt.title("Student Depth Map")
    plt.axis("off")

    if title:
        plt.suptitle(title)
    plt.show()

### The Training Function

In [13]:
def train_knowledge_distillation(teacher, student, train_dataloader, val_dataloader, criterion, optimizer, epochs, scheduler, checkpoint_dir, device):
    """
    Train the student model using Response-Based knowledge distillation.
    """
    teacher.eval() # Teacher should always be in evaluation mode

    print(f"Starting Knowledge Distillation Training on {device}...")
    min_loss = float('inf')
    train_losses = []  # List to store training losses
    val_losses = []    # List to store validation losses

    for epoch in range(epochs):
        student.train() # Student in training mode
        running_loss = 0.0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        start_time = time.time()

        for images in progress_bar:
            images = images.to(device)
            optimizer.zero_grad()

            # Forward pass with Teacher model (no_grad as teacher is fixed)
            with torch.no_grad():
                teacher_depth, teacher_features = teacher(images) # Returns depth map

            # Forward pass with Student model
            student_depth, student_features  = student(images) # Returns depth map

            # Calculate distillation loss
            loss = criterion(student_depth, teacher_depth, student_features, teacher_features)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()


        epoch_loss = running_loss / len(train_dataloader)
        train_losses.append(epoch_loss) # Store training loss
        current_lr = scheduler.get_last_lr()[0]
        end_time = time.time()
        print(f"End of Epoch {epoch+1},Time: {end_time - start_time:.2f}s, Current LR: {current_lr:.6f}, Average Loss: {epoch_loss:.4f}")
        scheduler.step()

        # Validation loop
        student.eval() # Student in evaluation mode for validation
        val_running_loss = 0.0
        with torch.no_grad():
            progress_bar_val = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{epochs} [Validation]")
            for val_images in progress_bar_val:
                val_images = val_images.to(device)
                teacher_depth, TFeat = teacher(val_images)
                student_depth, SFeat = student(val_images)
                val_loss = criterion(student_depth, teacher_depth, SFeat, TFeat)
                val_running_loss += val_loss.item()

        val_epoch_loss = val_running_loss / len(val_dataloader)
        val_losses.append(val_epoch_loss) # Store validation loss
        print(f"Average Validation Loss: {val_epoch_loss:.4f}")

        if val_epoch_loss < min_loss:
            min_loss = val_epoch_loss
            print("Validation loss improved. Saving the model.")

            torch.save(student.state_dict(), f"/content/drive/MyDrive/FINALStudentCoCo.pth")

        # Save losses to a file in Google Drive after each epoch
        loss_data = {'train_loss': train_losses, 'val_loss': val_losses}
        loss_filepath = "/content/drive/MyDrive/FINAL_training_losses_CoCo.pth"
        torch.save(loss_data, loss_filepath)
        print(f"Training and validation losses saved to {loss_filepath}")


    print("Knowledge Distillation Training Finished!")

    return train_losses, val_losses # Return the lists of losses

# Training Process

### Define Parameters & Models

In [14]:
    # --- Setup ---
    device = config.DEVICE
    print(f"Using device: {device}")

    # --- Models ---
    print("Loading teacher model...")
    teacher_model = TeacherWrapper().to(device)
    print("Loaded teacher model sucessfully")

    print("Initializing student model...")
    student_model = StudentDepthModel(pretrained=True).to(device)
    student_model.load_state_dict(torch.load('/content/drive/MyDrive/FINAL.pth', map_location=device))

    print("Initialized student model sucessfully")

    # Get parameters for the encoder and decoder
    encoder_params = student_model.encoder.parameters()
    decoder_params = student_model.decoder.parameters()

    # --- Optimizer, Loss, and Data ---
    student_optimizer = optim.AdamW([
        {'params': encoder_params, 'lr': config.LEARNING_RATE_ENCODER},  # A lower learning rate for the encoder
        {'params': decoder_params, 'lr': config.LEARNING_RATE_DECODER}   # A higher learning rate for the decoder
    ], weight_decay=config.WEIGHT_DECAY)

    num_epochs = config.EPOCHS
    scheduler = CosineAnnealingLR(student_optimizer, T_max=num_epochs, eta_min=config.MIN_LEARNING_RATE)

    criterion = DistillationLoss(
        lambda_silog = config.LAMBDA_SILOG,
        lambda_grad = config.LAMBDA_GRAD,
        lambda_feat = config.LAMBDA_FEAT,
        lambda_attn = config.LAMBDA_ATTN,
        alpha = config.ALPHA).to(device)

    input_size=(config.IMG_HEIGHT, config.IMG_WIDTH)

    transform = get_train_transforms(input_size=input_size)
    eval_transform = get_eval_transforms(input_size=input_size)

    # Create two separate datasets with their respective transforms
    # train_full_dataset = UnlabeledImageDataset(root_dir='/content/drive/MyDrive/images/', transform=transform, resize_size=input_size)
    # val_full_dataset = UnlabeledImageDataset(root_dir='/content/drive/MyDrive/images/', transform=eval_transform, resize_size=input_size)

    train_full_dataset = CocoUnlabeledDataset(root_dir='/content/drive/MyDrive/Coco/test2017', ann_file='/content/drive/MyDrive/Coco/annotations/annotations/image_info_test-dev2017.json', transform=transform, resize_size=input_size)
    val_full_dataset = CocoUnlabeledDataset(root_dir='/content/drive/MyDrive/Coco/test2017', ann_file='/content/drive/MyDrive/Coco/annotations/annotations/image_info_test-dev2017.json', transform=eval_transform, resize_size=input_size)

    # Use the same indices to split both datasets
    dataset_size = len(train_full_dataset)
    train_size = int(0.8 * dataset_size)
    val_size = dataset_size - train_size

    indices = list(range(dataset_size))
    SUBSET_SIZE = 5000
    print(f"Using a random subset of {SUBSET_SIZE} images.")
    indices = indices[:SUBSET_SIZE]

    np.random.seed(config.RANDOM_SEED)
    np.random.shuffle(indices)
    split = int(np.floor(0.8 * len(indices)))
    train_indices, val_indices = indices[:split], indices[split:]

    # Create subsets for training and validation
    train_dataset = torch.utils.data.Subset(train_full_dataset, train_indices)
    val_dataset = torch.utils.data.Subset(val_full_dataset, val_indices)


    # Create separate dataloaders for training and validation
    train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=12, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)

    print(f"Training set size: {len(train_dataset)}")
    print(f"Validation set size: {len(val_dataset)}")
    # --- Checkpoint Directory ---
    checkpoint_dir = config.CHECKPOINT_DIR
    os.makedirs(checkpoint_dir, exist_ok=True)





Using device: cuda
Loading teacher model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/950 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/99.2M [00:00<?, ?B/s]

Loaded teacher model sucessfully
Initializing student model...


model.safetensors:   0%|          | 0.00/9.34M [00:00<?, ?B/s]

Initialized student model sucessfully
loading annotations into memory...
Done (t=1.67s)
creating index...
index created!
Found 20288 images in /content/drive/MyDrive/Coco/test2017
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!
Found 20288 images in /content/drive/MyDrive/Coco/test2017
Using a random subset of 5000 images.
Training set size: 4000
Validation set size: 1000


Before Training

In [16]:

import cv2
image_path = "/content/drive/MyDrive/Coco/test2017/000000000001.jpg"
train_image = cv2.imread(image_path)
train_image = cv2.cvtColor(train_image, cv2.COLOR_BGR2RGB)  # Convert to RGB

train_input_tensor = eval_transform(Image.fromarray(train_image)).unsqueeze(0).to(device)
# Load image
image_path = "/content/test.jpg"
# image_path = "/content/drive/MyDrive/images/image1.JPG"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB

input_tensor = eval_transform(Image.fromarray(image)).unsqueeze(0).to(device)

student_model.eval() # Use the DepthModel instance


with torch.no_grad():
    # Student prediction (before training) using the DepthModel instance
    student_depth_before, Sfeat = student_model(input_tensor)
    student_depth_before_training = student_depth_before.squeeze().cpu().numpy()
    student_depth_before, Sfeat = student_model(train_input_tensor)
    student_depth_before_training_train_image = student_depth_before.squeeze().cpu().numpy()


### Verify which layers are trainable

In [17]:
# --- Verify which layers are trainable ---
total_param = 0
train_param = 0
for name, param in student_model.named_parameters():
    total_param += param.numel()
    if param.requires_grad:
        train_param += param.numel()
print(f"Total Parameters: {total_param}")
print(f"Trainable Parameters: {train_param}")

print(student_model)


Total Parameters: 3373265
Trainable Parameters: 3373265
StudentDepthModel(
  (encoder): FeatureListNet(
    (stem): ConvNormAct(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNormAct2d(
        16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (drop): Identity()
        (act): SiLU(inplace=True)
      )
    )
    (stages_0): Sequential(
      (0): BottleneckBlock(
        (conv1_1x1): ConvNormAct(
          (conv): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNormAct2d(
            64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): SiLU(inplace=True)
          )
        )
        (conv2_kxk): ConvNormAct(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (bn): BatchNormAct2d(
            64, eps=1e-05, momentum=0.1, affine=True, tra

In [None]:
# --- Verify which layers are trainable ---
total_param = 0
train_param = 0
for name, param in teacher_model.named_parameters():
    total_param += param.numel()
    if param.requires_grad:
        train_param += param.numel()
print(f"Total Parameters: {total_param}")
print(f"Trainable Parameters: {train_param}")

print(teacher_model)


Total Parameters: 24785089
Trainable Parameters: 24785089
TeacherWrapper(
  (model): DepthAnythingForDepthEstimation(
    (backbone): Dinov2Backbone(
      (embeddings): Dinov2Embeddings(
        (patch_embeddings): Dinov2PatchEmbeddings(
          (projection): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): Dinov2Encoder(
        (layer): ModuleList(
          (0-11): 12 x Dinov2Layer(
            (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
            (attention): Dinov2Attention(
              (attention): Dinov2SelfAttention(
                (query): Linear(in_features=384, out_features=384, bias=True)
                (key): Linear(in_features=384, out_features=384, bias=True)
                (value): Linear(in_features=384, out_features=384, bias=True)
              )
              (output): Dinov2SelfOutput(
                (dense): Linear(in_features=384, out_features=38

### Run the Training

In [None]:
import gc

# Clear memory
gc.collect()
torch.cuda.empty_cache()

### Run 1

In [None]:
trainLoss, valLoss = train_knowledge_distillation(
      teacher=teacher_model,
      student=student_model,
      train_dataloader=train_dataloader,
      val_dataloader=val_dataloader,
      criterion=criterion,
      optimizer=student_optimizer,
      epochs=num_epochs,
      scheduler=scheduler,
      device=device,
      checkpoint_dir=checkpoint_dir
  )



Starting Knowledge Distillation Training on cuda...


Epoch 1/60: 100%|██████████| 334/334 [15:32<00:00,  2.79s/it]


End of Epoch 1,Time: 932.01s, Current LR: 0.000010, Average Loss: 0.8446


Epoch 1/60 [Validation]: 100%|██████████| 84/84 [03:33<00:00,  2.55s/it]


Average Validation Loss: 0.8306
Validation loss improved. Saving the model.
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 2/60: 100%|██████████| 334/334 [04:58<00:00,  1.12it/s]


End of Epoch 2,Time: 298.53s, Current LR: 0.000010, Average Loss: 0.8282


Epoch 2/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.28it/s]


Average Validation Loss: 0.8180
Validation loss improved. Saving the model.
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 3/60: 100%|██████████| 334/334 [04:58<00:00,  1.12it/s]


End of Epoch 3,Time: 298.66s, Current LR: 0.000010, Average Loss: 0.8270


Epoch 3/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.29it/s]


Average Validation Loss: 0.8810
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 4/60: 100%|██████████| 334/334 [04:58<00:00,  1.12it/s]


End of Epoch 4,Time: 298.23s, Current LR: 0.000010, Average Loss: 0.8151


Epoch 4/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.30it/s]


Average Validation Loss: 0.8037
Validation loss improved. Saving the model.
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 5/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 5,Time: 297.84s, Current LR: 0.000010, Average Loss: 0.8029


Epoch 5/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.27it/s]


Average Validation Loss: 0.7916
Validation loss improved. Saving the model.
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 6/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 6,Time: 297.63s, Current LR: 0.000010, Average Loss: 0.7933


Epoch 6/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.29it/s]


Average Validation Loss: 0.8084
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 7/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 7,Time: 297.58s, Current LR: 0.000010, Average Loss: 0.7949


Epoch 7/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.28it/s]


Average Validation Loss: 0.7930
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 8/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 8,Time: 297.77s, Current LR: 0.000010, Average Loss: 0.7814


Epoch 8/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.29it/s]


Average Validation Loss: 0.8032
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 9/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 9,Time: 297.66s, Current LR: 0.000010, Average Loss: 0.7802


Epoch 9/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.30it/s]


Average Validation Loss: 0.8004
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 10/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 10,Time: 297.65s, Current LR: 0.000010, Average Loss: 0.7847


Epoch 10/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.29it/s]


Average Validation Loss: 0.7882
Validation loss improved. Saving the model.
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 11/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 11,Time: 297.43s, Current LR: 0.000009, Average Loss: 0.7691


Epoch 11/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.27it/s]


Average Validation Loss: 0.7980
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 12/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 12,Time: 297.83s, Current LR: 0.000009, Average Loss: 0.7747


Epoch 12/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.30it/s]


Average Validation Loss: 0.7860
Validation loss improved. Saving the model.
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 13/60: 100%|██████████| 334/334 [04:58<00:00,  1.12it/s]


End of Epoch 13,Time: 298.65s, Current LR: 0.000009, Average Loss: 0.7686


Epoch 13/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.29it/s]


Average Validation Loss: 0.8056
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 14/60: 100%|██████████| 334/334 [04:58<00:00,  1.12it/s]


End of Epoch 14,Time: 298.33s, Current LR: 0.000009, Average Loss: 0.7570


Epoch 14/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.30it/s]


Average Validation Loss: 0.7761
Validation loss improved. Saving the model.
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 15/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 15,Time: 297.40s, Current LR: 0.000009, Average Loss: 0.7535


Epoch 15/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.28it/s]


Average Validation Loss: 0.8043
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 16/60: 100%|██████████| 334/334 [04:57<00:00,  1.12it/s]


End of Epoch 16,Time: 297.52s, Current LR: 0.000009, Average Loss: 0.7526


Epoch 16/60 [Validation]: 100%|██████████| 84/84 [00:36<00:00,  2.29it/s]


Average Validation Loss: 0.7748
Validation loss improved. Saving the model.
Training and validation losses saved to /content/drive/MyDrive/FINAL_training_losses_CoCo.pth


Epoch 17/60:  37%|███▋      | 123/334 [01:50<03:07,  1.13it/s]

### Run 2

In [None]:
trainLoss, valLoss = train_knowledge_distillation(
      teacher=teacher_model,
      student=student_model,
      train_dataloader=train_dataloader,
      val_dataloader=val_dataloader,
      criterion=criterion,
      optimizer=student_optimizer,
      epochs=num_epochs,
      scheduler=scheduler,
      device=device,
      checkpoint_dir=checkpoint_dir
  )


# Evaluation

### On training

In [None]:
# Load image
image_path = "/content/drive/MyDrive/Coco/test2017/000000000001.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB

input_tensor = eval_transform(Image.fromarray(image)).unsqueeze(0).to(device)

# Set models to evaluation mode
teacher_model.eval() # Use the DepthModel instance
student_model.eval() # Use the DepthModel instance


with torch.no_grad():
    # Student prediction (before training) using the DepthModel instance
    start_time = time.time()
    student_output_after, Sfeat = student_model(input_tensor)
    end_time = time.time()
    inference_time_ms = (end_time - start_time) * 1000
    print(f"✅ Student model inference time: {inference_time_ms:.2f} ms")

    student_output_after_training = student_output_after.squeeze().cpu().numpy()

    # Teacher prediction using the DepthModel instance
    teacher_depth, Tfeat = teacher_model(input_tensor)
    print(teacher_depth.shape)
    teacher_depth = teacher_depth.squeeze().cpu().numpy()

# loss = distillation_criterion(student_output_after_training, teacher_depth)
# print(loss.item())
#Befor training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_depth_before_training_train_image, cmap="viridis")
plt.title("Student Depth Estimation (Before Training)")
plt.axis("off")

plt.show()

#After training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_output_after_training, cmap="viridis")
plt.title("Student Depth Estimation (After Training)")
plt.axis("off")

plt.show()



### On Testing

In [None]:
# Load image
image_path = "/content/test.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB

input_tensor = eval_transform(Image.fromarray(image)).unsqueeze(0).to(device)

# Set models to evaluation mode
teacher_model.eval() # Use the DepthModel instance
student_model.eval() # Use the DepthModel instance

loss = 0;
with torch.no_grad():
    # Student prediction (before training) using the DepthModel instance
    student_output_after, Sfeat = student_model(input_tensor)
    student_output_after_training = student_output_after.squeeze().cpu().numpy()

    # Teacher prediction using the DepthModel instance
    teacher_depth, Tfeat = teacher_model(input_tensor)
    # loss = distillation_criterion(student_output_after, teacher_depth)
    teacher_depth = teacher_depth.squeeze().cpu().numpy()

# print(loss.item())
#Befor training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_depth_before_training, cmap="viridis")
plt.title("Student Depth Estimation (Before Training)")
plt.axis("off")

plt.show()

#After training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_output_after_training, cmap="viridis")
plt.title("Student Depth Estimation (After Training)")
plt.axis("off")

plt.show()

### Export To onnx

In [None]:
%pip install onnx

Collecting onnx
  Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m107.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx
Successfully installed onnx-1.18.0


In [None]:
import onnx
dummy_input = torch.randn(1, 3, 384, 384).to(device)
onnx_model_path = '/content/drive/MyDrive/LastWrapped.onnx'
print(f"Exporting model to ONNX format at {onnx_model_path}...")


student_model.eval()

torch.onnx.export(
    student_model,                # The model to export
    dummy_input,                 # A sample input tensor
    onnx_model_path,             # Where to save the model
    export_params=True,          # Store the trained parameter weights inside the model file
    opset_version=14,            # The ONNX version to use (11, 12 are good choices)
    do_constant_folding=True,    # A performance optimization
    input_names=['input'],       # A name for the model's input
    output_names=['output_depth'], # A name for the model's output
    dynamic_axes={               # Allows for variable input image sizes
        'input': {0: 'batch_size', 2: 'height', 3: 'width'},
        'output_depth': {0: 'batch_size', 2: 'height', 3: 'width'}
    }
)

print("\nONNX export complete!")
print(f"Model saved to: {onnx_model_path}")
print("\nNext step: You can now use a tool like 'onnx-tf' to convert this .onnx file to a TensorFlow SavedModel, and then to TFLite.")

Exporting model to ONNX format at /content/drive/MyDrive/LastWrapped.onnx...


  new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
  if new_h != H or new_w != W:



ONNX export complete!
Model saved to: /content/drive/MyDrive/LastWrapped.onnx

Next step: You can now use a tool like 'onnx-tf' to convert this .onnx file to a TensorFlow SavedModel, and then to TFLite.
