In [None]:
# Mount Drive, extract CelebA, and create DataLoaders in one cell
from google.colab import drive
import zipfile
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm

# 1. Mount Google Drive
drive.mount('/content/drive')

# # 2. Extract CelebA ZIP (update your path)
zip_path = '/content/drive/MyDrive/img_align_celeba.zip'  # Change to your path
extract_path = '/content/celeba'

# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#     for file in tqdm(zip_ref.namelist(), desc='Extracting'):
#         zip_ref.extract(file, extract_path)

# 3. Dataset Class (Loads All Data Without Splitting)
class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
                                      delim_whitespace=True, header=1)
        self.filenames = self.attributes.index.tolist()  # Load all images
        self.attributes = (self.attributes + 1) // 2  # Convert -1/1 to 0/1

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, 'img_align_celeba', self.filenames[idx])
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = torch.zeros(3, 128, 128)  # Return a blank image on failure

        attrs = self.attributes.iloc[idx].values.astype('float32')
        if self.transform:
            img = self.transform(img)
        return img, torch.from_numpy(attrs)

# 4. Create DataLoader for Entire Dataset
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

full_loader = DataLoader(
    CelebADataset(extract_path, transform),
    batch_size=64, shuffle=True
)

# Verify
images, attrs = next(iter(full_loader))
print(f"Loaded batch: {images.shape}, {attrs.shape}")
print(f"Total samples: {len(full_loader.dataset)}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),


Loaded batch: torch.Size([64, 3, 156, 128]), torch.Size([64, 40])
Total samples: 202599


In [None]:
import os

num_images = len(os.listdir('/content/celeba/img_align_celeba'))
print(f"Number of images found: {num_images}")

Number of images found: 202599


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models
from torch.utils.data import DataLoader
from tqdm import tqdm

# Configuration
class Config:
    root_dir = '/content/celeba/img_align_celeba'  # Ensure this matches the first section
    csv_path = '/content/celeba/list_attr_celeba.txt'
    batch_size = 64
    lr = 1e-4
    temperature = 2.0
    alpha = 0.7
    num_classes = 40
    max_epochs = 10
    model_save_path = 'distilled_student.pth'
    pretrained_weights = models.ResNet50_Weights.IMAGENET1K_V2

# Use the same DataLoader from previous cell (full_loader)
train_loader = full_loader  # Reuse the already loaded DataLoader

# Load Teacher Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = models.resnet50(weights=Config.pretrained_weights)

# Modify FC layer to match expected output
teacher.fc = nn.Sequential(
    nn.Linear(teacher.fc.in_features, 512),
    nn.ReLU(),
    nn.Linear(512, Config.num_classes)
)

# Load pre-trained weights
checkpoint = torch.load('/content/drive/MyDrive/best_model.pth', map_location=device)
teacher.load_state_dict(checkpoint, strict=False)
teacher = teacher.to(device).eval()

# Load Student Model
student = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
student.fc = nn.Linear(student.fc.in_features, Config.num_classes)
student = student.to(device).train()

# Loss functions
def distillation_loss(student_logits, teacher_logits, temperature=Config.temperature):
    soft_teacher = torch.sigmoid(teacher_logits / temperature)
    soft_student = torch.sigmoid(student_logits / temperature)
    return F.binary_cross_entropy(soft_student, soft_teacher)

def attribute_loss(student_logits, targets):
    return F.binary_cross_entropy_with_logits(student_logits, targets)

# Optimizer
optimizer = optim.Adam(student.parameters(), lr=Config.lr)

# Training loop
best_val_loss = float('inf')
for epoch in range(Config.max_epochs):
    student.train()
    total_loss = 0
    for images, attrs in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
        images, attrs = images.to(device), attrs.to(device)

        with torch.no_grad():
            teacher_logits = teacher(images)

        student_logits = student(images)
        loss = (Config.alpha * distillation_loss(student_logits, teacher_logits) +
                (1 - Config.alpha) * attribute_loss(student_logits, attrs))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} | Train Loss: {avg_loss:.4f}")

    # Save best model
    if avg_loss < best_val_loss:
        best_val_loss = avg_loss
        torch.save(student.state_dict(), Config.model_save_path)
        print("Best model saved!")


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 88.1MB/s]
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 71.0MB/s]
Epoch 1: 100%|██████████| 3166/3166 [08:26<00:00,  6.26it/s]


Epoch 1 | Train Loss: 0.6188
Best model saved!


Epoch 2: 100%|██████████| 3166/3166 [08:23<00:00,  6.29it/s]


Epoch 2 | Train Loss: 0.6155
Best model saved!


Epoch 3: 100%|██████████| 3166/3166 [08:22<00:00,  6.30it/s]


Epoch 3 | Train Loss: 0.6141
Best model saved!


Epoch 4: 100%|██████████| 3166/3166 [08:23<00:00,  6.28it/s]


Epoch 4 | Train Loss: 0.6128
Best model saved!


Epoch 5: 100%|██████████| 3166/3166 [08:22<00:00,  6.29it/s]


Epoch 5 | Train Loss: 0.6111
Best model saved!


Epoch 6: 100%|██████████| 3166/3166 [08:24<00:00,  6.28it/s]


Epoch 6 | Train Loss: 0.6092
Best model saved!


Epoch 7: 100%|██████████| 3166/3166 [08:23<00:00,  6.29it/s]


Epoch 7 | Train Loss: 0.6072
Best model saved!


Epoch 8: 100%|██████████| 3166/3166 [08:24<00:00,  6.27it/s]


Epoch 8 | Train Loss: 0.6051
Best model saved!


Epoch 9: 100%|██████████| 3166/3166 [08:26<00:00,  6.25it/s]


Epoch 9 | Train Loss: 0.6032
Best model saved!


Epoch 10: 100%|██████████| 3166/3166 [08:25<00:00,  6.26it/s]


Epoch 10 | Train Loss: 0.6016
Best model saved!


In [None]:
# Function to calculate accuracy
def accuracy(predictions, targets):
    # Convert logits to binary predictions using sigmoid
    preds = torch.sigmoid(predictions)
    preds = (preds > 0.5).float()  # Convert to binary (0 or 1)
    correct = (preds == targets).float()  # Check if predictions match targets
    return correct.sum() / correct.numel()  # Compute the accuracy

# Function to evaluate a model (either Teacher or Student)
def evaluate_model(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0
    with torch.no_grad():  # Disable gradient calculation during evaluation
        for images, attrs in tqdm(data_loader, desc=f'Evaluating {model.__class__.__name__}'):
            images, attrs = images.to(device), attrs.to(device)

            # Forward pass through the model
            logits = model(images)

            # Calculate accuracy
            total_correct += accuracy(logits, attrs) * attrs.size(0)
            total_samples += attrs.size(0)

    # Final accuracy
    accuracy_value = total_correct / total_samples
    return accuracy_value

# Test Teacher Model on Full Dataset
teacher_accuracy = evaluate_model(teacher, full_loader, device)
print(f"Teacher Model Accuracy: {teacher_accuracy:.4f}")

# Test Student Model on Full Dataset
student_accuracy = evaluate_model(student, full_loader, device)
print(f"Student Model Accuracy: {student_accuracy:.4f}")


Evaluating ResNet: 100%|██████████| 3166/3166 [05:29<00:00,  9.61it/s]


Teacher Model Accuracy: 0.5224


Evaluating ResNet: 100%|██████████| 3166/3166 [05:10<00:00, 10.19it/s]

Student Model Accuracy: 0.9816



