# Set Up the Environment

### Install Dependencies 

In [None]:
# Install PyTorch and dependencies
%pip install torch torchvision

# Install Hugging Face Transformers for model loading
%pip install transformers

# Install dataset tools
%pip install datasets
# for image Processiing
%pip install opencv-python

### Importing Libraries and Drive Files

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from transformers import pipeline
from transformers import AutoModelForDepthEstimation
import time
import PIL
from PIL import Image
import requests
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

# Define Needed Classes & Functions

Class for Depth Model

In [None]:
class DepthModel(nn.Module):
    def __init__(self, mod, features_to_extract=None):
        super().__init__()
        self.model = mod
        self.eval()
        self.features_to_extract = features_to_extract

    def forward(self, x):
        outputs = self.model(x)
        predicted_depth = outputs.predicted_depth

        extracted_features = []
        if self.features_to_extract is not None and outputs.hidden_states is not None:
            for i in self.features_to_extract:
                if 0 <= i < len(outputs.hidden_states):
                    extracted_features.append(outputs.hidden_states[i])
                else:
                    print(f"Warning: Feature index {i} out of bounds for teacher model.")

        if self.features_to_extract is not None:
            return predicted_depth, extracted_features
        return predicted_depth


Class for Dataset Loading & Preprocessing

In [None]:
class UnlabeledImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(('png', 'jpg', 'JPG'))]
        print(f"Found {len(self.image_paths)} images in {root_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB') 

        if self.transform:
            image = self.transform(image)
        return image

Distillation Loss

In [None]:

class DepthDistillationLoss(nn.Module):
    def __init__(self, lambda_depth=1.0, lambda_features=0.0):
        super().__init__()
        self.lambda_depth = lambda_depth
        self.lambda_features = lambda_features
        self.mae_loss = nn.L1Loss() # Mean Absolute Error for depth maps
        self.mse_loss = nn.MSELoss() # Mean Squared Error for features

    def forward(self, student_outputs, teacher_outputs):
        # student_outputs and teacher_outputs can be tuples (depth_map, [features])
        student_depth = student_outputs[0] if isinstance(student_outputs, tuple) else student_outputs
        teacher_depth = teacher_outputs[0] if isinstance(teacher_outputs, tuple) else teacher_outputs

        total_loss = torch.tensor(0.0, device=student_depth.device)

        # 1. Depth Map Loss (MAE)
        if self.lambda_depth > 0:
            loss_depth = self.mae_loss(student_depth, teacher_depth)
            total_loss += self.lambda_depth * loss_depth

        # 2. Feature Loss (MSE) - if features are provided
        if self.lambda_features > 0 and isinstance(student_outputs, tuple) and isinstance(teacher_outputs, tuple):
            student_features = student_outputs[1]
            teacher_features = teacher_outputs[1]
            if len(student_features) != len(teacher_features):
                raise ValueError("Number of student and teacher feature lists must match.")

            loss_features = 0.0
            for sf, tf in zip(student_features, teacher_features):
                # Ensure feature maps are of compatible sizes if different layers have different resolutions
                # You might need to interpolate sf to tf.size() or vice-versa
                if sf.shape != tf.shape:
                    # Example: Interpolate student feature to teacher feature size
                    sf = F.interpolate(sf, size=tf.shape[2:], mode='bilinear', align_corners=False)
                loss_features += self.mse_loss(sf, tf)

            total_loss += self.lambda_features * (loss_features / len(student_features)) # Average feature loss

        return total_loss

The Training Function

In [None]:
def train_knowledge_distillation(teacher, student, dataloader, criterion, optimizer, epochs, device):
    teacher.eval() # Teacher should always be in evaluation mode
    student.train() # Student in training mode

    print(f"Starting Knowledge Distillation Training on {device}...")
    models = []
    x = 0
    for epoch in range(epochs):
        running_loss = 0.0
        start_time = time.time()

        for batch_idx, inputs in enumerate(dataloader):
            inputs = inputs.to(device)

            optimizer.zero_grad()

            # Forward pass with Teacher model (no_grad as teacher is fixed)
            with torch.no_grad():
                teacher_outputs = teacher(inputs) # Returns depth map and/or features

            # Forward pass with Student model
            student_outputs = student(inputs) # Returns depth map and/or features

            # Calculate distillation loss
            loss = criterion(student_outputs, teacher_outputs)

            # Backpropagation and Optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if (batch_idx + 1) % 50 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {running_loss / (batch_idx+1):.4f}")

        epoch_loss = running_loss / len(dataloader)
        end_time = time.time()
        print(f"Epoch {epoch+1} finished. Avg Loss: {epoch_loss:.4f}, Time: {end_time - start_time:.2f}s")

        # Save student model checkpoint periodically
        if(epoch+1)%5 == 0:
          models.append(student)
          # torch.save(student.state_dict(), f"/content/drive/MyDrive/distill_any_depth_student_epoch_{epoch+1}.pth")
          # print(f"Student model saved to /content/drive/MyDrive/distill_any_depth_student_epoch_{epoch+1}.pth")

    print("Knowledge Distillation Training Finished!")
    return models

# Training Process

### Define Parameters & Models

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

    # Instantiate Teacher and Student models
    # Make sure to set features_to_extract if you're using feature-based distillation
teacher_feature_layers = [0, 2] # Example: Extract features after first conv and first relu
student_feature_layers = [0, 2] # Example: Extract features after first conv and first relu

# Load teacher model
teacher = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf", output_hidden_states=True).to("cuda")

# Load student model
student = AutoModelForDepthEstimation.from_pretrained("xingyang1/Distill-Any-Depth-Small-hf", output_hidden_states=True).to("cuda")

# Enable gradient checkpointing for teacher and student
teacher.gradient_checkpointing_enable()
student.gradient_checkpointing_enable()

    # Ensure you load the actual DepthAnythingV2 Large here
    # For now, using dummy models
teacher_model = DepthAnythingV2Teacher(teacher, features_to_extract=teacher_feature_layers).to(device)
student_model = DistillAnyDepthStudent(student, features_to_extract=student_feature_layers).to(device)

    # Initialize optimizer for the student model
student_optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

    # Training parameters
num_epochs = 5

    # Distillation LOss
    # Instantiate the custom loss function
distillation_criterion = DepthDistillationLoss(lambda_depth=1.0, lambda_features=0.5)

# transformations for input images (teacher and student will use the same)

input_size = (384, 384)

transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet normalization
])

# Specify the path to your unlabeled data in Google Drive
unlabeled_data_path = '/content/drive/MyDrive/images/'

# Create dataset and data loader
unlabeled_dataset = UnlabeledImageDataset(root_dir=unlabeled_data_path, transform=transform)
unlabeled_dataloader = DataLoader(unlabeled_dataset, batch_size=5, shuffle=True, num_workers=2)



### Run the Training

In [None]:
    # Run the training
mods = train_knowledge_distillation(
    teacher=teacher_model,
    student=student_model,
    dataloader=unlabeled_dataloader,
    criterion=distillation_criterion,
    optimizer=student_optimizer,
    epochs=num_epochs,
    device=device
)

print("Training complete. Student model saved at specified checkpoints.")


# Evaluation 

In [None]:
# Load image
image_path = "/content/test.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB

input_tensor = transform(Image.fromarray(image)).unsqueeze(0).to("cuda")

# Set models to evaluation mode
teacher.eval()
student.eval()

with torch.no_grad():
    # Student prediction (before training)
    student_depth_before = student(input_tensor)[0] if isinstance(student(input_tensor), tuple) else student(input_tensor)
    print(student_depth_before.shape)
    student_depth_before_training = student_depth_before.squeeze().cpu().numpy()
    print(student_depth_before_training.shape)
    # Teacher prediction
    teacher_depth = teacher(input_tensor)[0] if isinstance(teacher(input_tensor), tuple) else teacher(input_tensor)
    teacher_depth = teacher_depth.squeeze().cpu().numpy()

#Befor training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Teacher Depth Map
plt.subplot(1, 3, 2)
plt.imshow(teacher_depth, cmap="viridis")
plt.title("Teacher Depth Estimation")
plt.axis("off")

# Student Depth Map
plt.subplot(1, 3, 3)
plt.imshow(student_depth_before_training, cmap="viridis")
plt.title("Student Depth Estimation (Before Training)")
plt.axis("off")

plt.show()

# Set models to evaluation mode
teacher.eval()
mods[0].eval()

with torch.no_grad():
    # Student prediction (after 5 epochs)
    student_depth_after_5 = mods[0](input_tensor)[0] if isinstance(mods[0](input_tensor), tuple) else mods[0](input_tensor)
    print(student_depth_after_5.shape)
    student_depth_after_training_5 = student_depth_after_5.squeeze().cpu().numpy()
    print(student_depth_after_5.shape)


# After Training
plt.figure(figsize=(15, 5))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")

# Student Depth Map (After 5 Epochs)
plt.subplot(1, 3, 2)
plt.imshow(student_depth_after_training_5, cmap="viridis")
plt.title("Student Depth Estimation (After 5 Epochs)")
plt.axis("off")

# Placeholder for Student After 10 Epochs
plt.subplot(1, 3, 3)
plt.imshow(student_depth_after_training_5, cmap="viridis") # Re-using for placeholder display
plt.title("Student Depth Estimation (After 10 Epochs)")
plt.axis("off")

plt.show()