In [36]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm

In [37]:
# Configuration
data_dir = "../DATA_PREPARE_ATT_04/Grayscale_Face_images"
batch_size = 16
num_classes = 8
initial_lr = 1e-3
num_epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the emotion classes
emotion_classes = ["Anger", "Contempt", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]


Using device: cuda


In [38]:
# Prepare Data Loaders
def prepare_data_loaders(data_dir, batch_size):
    transform_train = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    transform_val_test = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    print("Loading datasets...")
    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform_train)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform_val_test)
    test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform_val_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader, train_dataset

In [39]:
# Compute Class Weights
def compute_weights(train_dataset, num_classes):
    targets = [label for _, label in train_dataset.samples]
    class_weights = compute_class_weight('balanced', classes=np.arange(num_classes), y=targets)
    return torch.tensor(class_weights, dtype=torch.float).to(device)

In [40]:
# Load Teacher Model (ResNet-50)
def load_teacher_model(model_path='teacher_model.pth'):
    if os.path.exists(model_path):
        print("Loading teacher model from saved path...")
        teacher_model = models.resnet50()
        teacher_model.fc = nn.Linear(teacher_model.fc.in_features, num_classes)
        teacher_model.load_state_dict(torch.load(model_path))
    else:
        print("Downloading and saving ResNet-50 as teacher model...")
        teacher_model = models.resnet50(pretrained=True)
        teacher_model.fc = nn.Linear(teacher_model.fc.in_features, num_classes)
        torch.save(teacher_model.state_dict(), model_path)
    teacher_model.eval()  # Set to evaluation mode
    return teacher_model

In [41]:
# Load Student Model (MobileNetV2)
def load_student_model(model_path='student_model.pth'):
    if os.path.exists(model_path):
        print("Loading student model from saved path...")
        student_model = models.mobilenet_v2()
        student_model.classifier[1] = nn.Linear(student_model.last_channel, num_classes)
        student_model.load_state_dict(torch.load(model_path))
    else:
        print("Downloading and saving MobileNetV2 as student model...")
        student_model = models.mobilenet_v2(pretrained=True)
        student_model.classifier[1] = nn.Linear(student_model.last_channel, num_classes)
        torch.save(student_model.state_dict(), model_path)
    return student_model

In [42]:
# Knowledge Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=3.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.criterion_ce = nn.CrossEntropyLoss()
        self.criterion_kl = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_outputs, teacher_outputs, labels):
        loss_ce = self.criterion_ce(student_outputs, labels)
        loss_kl = self.criterion_kl(
            torch.log_softmax(student_outputs / self.temperature, dim=1),
            torch.softmax(teacher_outputs / self.temperature, dim=1),
        )
        return self.alpha * loss_ce + (1 - self.alpha) * loss_kl

In [43]:
# Training Loop
def train_model(teacher_model, student_model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
    teacher_model.eval()  # Ensure teacher remains frozen
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        student_model.train()
        train_loss = 0.0

        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass
            student_outputs = student_model(inputs)
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)  # Teacher predictions

            # Compute loss
            loss = criterion(student_outputs, teacher_outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        print(f"Training Loss: {train_loss:.4f}")

        # Validation
        student_model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validating"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student_model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = correct / total
        print(f"Validation Accuracy: {val_accuracy * 100:.2f}%")

        # Save best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(student_model.state_dict(), "mobilenetv2_student_model.pth")
            print(f"New best model saved with accuracy: {val_accuracy * 100:.2f}%")

        scheduler.step()

In [44]:
# Main Script
if __name__ == "__main__":
    # Load Data
    train_loader, val_loader, test_loader, train_dataset = prepare_data_loaders(data_dir, batch_size)

    # Compute Class Weights
    class_weights = compute_weights(train_dataset, num_classes)
    criterion = DistillationLoss(alpha=0.7, temperature=4.0)

    # Load Models
    teacher_model = load_teacher_model().to(device)
    student_model = load_student_model().to(device)

    # Optimizer and Scheduler
    optimizer = optim.AdamW(student_model.parameters(), lr=initial_lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)

    # Train
    train_model(teacher_model, student_model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs)

Loading datasets...
Loading teacher model from saved path...
Loading student model from saved path...

Epoch 1/20


Training: 100%|██████████| 5628/5628 [43:48<00:00,  2.14it/s]  


Training Loss: 1.0383


Validating: 100%|██████████| 1206/1206 [02:16<00:00,  8.84it/s]


Validation Accuracy: 53.14%
New best model saved with accuracy: 53.14%

Epoch 2/20


Training: 100%|██████████| 5628/5628 [31:43<00:00,  2.96it/s]  


Training Loss: 0.9064


Validating: 100%|██████████| 1206/1206 [03:51<00:00,  5.21it/s]


Validation Accuracy: 55.44%
New best model saved with accuracy: 55.44%

Epoch 3/20


Training: 100%|██████████| 5628/5628 [31:50<00:00,  2.95it/s]  


Training Loss: 0.8552


Validating: 100%|██████████| 1206/1206 [02:25<00:00,  8.31it/s]


Validation Accuracy: 59.44%
New best model saved with accuracy: 59.44%

Epoch 4/20


Training: 100%|██████████| 5628/5628 [12:53<00:00,  7.27it/s]


Training Loss: 0.8160


Validating: 100%|██████████| 1206/1206 [00:36<00:00, 32.63it/s]


Validation Accuracy: 60.61%
New best model saved with accuracy: 60.61%

Epoch 5/20


Training: 100%|██████████| 5628/5628 [12:07<00:00,  7.73it/s]


Training Loss: 0.7772


Validating: 100%|██████████| 1206/1206 [00:36<00:00, 33.00it/s]


Validation Accuracy: 62.00%
New best model saved with accuracy: 62.00%

Epoch 6/20


Training: 100%|██████████| 5628/5628 [12:09<00:00,  7.71it/s]


Training Loss: 0.7433


Validating: 100%|██████████| 1206/1206 [00:36<00:00, 32.63it/s]


Validation Accuracy: 62.46%
New best model saved with accuracy: 62.46%

Epoch 7/20


Training: 100%|██████████| 5628/5628 [13:17<00:00,  7.05it/s]


Training Loss: 0.7071


Validating: 100%|██████████| 1206/1206 [00:37<00:00, 32.43it/s]


Validation Accuracy: 64.22%
New best model saved with accuracy: 64.22%

Epoch 8/20


Training: 100%|██████████| 5628/5628 [12:06<00:00,  7.74it/s]


Training Loss: 0.6725


Validating: 100%|██████████| 1206/1206 [00:37<00:00, 31.91it/s]


Validation Accuracy: 64.99%
New best model saved with accuracy: 64.99%

Epoch 9/20


Training: 100%|██████████| 5628/5628 [14:41<00:00,  6.38it/s] 


Training Loss: 0.6431


Validating: 100%|██████████| 1206/1206 [02:18<00:00,  8.74it/s]


Validation Accuracy: 65.91%
New best model saved with accuracy: 65.91%

Epoch 10/20


Training: 100%|██████████| 5628/5628 [36:13<00:00,  2.59it/s]   


Training Loss: 0.6252


Validating: 100%|██████████| 1206/1206 [05:32<00:00,  3.63it/s] 


Validation Accuracy: 66.44%
New best model saved with accuracy: 66.44%

Epoch 11/20


Training: 100%|██████████| 5628/5628 [52:12<00:00,  1.80it/s]  


Training Loss: 0.7697


Validating: 100%|██████████| 1206/1206 [03:03<00:00,  6.56it/s]


Validation Accuracy: 59.24%

Epoch 12/20


Training: 100%|██████████| 5628/5628 [53:55<00:00,  1.74it/s]  


Training Loss: 0.7791


Validating: 100%|██████████| 1206/1206 [02:02<00:00,  9.81it/s]


Validation Accuracy: 59.37%

Epoch 13/20


Training:   1%|          | 44/5628 [00:28<59:58,  1.55it/s]  


KeyboardInterrupt: 