In [None]:
from google.colab import drive
import zipfile
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Extract CelebA ZIP (update your path)
zip_path = '/content/drive/MyDrive/img_align_celeba.zip'  # Change to your path
extract_path = '/content/celeba'

if not os.path.exists(extract_path):
    os.makedirs(extract_path, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        for file in tqdm(zip_ref.namelist(), desc='Extracting'):
            zip_ref.extract(file, extract_path)

# 3. Dataset Class with Splitting Capability
class CelebADataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
                                      delim_whitespace=True, header=1)
        self.filenames = self.attributes.index.tolist()
        self.attributes = (self.attributes + 1) // 2  # Convert -1/1 to 0/1

        # Load the official split file
        split_file = os.path.join(root_dir, 'list_eval_partition.txt')
        split_info = pd.read_csv(split_file, delim_whitespace=True, header=None, index_col=0, names=['partition'])

        # Filter based on requested split
        if split == 'train':
            self.filenames = [f for f in self.filenames if split_info.loc[f].partition == 0]
        elif split == 'valid':
            self.filenames = [f for f in self.filenames if split_info.loc[f].partition == 1]
        elif split == 'test':
            self.filenames = [f for f in self.filenames if split_info.loc[f].partition == 2]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, 'img_align_celeba', self.filenames[idx])
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = torch.zeros(3, 128, 128)  # Return a blank image on failure

        attrs = self.attributes.loc[self.filenames[idx]].values.astype('float32')
        if self.transform:
            img = self.transform(img)
        return img, torch.from_numpy(attrs)

# 4. Define Transforms
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# 5. Create DataLoaders
train_dataset = CelebADataset(extract_path, split='train', transform=transform)
valid_dataset = CelebADataset(extract_path, split='valid', transform=transform)
test_dataset = CelebADataset(extract_path, split='test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Verify
images, attrs = next(iter(train_loader))
print(f"Train batch: {images.shape}, {attrs.shape}")
print(f"Train samples: {len(train_loader.dataset)}, Valid samples: {len(valid_loader.dataset)}, Test samples: {len(test_loader.dataset)}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
  split_info = pd.read_csv(split_file, delim_whitespace=True, header=None, index_col=0, names=['partition'])
  self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
  split_info = pd.read_csv(split_file, delim_whitespace=True, header=None, index_col=0, names=['partition'])
  self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
  split_info = pd.read_csv(split_file, delim_whitespace=True, header=None, index_col=0, names=['partition'])


Train batch: torch.Size([64, 3, 156, 128]), torch.Size([64, 40])
Train samples: 162770, Valid samples: 19867, Test samples: 19962


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
import numpy as np

# Configuration
class Config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 40
    batch_size = 128
    lr = 3e-4
    max_epochs = 30
    temperature = 1.0  # Start with simpler temperature
    alpha = 0.5  # Balanced weight between teacher and ground truth
    patience = 7
    grad_clip = 1.0
    grad_accum_steps = 2
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    attribute_names = [
        '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
        'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
        'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
        'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
        'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
        'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks',
        'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
        'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
    ]

# Teacher Model (ResNet50)
class TeacherModel(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

        # Freeze early layers
        for param in self.backbone.parameters():
            param.requires_grad = False
        for param in self.backbone.layer3.parameters():
            param.requires_grad = True
        for param in self.backbone.layer4.parameters():
            param.requires_grad = True

        # Classifier head
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.backbone.fc.in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes))

    def forward(self, x):
        return self.backbone(x)

# Student Model (ResNet18)
class StudentModel(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.backbone.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes))

    def forward(self, x):
        return self.backbone(x)

# Proper Distillation Loss for Multi-Label Classification
class DistillationLoss(nn.Module):
    def __init__(self, temperature, alpha):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, student_logits, teacher_logits, targets):
        # Teacher probabilities with temperature
        teacher_probs = torch.sigmoid(teacher_logits / self.temperature)
        student_probs = torch.sigmoid(student_logits / self.temperature)

        # BCE loss with original logits (not temperature-scaled)
        bce_loss = self.bce_loss(student_logits, targets)

        # KL divergence between teacher and student probabilities
        kl_loss = (teacher_probs * (torch.log(teacher_probs + 1e-10) -
                  torch.log(student_probs + 1e-10))).mean()

        # Combined loss
        total_loss = (self.alpha * kl_loss +
                     (1 - self.alpha) * bce_loss)

        return total_loss, {'kl_loss': kl_loss.item(), 'bce_loss': bce_loss.item()}

# Training Function
def train_with_distillation(student, teacher, train_loader, val_loader):
    optimizer = optim.AdamW(
        student.parameters(),
        lr=Config.lr,
        weight_decay=1e-4
    )

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )

    criterion = DistillationLoss(Config.temperature, Config.alpha)
    scaler = torch.cuda.amp.GradScaler()

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(Config.max_epochs):
        student.train()
        train_loss = 0.0
        loss_components = {'kl_loss': 0.0, 'bce_loss': 0.0}

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for step, (images, labels) in enumerate(pbar):
            images = images.to(Config.device, non_blocking=True)
            labels = labels.to(Config.device, non_blocking=True)

            with torch.cuda.amp.autocast():
                student_logits = student(images)
                with torch.no_grad():
                    teacher_logits = teacher(images)
                loss, comp = criterion(student_logits, teacher_logits, labels)
                loss = loss / Config.grad_accum_steps

            scaler.scale(loss).backward()

            if (step + 1) % Config.grad_accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(student.parameters(), Config.grad_clip)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            train_loss += loss.item() * Config.grad_accum_steps
            for k in loss_components:
                loss_components[k] += comp[k]
            pbar.set_postfix({
                'loss': f"{train_loss/(step+1):.4f}",
                'kl': f"{loss_components['kl_loss']/(step+1):.4f}",
                'bce': f"{loss_components['bce_loss']/(step+1):.4f}"
            })

        # Validation
        val_loss = 0.0
        val_components = {'kl_loss': 0.0, 'bce_loss': 0.0}
        student.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(Config.device)
                labels = labels.to(Config.device)

                with torch.cuda.amp.autocast():
                    student_logits = student(images)
                    teacher_logits = teacher(images)
                    loss, comp = criterion(student_logits, teacher_logits, labels)
                    val_loss += loss.item()
                    for k in val_components:
                        val_components[k] += comp[k]

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        scheduler.step(avg_val_loss)

        print(f"\nEpoch {epoch+1}/{Config.max_epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} (KL: {loss_components['kl_loss']/len(train_loader):.4f}, BCE: {loss_components['bce_loss']/len(train_loader):.4f})")
        print(f"Val Loss: {avg_val_loss:.4f} (KL: {val_components['kl_loss']/len(val_loader):.4f}, BCE: {val_components['bce_loss']/len(val_loader):.4f})")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'model': student.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'best_val_loss': best_val_loss
            }, 'best_student_model.pth')
            print("Saved best student model")
        else:
            patience_counter += 1
            if patience_counter >= Config.patience:
                print(f"Early stopping after {epoch+1} epochs")
                break

    checkpoint = torch.load('best_student_model.pth')
    student.load_state_dict(checkpoint['model'])
    return student

# Evaluation Function
def evaluate_model(model, loader, attribute_names=None):
    model.eval()
    results = {
        'strict_acc': 0.0,
        'mean_acc': 0.0,
        'top_k_acc': {5: 0.0, 10: 0.0, 20: 0.0},
        'per_attribute_acc': np.zeros(Config.num_classes),
        'confusion_matrix': np.zeros((2, 2, Config.num_classes))
    }

    total_samples = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(Config.device)
            labels = labels.to(Config.device).float()
            batch_size = images.size(0)

            outputs = model(images)
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            results['strict_acc'] += (preds == labels).all(dim=1).sum().item()
            results['mean_acc'] += (preds == labels).float().mean(dim=1).sum().item()

            for k in results['top_k_acc'].keys():
                _, topk_indices = torch.topk(probs, k, dim=1)
                correct = torch.gather(labels, 1, topk_indices).sum(dim=1)
                results['top_k_acc'][k] += (correct.float() / k).sum().item()

            for attr in range(Config.num_classes):
                attr_pred = preds[:, attr]
                attr_label = labels[:, attr]
                results['per_attribute_acc'][attr] += (attr_pred == attr_label).sum().item()
                results['confusion_matrix'][0, 0, attr] += ((attr_pred == 1) & (attr_label == 1)).sum().item()
                results['confusion_matrix'][0, 1, attr] += ((attr_pred == 1) & (attr_label == 0)).sum().item()
                results['confusion_matrix'][1, 0, attr] += ((attr_pred == 0) & (attr_label == 1)).sum().item()
                results['confusion_matrix'][1, 1, attr] += ((attr_pred == 0) & (attr_label == 0)).sum().item()

            total_samples += batch_size

    results['strict_acc'] /= total_samples
    results['mean_acc'] /= total_samples
    for k in results['top_k_acc']:
        results['top_k_acc'][k] /= total_samples
    results['per_attribute_acc'] /= total_samples

    attribute_metrics = {}
    for attr in range(Config.num_classes):
        tp = results['confusion_matrix'][0, 0, attr]
        fp = results['confusion_matrix'][0, 1, attr]
        fn = results['confusion_matrix'][1, 0, attr]
        tn = results['confusion_matrix'][1, 1, attr]

        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-10)

        attr_name = attribute_names[attr] if attribute_names else f"Attr {attr}"
        attribute_metrics[attr_name] = {
            'accuracy': (tp + tn) / (tp + tn + fp + fn),
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'support': tp + fn
        }

    print("\n=== Evaluation Results ===")
    print(f"Strict Accuracy: {results['strict_acc']:.4f}")
    print(f"Mean Accuracy: {results['mean_acc']:.4f}")
    for k, acc in sorted(results['top_k_acc'].items()):
        print(f"Top-{k} Accuracy: {acc:.4f}")

    if attribute_names:
        print("\nPer-Attribute Metrics (Top 10 by F1 score):")
        print("-" * 85)
        print(f"{'Attribute':<25}{'Accuracy':<10}{'Precision':<10}{'Recall':<10}{'F1':<10}{'Support':<10}")
        print("-" * 85)

        sorted_attrs = sorted(attribute_metrics.items(),
                            key=lambda x: x[1]['f1'],
                            reverse=True)[:10]

        for name, metrics in sorted_attrs:
            print(f"{name:<25}{metrics['accuracy']:.4f}    {metrics['precision']:.4f}    {metrics['recall']:.4f}    {metrics['f1']:.4f}    {metrics['support']:>10}")

    return results, attribute_metrics

# Main execution
if __name__ == "__main__":
    # Initialize data loaders (you need to define these)
    # train_loader = ...
    # valid_loader = ...
    # test_loader = ...

    # Load teacher model
    print("Loading teacher model...")
    teacher = TeacherModel(num_classes=Config.num_classes)
    teacher_checkpoint = torch.load('C:/Users/akash/Downloads/Soft.pth', map_location=Config.device)

    # Handle potential DataParallel wrapper
    if all(k.startswith('module.') for k in teacher_checkpoint.keys()):
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in teacher_checkpoint.items():
            name = k[7:]  # remove 'module.' prefix
            new_state_dict[name] = v
        teacher_checkpoint = new_state_dict

    teacher.load_state_dict(teacher_checkpoint, strict=False)
    teacher = teacher.to(Config.device)
    teacher.eval()

    # Create student model
    print("Initializing student model...")
    student = StudentModel(num_classes=Config.num_classes)
    student = student.to(Config.device)

    # Train with distillation
    print("\nStarting knowledge distillation training...")
    trained_student = train_with_distillation(student, teacher, train_loader, valid_loader)

    # Evaluate on test set
    print("\nEvaluating student model on test set:")
    test_results, attribute_metrics = evaluate_model(trained_student, test_loader, Config.attribute_names)

    # Save final student model
    torch.save(trained_student.state_dict(), 'final_student_model.pth')
    print("\nSaved final student model")

Loading teacher model...
Initializing student model...


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 118MB/s]
  scaler = torch.cuda.amp.GradScaler()



Starting knowledge distillation training...


  with torch.cuda.amp.autocast():
Epoch 1: 100%|██████████| 2544/2544 [08:21<00:00,  5.07it/s, loss=0.2638, kl=0.0572, bce=0.4704]
  with torch.cuda.amp.autocast():



Epoch 1/30
Train Loss: 0.2638 (KL: 0.0572, BCE: 0.4704)
Val Loss: 0.2580 (KL: 0.0593, BCE: 0.4567)
Current LR: 3.00e-04
Saved best student model


Epoch 2: 100%|██████████| 2544/2544 [07:32<00:00,  5.62it/s, loss=0.2557, kl=0.0589, bce=0.4524]



Epoch 2/30
Train Loss: 0.2557 (KL: 0.0589, BCE: 0.4524)
Val Loss: 0.2558 (KL: 0.0569, BCE: 0.4548)
Current LR: 3.00e-04
Saved best student model


Epoch 3: 100%|██████████| 2544/2544 [07:27<00:00,  5.68it/s, loss=0.2526, kl=0.0595, bce=0.4456]



Epoch 3/30
Train Loss: 0.2526 (KL: 0.0595, BCE: 0.4456)
Val Loss: 0.2558 (KL: 0.0619, BCE: 0.4496)
Current LR: 3.00e-04
Saved best student model


Epoch 4: 100%|██████████| 2544/2544 [07:31<00:00,  5.63it/s, loss=0.2497, kl=0.0603, bce=0.4392]



Epoch 4/30
Train Loss: 0.2497 (KL: 0.0603, BCE: 0.4392)
Val Loss: 0.2555 (KL: 0.0612, BCE: 0.4499)
Current LR: 3.00e-04
Saved best student model


Epoch 5: 100%|██████████| 2544/2544 [07:28<00:00,  5.67it/s, loss=0.2465, kl=0.0612, bce=0.4319]



Epoch 5/30
Train Loss: 0.2465 (KL: 0.0612, BCE: 0.4319)
Val Loss: 0.2570 (KL: 0.0575, BCE: 0.4566)
Current LR: 3.00e-04


Epoch 6: 100%|██████████| 2544/2544 [07:29<00:00,  5.66it/s, loss=0.2427, kl=0.0622, bce=0.4231]



Epoch 6/30
Train Loss: 0.2427 (KL: 0.0622, BCE: 0.4231)
Val Loss: 0.2586 (KL: 0.0632, BCE: 0.4540)
Current LR: 3.00e-04


Epoch 7: 100%|██████████| 2544/2544 [07:34<00:00,  5.60it/s, loss=0.2385, kl=0.0634, bce=0.4136]



Epoch 7/30
Train Loss: 0.2385 (KL: 0.0634, BCE: 0.4136)
Val Loss: 0.2636 (KL: 0.0627, BCE: 0.4645)
Current LR: 3.00e-04


Epoch 8: 100%|██████████| 2544/2544 [07:34<00:00,  5.60it/s, loss=0.2345, kl=0.0646, bce=0.4045]



Epoch 8/30
Train Loss: 0.2345 (KL: 0.0646, BCE: 0.4045)
Val Loss: 0.2675 (KL: 0.0631, BCE: 0.4718)
Current LR: 1.50e-04


Epoch 9: 100%|██████████| 2544/2544 [07:36<00:00,  5.57it/s, loss=0.2259, kl=0.0671, bce=0.3846]



Epoch 9/30
Train Loss: 0.2259 (KL: 0.0671, BCE: 0.3846)
Val Loss: 0.2774 (KL: 0.0650, BCE: 0.4897)
Current LR: 1.50e-04


Epoch 10: 100%|██████████| 2544/2544 [07:30<00:00,  5.64it/s, loss=0.2210, kl=0.0686, bce=0.3734]



Epoch 10/30
Train Loss: 0.2210 (KL: 0.0686, BCE: 0.3734)
Val Loss: 0.2877 (KL: 0.0616, BCE: 0.5138)
Current LR: 1.50e-04


Epoch 11: 100%|██████████| 2544/2544 [07:35<00:00,  5.58it/s, loss=0.2179, kl=0.0696, bce=0.3661]



Epoch 11/30
Train Loss: 0.2179 (KL: 0.0696, BCE: 0.3661)
Val Loss: 0.2961 (KL: 0.0672, BCE: 0.5250)
Current LR: 1.50e-04
Early stopping after 11 epochs

Evaluating student model on test set:

=== Evaluation Results ===
Strict Accuracy: 0.0132
Mean Accuracy: 0.8977
Top-5 Accuracy: 0.9106
Top-10 Accuracy: 0.7392
Top-20 Accuracy: 0.4549

Per-Attribute Metrics (Top 10 by F1 score):
-------------------------------------------------------------------------------------
Attribute                Accuracy  Precision Recall    F1        Support   
-------------------------------------------------------------------------------------
No_Beard                 0.9562    0.9579    0.9923    0.9748       17041.0
Eyeglasses               0.9966    0.9788    0.9682    0.9735        1289.0
Male                     0.9782    0.9599    0.9846    0.9721        7715.0
Wearing_Lipstick         0.9428    0.9209    0.9742    0.9468       10418.0
Mouth_Slightly_Open      0.9354    0.9345    0.9350    0.9348     