In [None]:
from google.colab import drive
import zipfile
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Extract CelebA ZIP (update your path)
zip_path = '/content/drive/MyDrive/img_align_celeba.zip'  # Change to your path
extract_path = '/content/celeba'

if not os.path.exists(extract_path):
    os.makedirs(extract_path, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        for file in tqdm(zip_ref.namelist(), desc='Extracting'):
            zip_ref.extract(file, extract_path)

# 3. Dataset Class with Splitting Capability
class CelebADataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
                                      delim_whitespace=True, header=1)
        self.filenames = self.attributes.index.tolist()
        self.attributes = (self.attributes + 1) // 2  # Convert -1/1 to 0/1

        # Load the official split file
        split_file = os.path.join(root_dir, 'list_eval_partition.txt')
        split_info = pd.read_csv(split_file, delim_whitespace=True, header=None, index_col=0, names=['partition'])

        # Filter based on requested split
        if split == 'train':
            self.filenames = [f for f in self.filenames if split_info.loc[f].partition == 0]
        elif split == 'valid':
            self.filenames = [f for f in self.filenames if split_info.loc[f].partition == 1]
        elif split == 'test':
            self.filenames = [f for f in self.filenames if split_info.loc[f].partition == 2]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, 'img_align_celeba', self.filenames[idx])
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = torch.zeros(3, 128, 128)  # Return a blank image on failure

        attrs = self.attributes.loc[self.filenames[idx]].values.astype('float32')
        if self.transform:
            img = self.transform(img)
        return img, torch.from_numpy(attrs)

# 4. Define Transforms
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# 5. Create DataLoaders
train_dataset = CelebADataset(extract_path, split='train', transform=transform)
valid_dataset = CelebADataset(extract_path, split='valid', transform=transform)
test_dataset = CelebADataset(extract_path, split='test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Verify
images, attrs = next(iter(train_loader))
print(f"Train batch: {images.shape}, {attrs.shape}")
print(f"Train samples: {len(train_loader.dataset)}, Valid samples: {len(valid_loader.dataset)}, Test samples: {len(test_loader.dataset)}")


Mounted at /content/drive


Extracting:   0%|          | 0/202600 [00:00<?, ?it/s]

  self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
  split_info = pd.read_csv(split_file, delim_whitespace=True, header=None, index_col=0, names=['partition'])
  self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
  split_info = pd.read_csv(split_file, delim_whitespace=True, header=None, index_col=0, names=['partition'])
  self.attributes = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.txt'),
  split_info = pd.read_csv(split_file, delim_whitespace=True, header=None, index_col=0, names=['partition'])


Train batch: torch.Size([64, 3, 156, 128]), torch.Size([64, 40])
Train samples: 162770, Valid samples: 19867, Test samples: 19962


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
import numpy as np

# Configuration
class Config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 40
    batch_size = 128
    lr = 3e-4
    max_epochs = 30
    temperature = 1.0  # Start with simpler temperature
    alpha = 0.5  # Balanced weight between teacher and ground truth
    patience = 7
    grad_clip = 1.0
    grad_accum_steps = 2
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    attribute_names = [
        '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
        'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
        'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
        'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
        'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
        'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks',
        'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
        'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
    ]

# Teacher Model (ResNet50)
class TeacherModel(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

        # Freeze early layers
        for param in self.backbone.parameters():
            param.requires_grad = False
        for param in self.backbone.layer3.parameters():
            param.requires_grad = True
        for param in self.backbone.layer4.parameters():
            param.requires_grad = True

        # Classifier head
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.backbone.fc.in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes))

    def forward(self, x):
        return self.backbone(x)

# Student Model (ResNet18)
class StudentModel(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.backbone.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes))

    def forward(self, x):
        return self.backbone(x)

# Proper Distillation Loss for Multi-Label Classification
class DistillationLoss(nn.Module):
    def __init__(self, temperature, alpha):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, student_logits, teacher_logits, targets):
        # Teacher probabilities with temperature
        teacher_probs = torch.sigmoid(teacher_logits / self.temperature)
        student_probs = torch.sigmoid(student_logits / self.temperature)

        # BCE loss with original logits (not temperature-scaled)
        bce_loss = self.bce_loss(student_logits, targets)

        # KL divergence between teacher and student probabilities
        kl_loss = (teacher_probs * (torch.log(teacher_probs + 1e-10) -
                  torch.log(student_probs + 1e-10))).mean()

        # Combined loss
        total_loss = (self.alpha * kl_loss +
                     (1 - self.alpha) * bce_loss)

        return total_loss, {'kl_loss': kl_loss.item(), 'bce_loss': bce_loss.item()}

# Training Function
def train_with_distillation(student, teacher, train_loader, val_loader):
    optimizer = optim.AdamW(
        student.parameters(),
        lr=Config.lr,
        weight_decay=1e-4
    )

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )

    criterion = DistillationLoss(Config.temperature, Config.alpha)
    scaler = torch.cuda.amp.GradScaler()

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(Config.max_epochs):
        student.train()
        train_loss = 0.0
        loss_components = {'kl_loss': 0.0, 'bce_loss': 0.0}

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for step, (images, labels) in enumerate(pbar):
            images = images.to(Config.device, non_blocking=True)
            labels = labels.to(Config.device, non_blocking=True)

            with torch.cuda.amp.autocast():
                student_logits = student(images)
                with torch.no_grad():
                    teacher_logits = teacher(images)
                loss, comp = criterion(student_logits, teacher_logits, labels)
                loss = loss / Config.grad_accum_steps

            scaler.scale(loss).backward()

            if (step + 1) % Config.grad_accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(student.parameters(), Config.grad_clip)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

            train_loss += loss.item() * Config.grad_accum_steps
            for k in loss_components:
                loss_components[k] += comp[k]
            pbar.set_postfix({
                'loss': f"{train_loss/(step+1):.4f}",
                'kl': f"{loss_components['kl_loss']/(step+1):.4f}",
                'bce': f"{loss_components['bce_loss']/(step+1):.4f}"
            })

        # Validation
        val_loss = 0.0
        val_components = {'kl_loss': 0.0, 'bce_loss': 0.0}
        student.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(Config.device)
                labels = labels.to(Config.device)

                with torch.cuda.amp.autocast():
                    student_logits = student(images)
                    teacher_logits = teacher(images)
                    loss, comp = criterion(student_logits, teacher_logits, labels)
                    val_loss += loss.item()
                    for k in val_components:
                        val_components[k] += comp[k]

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        scheduler.step(avg_val_loss)

        print(f"\nEpoch {epoch+1}/{Config.max_epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} (KL: {loss_components['kl_loss']/len(train_loader):.4f}, BCE: {loss_components['bce_loss']/len(train_loader):.4f})")
        print(f"Val Loss: {avg_val_loss:.4f} (KL: {val_components['kl_loss']/len(val_loader):.4f}, BCE: {val_components['bce_loss']/len(val_loader):.4f})")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'model': student.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'best_val_loss': best_val_loss
            }, 'best_student_model.pth')
            print("Saved best student model")
        else:
            patience_counter += 1
            if patience_counter >= Config.patience:
                print(f"Early stopping after {epoch+1} epochs")
                break

    checkpoint = torch.load('best_student_model.pth')
    student.load_state_dict(checkpoint['model'])
    return student

# Evaluation Function
def evaluate_model(model, loader, attribute_names=None):
    model.eval()
    results = {
        'strict_acc': 0.0,
        'mean_acc': 0.0,
        'top_k_acc': {5: 0.0, 10: 0.0, 20: 0.0},
        'per_attribute_acc': np.zeros(Config.num_classes),
        'confusion_matrix': np.zeros((2, 2, Config.num_classes))
    }

    total_samples = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(Config.device)
            labels = labels.to(Config.device).float()
            batch_size = images.size(0)

            outputs = model(images)
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            results['strict_acc'] += (preds == labels).all(dim=1).sum().item()
            results['mean_acc'] += (preds == labels).float().mean(dim=1).sum().item()

            for k in results['top_k_acc'].keys():
                _, topk_indices = torch.topk(probs, k, dim=1)
                correct = torch.gather(labels, 1, topk_indices).sum(dim=1)
                results['top_k_acc'][k] += (correct.float() / k).sum().item()

            for attr in range(Config.num_classes):
                attr_pred = preds[:, attr]
                attr_label = labels[:, attr]
                results['per_attribute_acc'][attr] += (attr_pred == attr_label).sum().item()
                results['confusion_matrix'][0, 0, attr] += ((attr_pred == 1) & (attr_label == 1)).sum().item()
                results['confusion_matrix'][0, 1, attr] += ((attr_pred == 1) & (attr_label == 0)).sum().item()
                results['confusion_matrix'][1, 0, attr] += ((attr_pred == 0) & (attr_label == 1)).sum().item()
                results['confusion_matrix'][1, 1, attr] += ((attr_pred == 0) & (attr_label == 0)).sum().item()

            total_samples += batch_size

    results['strict_acc'] /= total_samples
    results['mean_acc'] /= total_samples
    for k in results['top_k_acc']:
        results['top_k_acc'][k] /= total_samples
    results['per_attribute_acc'] /= total_samples

    attribute_metrics = {}
    for attr in range(Config.num_classes):
        tp = results['confusion_matrix'][0, 0, attr]
        fp = results['confusion_matrix'][0, 1, attr]
        fn = results['confusion_matrix'][1, 0, attr]
        tn = results['confusion_matrix'][1, 1, attr]

        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-10)

        attr_name = attribute_names[attr] if attribute_names else f"Attr {attr}"
        attribute_metrics[attr_name] = {
            'accuracy': (tp + tn) / (tp + tn + fp + fn),
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'support': tp + fn
        }

    print("\n=== Evaluation Results ===")
    print(f"Strict Accuracy: {results['strict_acc']:.4f}")
    print(f"Mean Accuracy: {results['mean_acc']:.4f}")
    for k, acc in sorted(results['top_k_acc'].items()):
        print(f"Top-{k} Accuracy: {acc:.4f}")

    if attribute_names:
        print("\nPer-Attribute Metrics (Top 10 by F1 score):")
        print("-" * 85)
        print(f"{'Attribute':<25}{'Accuracy':<10}{'Precision':<10}{'Recall':<10}{'F1':<10}{'Support':<10}")
        print("-" * 85)

        sorted_attrs = sorted(attribute_metrics.items(),
                            key=lambda x: x[1]['f1'],
                            reverse=True)[:10]

        for name, metrics in sorted_attrs:
            print(f"{name:<25}{metrics['accuracy']:.4f}    {metrics['precision']:.4f}    {metrics['recall']:.4f}    {metrics['f1']:.4f}    {metrics['support']:>10}")

    return results, attribute_metrics

# Main execution
if __name__ == "__main__":
    # Initialize data loaders (you need to define these)
    # train_loader = ...
    # valid_loader = ...
    # test_loader = ...

    # Load teacher model
    print("Loading teacher model...")
    teacher = TeacherModel(num_classes=Config.num_classes)
    teacher_checkpoint = torch.load('/content/drive/MyDrive/best_model.pth', map_location=Config.device)

    # Handle potential DataParallel wrapper
    if all(k.startswith('module.') for k in teacher_checkpoint.keys()):
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in teacher_checkpoint.items():
            name = k[7:]  # remove 'module.' prefix
            new_state_dict[name] = v
        teacher_checkpoint = new_state_dict

    teacher.load_state_dict(teacher_checkpoint, strict=False)
    teacher = teacher.to(Config.device)
    teacher.eval()

    # Create student model
    print("Initializing student model...")
    student = StudentModel(num_classes=Config.num_classes)
    student = student.to(Config.device)

    # Train with distillation
    print("\nStarting knowledge distillation training...")
    trained_student = train_with_distillation(student, teacher, train_loader, valid_loader)

    # Evaluate on test set
    print("\nEvaluating student model on test set:")
    test_results, attribute_metrics = evaluate_model(trained_student, test_loader, Config.attribute_names)

    # Save final student model
    torch.save(trained_student.state_dict(), 'final_student_model.pth')
    print("\nSaved final student model")

Loading teacher model...
Initializing student model...


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 118MB/s]
  scaler = torch.cuda.amp.GradScaler()



Starting knowledge distillation training...


  with torch.cuda.amp.autocast():
Epoch 1: 100%|██████████| 2544/2544 [08:21<00:00,  5.07it/s, loss=0.2638, kl=0.0572, bce=0.4704]
  with torch.cuda.amp.autocast():



Epoch 1/30
Train Loss: 0.2638 (KL: 0.0572, BCE: 0.4704)
Val Loss: 0.2580 (KL: 0.0593, BCE: 0.4567)
Current LR: 3.00e-04
Saved best student model


Epoch 2: 100%|██████████| 2544/2544 [07:32<00:00,  5.62it/s, loss=0.2557, kl=0.0589, bce=0.4524]



Epoch 2/30
Train Loss: 0.2557 (KL: 0.0589, BCE: 0.4524)
Val Loss: 0.2558 (KL: 0.0569, BCE: 0.4548)
Current LR: 3.00e-04
Saved best student model


Epoch 3: 100%|██████████| 2544/2544 [07:27<00:00,  5.68it/s, loss=0.2526, kl=0.0595, bce=0.4456]



Epoch 3/30
Train Loss: 0.2526 (KL: 0.0595, BCE: 0.4456)
Val Loss: 0.2558 (KL: 0.0619, BCE: 0.4496)
Current LR: 3.00e-04
Saved best student model


Epoch 4: 100%|██████████| 2544/2544 [07:31<00:00,  5.63it/s, loss=0.2497, kl=0.0603, bce=0.4392]



Epoch 4/30
Train Loss: 0.2497 (KL: 0.0603, BCE: 0.4392)
Val Loss: 0.2555 (KL: 0.0612, BCE: 0.4499)
Current LR: 3.00e-04
Saved best student model


Epoch 5: 100%|██████████| 2544/2544 [07:28<00:00,  5.67it/s, loss=0.2465, kl=0.0612, bce=0.4319]



Epoch 5/30
Train Loss: 0.2465 (KL: 0.0612, BCE: 0.4319)
Val Loss: 0.2570 (KL: 0.0575, BCE: 0.4566)
Current LR: 3.00e-04


Epoch 6: 100%|██████████| 2544/2544 [07:29<00:00,  5.66it/s, loss=0.2427, kl=0.0622, bce=0.4231]



Epoch 6/30
Train Loss: 0.2427 (KL: 0.0622, BCE: 0.4231)
Val Loss: 0.2586 (KL: 0.0632, BCE: 0.4540)
Current LR: 3.00e-04


Epoch 7: 100%|██████████| 2544/2544 [07:34<00:00,  5.60it/s, loss=0.2385, kl=0.0634, bce=0.4136]



Epoch 7/30
Train Loss: 0.2385 (KL: 0.0634, BCE: 0.4136)
Val Loss: 0.2636 (KL: 0.0627, BCE: 0.4645)
Current LR: 3.00e-04


Epoch 8: 100%|██████████| 2544/2544 [07:34<00:00,  5.60it/s, loss=0.2345, kl=0.0646, bce=0.4045]



Epoch 8/30
Train Loss: 0.2345 (KL: 0.0646, BCE: 0.4045)
Val Loss: 0.2675 (KL: 0.0631, BCE: 0.4718)
Current LR: 1.50e-04


Epoch 9: 100%|██████████| 2544/2544 [07:36<00:00,  5.57it/s, loss=0.2259, kl=0.0671, bce=0.3846]



Epoch 9/30
Train Loss: 0.2259 (KL: 0.0671, BCE: 0.3846)
Val Loss: 0.2774 (KL: 0.0650, BCE: 0.4897)
Current LR: 1.50e-04


Epoch 10: 100%|██████████| 2544/2544 [07:30<00:00,  5.64it/s, loss=0.2210, kl=0.0686, bce=0.3734]



Epoch 10/30
Train Loss: 0.2210 (KL: 0.0686, BCE: 0.3734)
Val Loss: 0.2877 (KL: 0.0616, BCE: 0.5138)
Current LR: 1.50e-04


Epoch 11: 100%|██████████| 2544/2544 [07:35<00:00,  5.58it/s, loss=0.2179, kl=0.0696, bce=0.3661]



Epoch 11/30
Train Loss: 0.2179 (KL: 0.0696, BCE: 0.3661)
Val Loss: 0.2961 (KL: 0.0672, BCE: 0.5250)
Current LR: 1.50e-04
Early stopping after 11 epochs

Evaluating student model on test set:

=== Evaluation Results ===
Strict Accuracy: 0.0132
Mean Accuracy: 0.8977
Top-5 Accuracy: 0.9106
Top-10 Accuracy: 0.7392
Top-20 Accuracy: 0.4549

Per-Attribute Metrics (Top 10 by F1 score):
-------------------------------------------------------------------------------------
Attribute                Accuracy  Precision Recall    F1        Support   
-------------------------------------------------------------------------------------
No_Beard                 0.9562    0.9579    0.9923    0.9748       17041.0
Eyeglasses               0.9966    0.9788    0.9682    0.9735        1289.0
Male                     0.9782    0.9599    0.9846    0.9721        7715.0
Wearing_Lipstick         0.9428    0.9209    0.9742    0.9468       10418.0
Mouth_Slightly_Open      0.9354    0.9345    0.9350    0.9348     

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
import numpy as np
from collections import OrderedDict

# Configuration
class Config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 40
    batch_size = 128
    lr = 3e-4
    max_epochs = 30
    temperature = 3.0  # Increased temperature for softer probabilities
    alpha = 0.3       # Weight for attention transfer
    beta = 0.3        # Weight for distillation loss
    patience = 7
    grad_clip = 1.0
    grad_accum_steps = 2
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    attribute_names = [
        '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
        'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
        'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
        'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
        'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
        'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks',
        'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
        'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
    ]

# Teacher Model
class TeacherModel(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.base_model = models.resnet50(weights=None)

        # Replace classifier head
        self.base_model.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.base_model.fc.in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes))

        # Attention maps storage
        self.attention_maps = {}
        self._register_hooks()

    def _register_hooks(self):
        def get_activation(name):
            def hook(model, input, output):
                # Store attention maps after ReLU activation
                self.attention_maps[name] = output.detach()
            return hook

        # Register hooks for intermediate layers
        self.base_model.layer1.register_forward_hook(get_activation('layer1'))
        self.base_model.layer2.register_forward_hook(get_activation('layer2'))
        self.base_model.layer3.register_forward_hook(get_activation('layer3'))
        self.base_model.layer4.register_forward_hook(get_activation('layer4'))

    def forward(self, x):
        self.attention_maps.clear()
        return self.base_model(x)

# Student Model
class StudentModel(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.base_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

        # Replace classifier head
        self.base_model.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.base_model.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes))

        self.attention_maps = {}
        self._register_hooks()

    def _register_hooks(self):
        def get_activation(name):
            def hook(model, input, output):
                self.attention_maps[name] = output
            return hook

        self.base_model.layer1.register_forward_hook(get_activation('layer1'))
        self.base_model.layer2.register_forward_hook(get_activation('layer2'))
        self.base_model.layer3.register_forward_hook(get_activation('layer3'))
        self.base_model.layer4.register_forward_hook(get_activation('layer4'))

    def forward(self, x):
        self.attention_maps.clear()
        return self.base_model(x)

# Corrected Attention Transfer Loss
class AttentionTransferLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.3, beta=0.3):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha  # Weight for attention transfer
        self.beta = beta    # Weight for distillation loss
        self.bce_loss = nn.BCEWithLogitsLoss()

    def _attention_transfer(self, teacher_feat, student_feat):
        """Compute attention transfer loss between teacher and student features"""
        # Get attention maps using squared activations
        teacher_attention = torch.sum(teacher_feat.pow(2), dim=1)
        student_attention = torch.sum(student_feat.pow(2), dim=1)

        # Normalize attention maps
        teacher_attention = F.normalize(teacher_attention.view(teacher_attention.size(0), -1), p=2, dim=1)
        student_attention = F.normalize(student_attention.view(student_attention.size(0), -1), p=2, dim=1)

        return F.mse_loss(student_attention, teacher_attention)

    def forward(self, student_logits, teacher_logits, student_maps, teacher_maps, targets):
        # Attention transfer loss
        attn_loss = 0
        layer_weights = {'layer1': 0.1, 'layer2': 0.2, 'layer3': 0.3, 'layer4': 0.4}

        for layer in layer_weights:
            if layer in teacher_maps and layer in student_maps:
                attn_loss += layer_weights[layer] * self._attention_transfer(
                    teacher_maps[layer],
                    student_maps[layer]
                )

        # Distillation loss (KL divergence between probabilities)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)

        kl_loss = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction='batchmean'
        ) * (self.temperature ** 2)  # Scale by temperature squared

        # BCE loss
        bce_loss = self.bce_loss(student_logits, targets)

        # Combined loss (all components should be positive)
        total_loss = (self.alpha * attn_loss +
                     self.beta * kl_loss +
                     (1 - self.alpha - self.beta) * bce_loss)

        return total_loss, {
            'attn_loss': attn_loss.item(),
            'kl_loss': kl_loss.item(),
            'bce_loss': bce_loss.item()
        }

# Training Function
def train_with_attention_transfer(student, teacher, train_loader, val_loader):
    optimizer = optim.AdamW(student.parameters(), lr=Config.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    criterion = AttentionTransferLoss(Config.temperature, Config.alpha, Config.beta)
    scaler = torch.cuda.amp.GradScaler()

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(Config.max_epochs):
        student.train()
        train_loss = 0
        loss_stats = {'attn_loss': 0, 'kl_loss': 0, 'bce_loss': 0}

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for step, (images, labels) in enumerate(pbar):
            images = images.to(Config.device)
            labels = labels.to(Config.device).float()

            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                # Forward pass
                student_logits = student(images)
                with torch.no_grad():
                    teacher_logits = teacher(images)

                # Compute loss
                loss, loss_dict = criterion(
                    student_logits, teacher_logits,
                    student.attention_maps, teacher.attention_maps,
                    labels
                )
                loss = loss / Config.grad_accum_steps

            # Backward pass with gradient accumulation
            scaler.scale(loss).backward()

            if (step + 1) % Config.grad_accum_steps == 0:
                # Gradient clipping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(student.parameters(), Config.grad_clip)

                # Update weights
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            # Update statistics
            train_loss += loss.item() * Config.grad_accum_steps
            for k in loss_stats:
                loss_stats[k] += loss_dict[k]

            # Update progress bar
            pbar.set_postfix({
                'loss': f"{train_loss/(step+1):.4f}",
                'attn': f"{loss_stats['attn_loss']/(step+1):.4f}",
                'kl': f"{loss_stats['kl_loss']/(step+1):.4f}",
                'bce': f"{loss_stats['bce_loss']/(step+1):.4f}"
            })

        # Validation
        student.eval()
        val_loss = 0
        val_stats = {'attn_loss': 0, 'kl_loss': 0, 'bce_loss': 0}

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(Config.device)
                labels = labels.to(Config.device).float()

                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    student_logits = student(images)
                    teacher_logits = teacher(images)

                    _, loss_dict = criterion(
                        student_logits, teacher_logits,
                        student.attention_maps, teacher.attention_maps,
                        labels
                    )

                val_loss += loss_dict['attn_loss'] + loss_dict['kl_loss'] + loss_dict['bce_loss']
                for k in val_stats:
                    val_stats[k] += loss_dict[k]

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        # Update learning rate
        scheduler.step(avg_val_loss)

        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{Config.max_epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} (Attn: {loss_stats['attn_loss']/len(train_loader):.4f}, "
              f"KL: {loss_stats['kl_loss']/len(train_loader):.4f}, BCE: {loss_stats['bce_loss']/len(train_loader):.4f})")
        print(f"Val Loss: {avg_val_loss:.4f} (Attn: {val_stats['attn_loss']/len(val_loader):.4f}, "
              f"KL: {val_stats['kl_loss']/len(val_loader):.4f}, BCE: {val_stats['bce_loss']/len(val_loader):.4f})")

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'model': student.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'val_loss': best_val_loss
            }, 'best_student.pth')
            print("Saved best model")
        else:
            patience_counter += 1
            if patience_counter >= Config.patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Load best model
    checkpoint = torch.load('best_student.pth')
    student.load_state_dict(checkpoint['model'])
    return student

# Evaluation Function
def evaluate_model(model, loader, attribute_names=None):
    model.eval()
    results = {
        'strict_acc': 0.0,
        'mean_acc': 0.0,
        'top_k_acc': {5: 0.0, 10: 0.0, 20: 0.0},
        'per_attribute_acc': np.zeros(Config.num_classes),
        'confusion_matrix': np.zeros((2, 2, Config.num_classes))
    }
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(Config.device)
            labels = labels.to(Config.device).float()
            batch_size = images.size(0)

            outputs = model(images)
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            # Calculate metrics
            results['strict_acc'] += (preds == labels).all(dim=1).sum().item()
            results['mean_acc'] += (preds == labels).float().mean(dim=1).sum().item()

            # Top-k accuracy
            for k in results['top_k_acc']:
                _, topk = torch.topk(probs, k, dim=1)
                correct = torch.gather(labels, 1, topk).sum(dim=1)
                results['top_k_acc'][k] += (correct.float() / k).sum().item()

            # Per-attribute metrics
            for attr in range(Config.num_classes):
                attr_pred = preds[:, attr]
                attr_label = labels[:, attr]
                results['per_attribute_acc'][attr] += (attr_pred == attr_label).sum().item()

                # Confusion matrix
                tp = ((attr_pred == 1) & (attr_label == 1)).sum().item()
                fp = ((attr_pred == 1) & (attr_label == 0)).sum().item()
                fn = ((attr_pred == 0) & (attr_label == 1)).sum().item()
                tn = ((attr_pred == 0) & (attr_label == 0)).sum().item()
                results['confusion_matrix'][0, 0, attr] += tp
                results['confusion_matrix'][0, 1, attr] += fp
                results['confusion_matrix'][1, 0, attr] += fn
                results['confusion_matrix'][1, 1, attr] += tn

            total += batch_size

    # Normalize metrics
    results['strict_acc'] /= total
    results['mean_acc'] /= total
    for k in results['top_k_acc']:
        results['top_k_acc'][k] /= total
    results['per_attribute_acc'] /= total

    # Calculate per-attribute metrics
    attr_metrics = {}
    for attr in range(Config.num_classes):
        tp = results['confusion_matrix'][0, 0, attr]
        fp = results['confusion_matrix'][0, 1, attr]
        fn = results['confusion_matrix'][1, 0, attr]
        tn = results['confusion_matrix'][1, 1, attr]

        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-10)

        name = attribute_names[attr] if attribute_names else f"Attr{attr}"
        attr_metrics[name] = {
            'accuracy': (tp + tn) / (tp + tn + fp + fn),
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'support': tp + fn
        }

    # Print results
    print("\n=== Evaluation Results ===")
    print(f"Strict Accuracy: {results['strict_acc']:.4f}")
    print(f"Mean Accuracy: {results['mean_acc']:.4f}")
    for k in sorted(results['top_k_acc'].keys()):
        print(f"Top-{k} Accuracy: {results['top_k_acc'][k]:.4f}")

    if attribute_names:
        print("\nTop 10 Attributes by F1 Score:")
        print("-" * 85)
        print(f"{'Attribute':<25}{'Accuracy':<10}{'Precision':<10}{'Recall':<10}{'F1':<10}{'Support':<10}")
        print("-" * 85)
        sorted_attrs = sorted(attr_metrics.items(), key=lambda x: x[1]['f1'], reverse=True)[:10]
        for name, metrics in sorted_attrs:
            print(f"{name:<25}{metrics['accuracy']:.4f}    {metrics['precision']:.4f}    "
                  f"{metrics['recall']:.4f}    {metrics['f1']:.4f}    {metrics['support']:>10}")

    return results, attr_metrics

# Main execution
if __name__ == "__main__":
    # Initialize data loaders (replace with your actual data loaders)
    # train_loader = ...
    # val_loader = ...
    # test_loader = ...

    # Initialize models
    print("Initializing models...")
    teacher = TeacherModel(Config.num_classes).to(Config.device)
    student = StudentModel(Config.num_classes).to(Config.device)

    # Load teacher weights with proper handling
    print("Loading teacher weights...")
    teacher_weights = torch.load('/content/drive/MyDrive/best_model.pth', map_location=Config.device)

    # Handle potential DataParallel and key mismatches
    if all(k.startswith('module.') for k in teacher_weights.keys()):
        teacher_weights = OrderedDict([(k.replace('module.', ''), v) for k, v in teacher_weights.items()])

    # Load weights into model
    teacher.load_state_dict(teacher_weights, strict=False)
    teacher.eval()

    # Train with attention transfer
    print("\nStarting attention transfer training...")
    trained_student = train_with_attention_transfer(student, teacher, train_loader, valid_loader)

    # Evaluate
    print("\nEvaluating on test set...")
    test_results, attr_metrics = evaluate_model(trained_student, test_loader, Config.attribute_names)

    # Save final model
    torch.save(trained_student.state_dict(), 'final_student.pth')
    print("\nTraining complete. Saved final student model.")

Initializing models...
Loading teacher weights...


  scaler = torch.cuda.amp.GradScaler()



Starting attention transfer training...


Epoch 1: 100%|██████████| 2544/2544 [05:25<00:00,  7.81it/s, loss=0.4468, attn=0.0007, kl=0.2653, bce=0.9174]



Epoch 1/30
Train Loss: 0.4468 (Attn: 0.0007, KL: 0.2653, BCE: 0.9174)
Val Loss: 1.0881 (Attn: 0.0004, KL: 0.2206, BCE: 0.8671)
Saved best model


Epoch 2: 100%|██████████| 2544/2544 [05:22<00:00,  7.88it/s, loss=0.4229, attn=0.0004, kl=0.2194, bce=0.8924]



Epoch 2/30
Train Loss: 0.4229 (Attn: 0.0004, KL: 0.2194, BCE: 0.8924)
Val Loss: 1.0723 (Attn: 0.0004, KL: 0.1957, BCE: 0.8762)
Saved best model


Epoch 3: 100%|██████████| 2544/2544 [05:23<00:00,  7.86it/s, loss=0.4196, attn=0.0003, kl=0.2187, bce=0.8847]



Epoch 3/30
Train Loss: 0.4196 (Attn: 0.0003, KL: 0.2187, BCE: 0.8847)
Val Loss: 1.0809 (Attn: 0.0003, KL: 0.2264, BCE: 0.8542)


Epoch 4: 100%|██████████| 2544/2544 [05:22<00:00,  7.88it/s, loss=0.4175, attn=0.0003, kl=0.2199, bce=0.8787]



Epoch 4/30
Train Loss: 0.4175 (Attn: 0.0003, KL: 0.2199, BCE: 0.8787)
Val Loss: 1.0754 (Attn: 0.0003, KL: 0.2134, BCE: 0.8616)


Epoch 5: 100%|██████████| 2544/2544 [05:22<00:00,  7.88it/s, loss=0.4153, attn=0.0003, kl=0.2207, bce=0.8726]



Epoch 5/30
Train Loss: 0.4153 (Attn: 0.0003, KL: 0.2207, BCE: 0.8726)
Val Loss: 1.0923 (Attn: 0.0003, KL: 0.2518, BCE: 0.8402)


Epoch 6: 100%|██████████| 2544/2544 [05:21<00:00,  7.91it/s, loss=0.4130, attn=0.0004, kl=0.2222, bce=0.8655]



Epoch 6/30
Train Loss: 0.4130 (Attn: 0.0004, KL: 0.2222, BCE: 0.8655)
Val Loss: 1.0797 (Attn: 0.0004, KL: 0.2243, BCE: 0.8550)


Epoch 7: 100%|██████████| 2544/2544 [05:21<00:00,  7.90it/s, loss=0.4073, attn=0.0004, kl=0.2230, bce=0.8506]



Epoch 7/30
Train Loss: 0.4073 (Attn: 0.0004, KL: 0.2230, BCE: 0.8506)
Val Loss: 1.0815 (Attn: 0.0004, KL: 0.2184, BCE: 0.8628)


Epoch 8: 100%|██████████| 2544/2544 [05:22<00:00,  7.89it/s, loss=0.4043, attn=0.0004, kl=0.2259, bce=0.8410]



Epoch 8/30
Train Loss: 0.4043 (Attn: 0.0004, KL: 0.2259, BCE: 0.8410)
Val Loss: 1.0865 (Attn: 0.0004, KL: 0.2284, BCE: 0.8578)


Epoch 9: 100%|██████████| 2544/2544 [05:22<00:00,  7.88it/s, loss=0.4016, attn=0.0004, kl=0.2276, bce=0.8331]



Epoch 9/30
Train Loss: 0.4016 (Attn: 0.0004, KL: 0.2276, BCE: 0.8331)
Val Loss: 1.0905 (Attn: 0.0004, KL: 0.2340, BCE: 0.8561)
Early stopping at epoch 9

Evaluating on test set...

=== Evaluation Results ===
Strict Accuracy: 0.0000
Mean Accuracy: 0.6929
Top-5 Accuracy: 0.2111
Top-10 Accuracy: 0.2963
Top-20 Accuracy: 0.3463

Top 10 Attributes by F1 Score:
-------------------------------------------------------------------------------------
Attribute                Accuracy  Precision Recall    F1        Support   
-------------------------------------------------------------------------------------
Mouth_Slightly_Open      0.7373    0.6728    0.9137    0.7750        9883.0
High_Cheekbones          0.4818    0.4818    1.0000    0.6503        9618.0
Heavy_Makeup             0.4995    0.4471    0.9964    0.6172        8084.0
Smiling                  0.6402    0.9996    0.2810    0.4386        9987.0
Oval_Face                0.7261    0.5613    0.3367    0.4209        5901.0
Bangs         

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
import numpy as np
from collections import OrderedDict

# Configuration
class Config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 40
    batch_size = 64
    lr = 5e-5  # Reduced learning rate
    max_epochs = 50
    temperature = 1.0  # More conservative temperature
    alpha = 0.2  # Reduced attention transfer weight
    beta = 0.2   # Reduced distillation weight
    patience = 10
    grad_clip = 0.1  # Tighter gradient clipping
    grad_accum_steps = 4
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    # Class weights - should be calculated from your dataset statistics
    pos_weight = torch.tensor([
        0.5, 0.7, 0.6, 0.8, 0.3, 0.4, 0.6, 0.7, 0.5, 0.4,
        0.3, 0.5, 0.4, 0.6, 0.3, 0.2, 0.3, 0.4, 0.5, 0.6,
        0.7, 0.5, 0.3, 0.4, 0.8, 0.5, 0.3, 0.5, 0.4, 0.3,
        0.3, 0.6, 0.5, 0.5, 0.4, 0.3, 0.6, 0.4, 0.3, 0.7
    ]).to(device)

    attribute_names = [
        '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
        'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
        'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
        'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
        'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
        'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks',
        'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
        'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
    ]

# Teacher Model
class TeacherModel(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.base_model = models.resnet50(weights=None)

        self.base_model.fc = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(self.base_model.fc.in_features, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, num_classes)
        )

        self.attention_maps = {}
        self._register_hooks()

    def _register_hooks(self):
        def get_activation(name):
            def hook(model, input, output):
                self.attention_maps[name] = output.detach()
            return hook

        self.base_model.layer1.register_forward_hook(get_activation('layer1'))
        self.base_model.layer2.register_forward_hook(get_activation('layer2'))
        self.base_model.layer3.register_forward_hook(get_activation('layer3'))
        self.base_model.layer4.register_forward_hook(get_activation('layer4'))

    def forward(self, x):
        self.attention_maps.clear()
        return self.base_model(x)

# Student Model
class StudentModel(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.base_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

        self.base_model.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.base_model.fc.in_features, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, num_classes)
        )

        self.attention_maps = {}
        self._register_hooks()

    def _register_hooks(self):
        def get_activation(name):
            def hook(model, input, output):
                self.attention_maps[name] = output
            return hook

        self.base_model.layer1.register_forward_hook(get_activation('layer1'))
        self.base_model.layer2.register_forward_hook(get_activation('layer2'))
        self.base_model.layer3.register_forward_hook(get_activation('layer3'))
        self.base_model.layer4.register_forward_hook(get_activation('layer4'))

    def forward(self, x):
        self.attention_maps.clear()
        return self.base_model(x)

# Corrected Attention Transfer Loss
class AttentionTransferLoss(nn.Module):
    def __init__(self, temperature=1.0, alpha=0.2, beta=0.2):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.beta = beta
        self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=Config.pos_weight)

    def _attention_transfer(self, teacher_feat, student_feat):
        """Handle different feature map sizes safely"""
        # Resize teacher features to match student dimensions
        if teacher_feat.shape[2:] != student_feat.shape[2:]:
            teacher_feat = F.adaptive_avg_pool2d(teacher_feat, student_feat.shape[2:])

        # Compute normalized attention maps
        teacher_attention = F.normalize(teacher_feat.pow(2).mean(1).view(teacher_feat.size(0), -1), p=2, dim=1)
        student_attention = F.normalize(student_feat.pow(2).mean(1).view(student_feat.size(0), -1), p=2, dim=1)

        return F.mse_loss(student_attention, teacher_attention)

    def forward(self, student_logits, teacher_logits, student_maps, teacher_maps, targets):
        # Attention transfer loss
        attn_loss = 0
        layer_pairs = [
            ('layer1', 'layer1', 0.1),
            ('layer2', 'layer2', 0.2),
            ('layer3', 'layer3', 0.3),
            ('layer4', 'layer4', 0.4)
        ]

        for t_layer, s_layer, weight in layer_pairs:
            if t_layer in teacher_maps and s_layer in student_maps:
                attn_loss += weight * self._attention_transfer(
                    teacher_maps[t_layer],
                    student_maps[s_layer]
                )

        # Stable KL Divergence calculation
        teacher_probs = torch.sigmoid(teacher_logits)
        student_probs = torch.sigmoid(student_logits)

        # Temperature scaling with numerical stability
        teacher_probs = teacher_probs.pow(1/self.temperature)
        teacher_probs = teacher_probs / (teacher_probs.sum(dim=1, keepdim=True) + 1e-10)

        student_probs = student_probs.pow(1/self.temperature)
        student_probs = student_probs / (student_probs.sum(dim=1, keepdim=True) + 1e-10)

        kl_loss = F.kl_div(
            torch.log(student_probs + 1e-10),
            teacher_probs,
            reduction='batchmean'
        )

        # BCE loss
        bce_loss = self.bce_loss(student_logits, targets)

        # Combined loss with safeguards
        total_loss = (self.alpha * attn_loss +
                     self.beta * kl_loss +
                     (1 - self.alpha - self.beta) * bce_loss)

        # Verify all losses are finite
        if not torch.isfinite(total_loss):
            print(f"Warning: Non-finite loss detected - attn: {attn_loss}, kl: {kl_loss}, bce: {bce_loss}")
            total_loss = bce_loss  # Fall back to BCE only if other losses become unstable

        return total_loss, {
            'attn_loss': attn_loss.item() if torch.isfinite(attn_loss) else 0,
            'kl_loss': kl_loss.item() if torch.isfinite(kl_loss) else 0,
            'bce_loss': bce_loss.item()
        }

# Training Function
def train_with_attention_transfer(student, teacher, train_loader, val_loader):
    optimizer = optim.AdamW(student.parameters(), lr=Config.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )
    criterion = AttentionTransferLoss(Config.temperature, Config.alpha, Config.beta)
    scaler = torch.amp.GradScaler()

    best_val_loss = float('inf')
    best_val_acc = 0.0
    patience_counter = 0

    for epoch in range(Config.max_epochs):
        student.train()
        train_loss = 0
        loss_stats = {'attn_loss': 0, 'kl_loss': 0, 'bce_loss': 0}

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for step, (images, labels) in enumerate(pbar):
            images = images.to(Config.device)
            labels = labels.to(Config.device).float()

            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                student_logits = student(images)
                with torch.no_grad():
                    teacher_logits = teacher(images)

                loss, loss_dict = criterion(
                    student_logits, teacher_logits,
                    student.attention_maps, teacher.attention_maps,
                    labels
                )

                if not torch.isfinite(loss):
                    print(f"Non-finite loss detected at step {step}")
                    optimizer.zero_grad()
                    break

                loss = loss / Config.grad_accum_steps

            scaler.scale(loss).backward()

            if (step + 1) % Config.grad_accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(student.parameters(), Config.grad_clip)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_loss += loss.item() * Config.grad_accum_steps
            for k in loss_stats:
                loss_stats[k] += loss_dict[k]

            pbar.set_postfix({
                'loss': f"{train_loss/(step+1):.4f}",
                'attn': f"{loss_stats['attn_loss']/(step+1):.4f}",
                'kl': f"{loss_stats['kl_loss']/(step+1):.4f}",
                'bce': f"{loss_stats['bce_loss']/(step+1):.4f}"
            })

        # Validation
        student.eval()
        val_loss = 0
        val_stats = {'attn_loss': 0, 'kl_loss': 0, 'bce_loss': 0}
        val_mean_acc = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(Config.device)
                labels = labels.to(Config.device).float()

                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    student_logits = student(images)
                    teacher_logits = teacher(images)

                    _, loss_dict = criterion(
                        student_logits, teacher_logits,
                        student.attention_maps, teacher.attention_maps,
                        labels
                    )

                preds = (torch.sigmoid(student_logits) > 0.5).float()
                val_mean_acc += (preds == labels).float().mean(dim=1).sum().item()

                val_loss += loss_dict['attn_loss'] + loss_dict['kl_loss'] + loss_dict['bce_loss']
                for k in val_stats:
                    val_stats[k] += loss_dict[k]

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        val_mean_acc /= len(val_loader.dataset)

        scheduler.step(avg_val_loss)

        print(f"\nEpoch {epoch+1}/{Config.max_epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} (Attn: {loss_stats['attn_loss']/len(train_loader):.4f}, "
              f"KL: {loss_stats['kl_loss']/len(train_loader):.4f}, BCE: {loss_stats['bce_loss']/len(train_loader):.4f})")
        print(f"Val Loss: {avg_val_loss:.4f} (Attn: {val_stats['attn_loss']/len(val_loader):.4f}, "
              f"KL: {val_stats['kl_loss']/len(val_loader):.4f}, BCE: {val_stats['bce_loss']/len(val_loader):.4f})")
        print(f"Val Mean Accuracy: {val_mean_acc:.4f}")

        # Early stopping
        if avg_val_loss < best_val_loss or val_mean_acc > best_val_acc:
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
            if val_mean_acc > best_val_acc:
                best_val_acc = val_mean_acc
            patience_counter = 0
            torch.save({
                'model': student.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'val_loss': best_val_loss,
                'val_acc': best_val_acc
            }, 'best_student.pth')
            print("Saved best model")
        else:
            patience_counter += 1
            if patience_counter >= Config.patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    checkpoint = torch.load('best_student.pth')
    student.load_state_dict(checkpoint['model'])
    return student

# Evaluation Function
def evaluate_model(model, loader, attribute_names=None):
    model.eval()
    results = {
        'strict_acc': 0.0,
        'mean_acc': 0.0,
        'top_k_acc': {3: 0.0, 5: 0.0, 10: 0.0},
        'per_attribute_acc': np.zeros(Config.num_classes),
        'per_attribute_f1': np.zeros(Config.num_classes),
        'confusion_matrix': np.zeros((2, 2, Config.num_classes))
    }
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(Config.device)
            labels = labels.to(Config.device).float()
            batch_size = images.size(0)

            outputs = model(images)
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            results['strict_acc'] += (preds == labels).all(dim=1).sum().item()
            results['mean_acc'] += (preds == labels).float().mean(dim=1).sum().item()

            for k in results['top_k_acc']:
                _, topk = torch.topk(probs, k, dim=1)
                correct = torch.gather(labels, 1, topk).sum(dim=1)
                results['top_k_acc'][k] += (correct.float() / k).sum().item()

            for attr in range(Config.num_classes):
                attr_pred = preds[:, attr]
                attr_label = labels[:, attr]
                results['per_attribute_acc'][attr] += (attr_pred == attr_label).sum().item()

                tp = ((attr_pred == 1) & (attr_label == 1)).sum().item()
                fp = ((attr_pred == 1) & (attr_label == 0)).sum().item()
                fn = ((attr_pred == 0) & (attr_label == 1)).sum().item()
                tn = ((attr_pred == 0) & (attr_label == 0)).sum().item()
                results['confusion_matrix'][0, 0, attr] += tp
                results['confusion_matrix'][0, 1, attr] += fp
                results['confusion_matrix'][1, 0, attr] += fn
                results['confusion_matrix'][1, 1, attr] += tn

                precision = tp / (tp + fp + 1e-10)
                recall = tp / (tp + fn + 1e-10)
                f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
                results['per_attribute_f1'][attr] += f1 * batch_size

            total += batch_size

    results['strict_acc'] /= total
    results['mean_acc'] /= total
    for k in results['top_k_acc']:
        results['top_k_acc'][k] /= total
    results['per_attribute_acc'] /= total
    results['per_attribute_f1'] /= total

    attr_metrics = {}
    for attr in range(Config.num_classes):
        tp = results['confusion_matrix'][0, 0, attr]
        fp = results['confusion_matrix'][0, 1, attr]
        fn = results['confusion_matrix'][1, 0, attr]
        tn = results['confusion_matrix'][1, 1, attr]

        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-10)

        name = attribute_names[attr] if attribute_names else f"Attr{attr}"
        attr_metrics[name] = {
            'accuracy': (tp + tn) / (tp + tn + fp + fn),
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'support': tp + fn
        }

    print("\n=== Evaluation Results ===")
    print(f"Strict Accuracy: {results['strict_acc']:.4f}")
    print(f"Mean Accuracy: {results['mean_acc']:.4f}")
    for k in sorted(results['top_k_acc'].keys()):
        print(f"Top-{k} Accuracy: {results['top_k_acc'][k]:.4f}")

    if attribute_names:
        print("\nTop 10 Attributes by F1 Score:")
        print("-" * 85)
        print(f"{'Attribute':<25}{'Accuracy':<10}{'Precision':<10}{'Recall':<10}{'F1':<10}{'Support':<10}")
        print("-" * 85)
        sorted_attrs = sorted(attr_metrics.items(), key=lambda x: x[1]['f1'], reverse=True)[:10]
        for name, metrics in sorted_attrs:
            print(f"{name:<25}{metrics['accuracy']:.4f}    {metrics['precision']:.4f}    "
                  f"{metrics['recall']:.4f}    {metrics['f1']:.4f}    {metrics['support']:>10}")

    return results, attr_metrics

# Main execution
if __name__ == "__main__":
    # Initialize data loaders (replace with your actual data loaders)
    # train_loader = ...
    # val_loader = ...
    # test_loader = ...

    # Initialize models
    print("Initializing models...")
    teacher = TeacherModel(Config.num_classes).to(Config.device)
    student = StudentModel(Config.num_classes).to(Config.device)

    # Load teacher weights
    print("Loading teacher weights...")
    teacher_weights = torch.load('/content/drive/MyDrive/best_model.pth', map_location=Config.device)

    if all(k.startswith('module.') for k in teacher_weights.keys()):
        teacher_weights = OrderedDict([(k.replace('module.', ''), v) for k, v in teacher_weights.items()])

    teacher.load_state_dict(teacher_weights, strict=False)
    teacher.eval()

    # Train with attention transfer
    print("\nStarting attention transfer training...")
    trained_student = train_with_attention_transfer(student, teacher, train_loader, valid_loader)

    # Evaluate
    print("\nEvaluating on test set...")
    test_results, attr_metrics = evaluate_model(trained_student, test_loader, Config.attribute_names)

    # Save final model
    torch.save({
        'model': trained_student.state_dict(),
        'config': Config.__dict__,
        'test_results': test_results
    }, 'final_student.pth')
    print("\nTraining complete. Saved final student model.")

Initializing models...




Loading teacher weights...

Starting attention transfer training...


Epoch 1: 100%|██████████| 2544/2544 [05:28<00:00,  7.74it/s, loss=0.3285, attn=0.0032, kl=0.1486, bce=0.4969]



Epoch 1/50
Train Loss: 0.3285 (Attn: 0.0032, KL: 0.1486, BCE: 0.4969)
Val Loss: 0.5367 (Attn: 0.0025, KL: 0.1515, BCE: 0.3827)
Val Mean Accuracy: 0.8921
Saved best model


Epoch 2: 100%|██████████| 2544/2544 [05:28<00:00,  7.74it/s, loss=0.2233, attn=0.0021, kl=0.1821, bce=0.3107]



Epoch 2/50
Train Loss: 0.2233 (Attn: 0.0021, KL: 0.1821, BCE: 0.3107)
Val Loss: 0.4593 (Attn: 0.0018, KL: 0.2098, BCE: 0.2477)
Val Mean Accuracy: 0.8886
Saved best model


Epoch 3: 100%|██████████| 2544/2544 [05:29<00:00,  7.73it/s, loss=0.1903, attn=0.0016, kl=0.2180, bce=0.2440]



Epoch 3/50
Train Loss: 0.1903 (Attn: 0.0016, KL: 0.2180, BCE: 0.2440)
Val Loss: 0.4553 (Attn: 0.0015, KL: 0.2228, BCE: 0.2311)
Val Mean Accuracy: 0.8764
Saved best model


Epoch 4: 100%|██████████| 2544/2544 [05:32<00:00,  7.64it/s, loss=0.1854, attn=0.0013, kl=0.2108, bce=0.2382]



Epoch 4/50
Train Loss: 0.1854 (Attn: 0.0013, KL: 0.2108, BCE: 0.2382)
Val Loss: 0.4368 (Attn: 0.0012, KL: 0.2014, BCE: 0.2342)
Val Mean Accuracy: 0.8654
Saved best model


Epoch 5: 100%|██████████| 2544/2544 [05:33<00:00,  7.62it/s, loss=0.1835, attn=0.0012, kl=0.2042, bce=0.2373]



Epoch 5/50
Train Loss: 0.1835 (Attn: 0.0012, KL: 0.2042, BCE: 0.2373)
Val Loss: 0.4337 (Attn: 0.0011, KL: 0.1988, BCE: 0.2338)
Val Mean Accuracy: 0.8625
Saved best model


Epoch 6: 100%|██████████| 2544/2544 [05:33<00:00,  7.63it/s, loss=0.1822, attn=0.0011, kl=0.2007, bce=0.2364]



Epoch 6/50
Train Loss: 0.1822 (Attn: 0.0011, KL: 0.2007, BCE: 0.2364)
Val Loss: 0.4368 (Attn: 0.0010, KL: 0.2044, BCE: 0.2313)
Val Mean Accuracy: 0.8651


Epoch 7: 100%|██████████| 2544/2544 [05:34<00:00,  7.61it/s, loss=0.1812, attn=0.0010, kl=0.1998, bce=0.2350]



Epoch 7/50
Train Loss: 0.1812 (Attn: 0.0010, KL: 0.1998, BCE: 0.2350)
Val Loss: 0.4342 (Attn: 0.0010, KL: 0.2008, BCE: 0.2324)
Val Mean Accuracy: 0.8643


Epoch 8: 100%|██████████| 2544/2544 [05:33<00:00,  7.63it/s, loss=0.1803, attn=0.0010, kl=0.1995, bce=0.2336]



Epoch 8/50
Train Loss: 0.1803 (Attn: 0.0010, KL: 0.1995, BCE: 0.2336)
Val Loss: 0.4289 (Attn: 0.0010, KL: 0.1939, BCE: 0.2340)
Val Mean Accuracy: 0.8600
Saved best model


Epoch 9: 100%|██████████| 2544/2544 [05:29<00:00,  7.72it/s, loss=0.1793, attn=0.0010, kl=0.2004, bce=0.2318]



Epoch 9/50
Train Loss: 0.1793 (Attn: 0.0010, KL: 0.2004, BCE: 0.2318)
Val Loss: 0.4419 (Attn: 0.0010, KL: 0.2042, BCE: 0.2368)
Val Mean Accuracy: 0.8627


Epoch 10: 100%|██████████| 2544/2544 [05:29<00:00,  7.72it/s, loss=0.1782, attn=0.0009, kl=0.2007, bce=0.2297]



Epoch 10/50
Train Loss: 0.1782 (Attn: 0.0009, KL: 0.2007, BCE: 0.2297)
Val Loss: 0.4313 (Attn: 0.0009, KL: 0.1972, BCE: 0.2332)
Val Mean Accuracy: 0.8613


Epoch 11: 100%|██████████| 2544/2544 [05:28<00:00,  7.75it/s, loss=0.1770, attn=0.0009, kl=0.2016, bce=0.2275]



Epoch 11/50
Train Loss: 0.1770 (Attn: 0.0009, KL: 0.2016, BCE: 0.2275)
Val Loss: 0.4408 (Attn: 0.0010, KL: 0.2068, BCE: 0.2331)
Val Mean Accuracy: 0.8681


Epoch 12: 100%|██████████| 2544/2544 [05:27<00:00,  7.76it/s, loss=0.1760, attn=0.0010, kl=0.2036, bce=0.2252]



Epoch 12/50
Train Loss: 0.1760 (Attn: 0.0010, KL: 0.2036, BCE: 0.2252)
Val Loss: 0.4384 (Attn: 0.0009, KL: 0.2056, BCE: 0.2318)
Val Mean Accuracy: 0.8657


Epoch 13: 100%|██████████| 2544/2544 [05:29<00:00,  7.71it/s, loss=0.1744, attn=0.0010, kl=0.2065, bce=0.2214]



Epoch 13/50
Train Loss: 0.1744 (Attn: 0.0010, KL: 0.2065, BCE: 0.2214)
Val Loss: 0.4394 (Attn: 0.0010, KL: 0.2064, BCE: 0.2320)
Val Mean Accuracy: 0.8664


Epoch 14: 100%|██████████| 2544/2544 [05:29<00:00,  7.73it/s, loss=0.1736, attn=0.0010, kl=0.2090, bce=0.2194]



Epoch 14/50
Train Loss: 0.1736 (Attn: 0.0010, KL: 0.2090, BCE: 0.2194)
Val Loss: 0.4400 (Attn: 0.0010, KL: 0.2062, BCE: 0.2328)
Val Mean Accuracy: 0.8660


Epoch 15: 100%|██████████| 2544/2544 [05:28<00:00,  7.75it/s, loss=0.1730, attn=0.0010, kl=0.2110, bce=0.2177]



Epoch 15/50
Train Loss: 0.1730 (Attn: 0.0010, KL: 0.2110, BCE: 0.2177)
Val Loss: 0.4422 (Attn: 0.0010, KL: 0.2084, BCE: 0.2328)
Val Mean Accuracy: 0.8657


Epoch 16: 100%|██████████| 2544/2544 [05:27<00:00,  7.77it/s, loss=0.1724, attn=0.0010, kl=0.2130, bce=0.2160]



Epoch 16/50
Train Loss: 0.1724 (Attn: 0.0010, KL: 0.2130, BCE: 0.2160)
Val Loss: 0.4469 (Attn: 0.0010, KL: 0.2137, BCE: 0.2321)
Val Mean Accuracy: 0.8689


Epoch 17:   5%|▍         | 120/2544 [00:15<05:21,  7.53it/s, loss=0.1718, attn=0.0010, kl=0.2140, bce=0.2146]


KeyboardInterrupt: 

In [None]:
# Load the model architecture first
student = StudentModel(Config.num_classes).to(Config.device)

# Load weights from checkpoint
checkpoint = torch.load('best_student.pth')  # or 'final_student.pth'
student.load_state_dict(checkpoint['model'])
student.eval()  # Set to evaluation mode
test_results, attr_metrics = evaluate_model(student, test_loader, Config.attribute_names)


=== Evaluation Results ===
Strict Accuracy: 0.0024
Mean Accuracy: 0.8533
Top-3 Accuracy: 0.9419
Top-5 Accuracy: 0.8767
Top-10 Accuracy: 0.6915

Top 10 Attributes by F1 Score:
-------------------------------------------------------------------------------------
Attribute                Accuracy  Precision Recall    F1        Support   
-------------------------------------------------------------------------------------
Wearing_Lipstick         0.8577    0.9866    0.7373    0.8439       10418.0
Mouth_Slightly_Open      0.8532    0.9896    0.7109    0.8274        9883.0
High_Cheekbones          0.8265    0.9699    0.6603    0.7857        9618.0
No_Beard                 0.6959    0.9932    0.6482    0.7844       17041.0
Heavy_Makeup             0.8425    0.9750    0.6270    0.7632        8084.0
Attractive               0.7649    0.9043    0.5880    0.7126        9898.0
Smiling                  0.7647    0.9981    0.5306    0.6929        9987.0
Male                     0.7989    0.9962   