In [26]:
import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
import random
from PIL import Image
import os
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim

# Pair Sampling Dataset Loader

In [27]:
class PairFERDataset(Dataset):
    def __init__(self, root_dir):
        self.dataset = ImageFolder(root=root_dir)
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize((112, 112)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        self.class_to_images = {i: [] for i in range(len(self.dataset.classes))}
        for idx, (_, label) in enumerate(self.dataset.samples):
            self.class_to_images[label].append(idx)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img1, label = self.dataset[idx]
        idx2 = random.choice(self.class_to_images[label])
        img2, _ = self.dataset[idx2]
        return self.transform(img1), self.transform(img2), label

# Model Componenets

In [28]:
class Encoder(nn.Module):
    def __init__(self, output_dim=64):
        super().__init__()
        base = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.feature_extractor = nn.Sequential(*list(base.children())[:-1])  # output: (B, 512, 1, 1)
        self.projector = nn.Linear(512, output_dim)

    def forward(self, x):
        x = self.feature_extractor(x)           # (B, 512, 1, 1)
        x = x.view(x.size(0), -1)               # (B, 512)
        return self.projector(x)                # (B, output_dim)

# Mutual Information Estimator (MINE)

In [30]:
class Discriminator(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, e, i):
        return self.net(torch.cat([e, i], dim=1))

# Discriminator for Adversarial MI Minimization

In [31]:
# Discriminator for Adversarial MI Minimization
class Discriminator(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, e, i):
        return self.net(torch.cat([e, i], dim=1))

# Mutual Information Estimation (MINE Loss)

In [36]:
def estimate_mi(mine, x, y):
    joint = mine(x, y).mean()
    y_perm = y[torch.randperm(y.size(0))]
    marginal = torch.clamp(torch.exp(mine(x, y_perm)).mean(), min=1e-8, max=1e8)
    return joint - torch.log(marginal + 1e-6)

# Classifier (Final FER model)

In [37]:
class ExpressionClassifier(nn.Module):
    def __init__(self, input_dim=64, num_classes=7):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# Train

In [40]:
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create dummy dataset for demonstration (replace with your actual dataset path)
    try:
        dataset = PairFERDataset("train")
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    except:
        print("Dataset not found. Please ensure 'train' directory exists with proper structure.")
        return

    # Models
    E_exp = Encoder().to(device)
    E_id = Encoder().to(device)
    stat_net = StatisticsNetwork(64).to(device)
    disc = Discriminator(64).to(device)
    clf = ExpressionClassifier().to(device)

    # Optimizers
    opt_all = optim.Adam(list(E_exp.parameters()) + list(E_id.parameters()) +
                         list(stat_net.parameters()) + list(disc.parameters()), lr=1e-6)
    opt_clf = optim.Adam(clf.parameters(), lr=1e-6)
    loss_fn = nn.CrossEntropyLoss()

    print("Starting training...")
    
    for epoch in range(10):
        epoch_loss_exp = 0
        epoch_loss_clf = 0
        num_batches = 0
        
        for batch_idx, (m, n, label) in enumerate(dataloader):
            m, n, label = m.to(device), n.to(device), label.to(device)

            # Convert grayscale to RGB if needed
            if m.shape[1] == 1:
                m = m.repeat(1, 3, 1, 1)
            if n.shape[1] == 1:
                n = n.repeat(1, 3, 1, 1)

            # Encode images
            EM, EN = E_exp(m), E_exp(n)  # Expression embeddings
            IM, IN = E_id(m), E_id(n)    # Identity embeddings

            # Expression MI maximization (we want similar expressions to have high MI)
            # Here we estimate MI between expression embeddings of the same class
            mi_exp = estimate_mi(stat_net, EM, EN)
            
            # L1 regularization to encourage similar expression embeddings for same class
            l1 = torch.mean(torch.abs(EM - EN))
            loss_exp = -mi_exp + 0.1 * l1

            # Identity MI minimization (we want to remove identity info from expression embeddings)
            # Create combined embeddings for identity estimation
            TM = torch.cat([EM, IM], dim=1)
            TN = torch.cat([EN, IN], dim=1)
            
            # We want to minimize MI between expression and identity embeddings
            mi_id_exp = estimate_mi(stat_net, EM, IM)
            mi_id_exp_n = estimate_mi(stat_net, EN, IN)
            mi_id = mi_id_exp + mi_id_exp_n

            # Adversarial loss for identity disentanglement
            real = disc(EM, IM)
            fake = disc(EM, IN[torch.randperm(IN.size(0))])
            loss_adv = -torch.mean(torch.log(real + 1e-6) + torch.log(1 - fake + 1e-6))

            # Total loss for encoder training
            loss_total = loss_exp + mi_id + 0.1 * loss_adv

            # Update encoders, statistics network, and discriminator
            opt_all.zero_grad()
            loss_total.backward()
            torch.nn.utils.clip_grad_norm_(list(E_exp.parameters()) + list(E_id.parameters()) + 
                                           list(stat_net.parameters()) + list(disc.parameters()), max_norm=1.0)
            opt_all.step()

            # Train classifier on expression embeddings
            pred = clf(EM.detach())
            loss_clf = loss_fn(pred, label)

            opt_clf.zero_grad()
            loss_clf.backward()
            opt_clf.step()

            epoch_loss_exp += loss_exp.item()
            epoch_loss_clf += loss_clf.item()
            num_batches += 1

            if batch_idx % 90 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}: MI Loss={loss_exp.item():.4f}, Clf Loss={loss_clf.item():.4f}")

        avg_loss_exp = epoch_loss_exp / num_batches
        avg_loss_clf = epoch_loss_clf / num_batches
        print(f"Epoch {epoch} Complete: Avg MI Loss={avg_loss_exp:.4f}, Avg Clf Loss={avg_loss_clf:.4f}")

    print("Training completed!")

# Demo function to show model architecture
def demo_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create models
    E_exp = Encoder().to(device)
    E_id = Encoder().to(device)
    stat_net = StatisticsNetwork(64).to(device)
    disc = Discriminator(64).to(device)
    clf = ExpressionClassifier().to(device)
    
    # Create dummy input
    dummy_input = torch.randn(4, 3, 112, 112).to(device)
    
    print("Model Architecture Demo:")
    print(f"Input shape: {dummy_input.shape}")
    
    # Forward pass
    exp_emb = E_exp(dummy_input)
    id_emb = E_id(dummy_input)
    
    print(f"Expression embedding shape: {exp_emb.shape}")
    print(f"Identity embedding shape: {id_emb.shape}")
    
    # Statistics network
    mi_score = stat_net(exp_emb, id_emb)
    print(f"MI score shape: {mi_score.shape}")
    
    # Discriminator
    disc_score = disc(exp_emb, id_emb)
    print(f"Discriminator score shape: {disc_score.shape}")
    
    # Classifier
    pred = clf(exp_emb)
    print(f"Classification output shape: {pred.shape}")
    
    print("\nModel summary:")
    print(f"Expression Encoder parameters: {sum(p.numel() for p in E_exp.parameters())}")
    print(f"Identity Encoder parameters: {sum(p.numel() for p in E_id.parameters())}")
    print(f"Statistics Network parameters: {sum(p.numel() for p in stat_net.parameters())}")
    print(f"Discriminator parameters: {sum(p.numel() for p in disc.parameters())}")
    print(f"Classifier parameters: {sum(p.numel() for p in clf.parameters())}")

if __name__ == '__main__':
    # Run demo first
    demo_model()
    
    # Then run training (uncomment the line below when you have the dataset)
train()

Model Architecture Demo:
Input shape: torch.Size([4, 3, 112, 112])
Expression embedding shape: torch.Size([4, 64])
Identity embedding shape: torch.Size([4, 64])
MI score shape: torch.Size([4, 1])
Discriminator score shape: torch.Size([4, 1])
Classification output shape: torch.Size([4, 7])

Model summary:
Expression Encoder parameters: 11209344
Identity Encoder parameters: 11209344
Statistics Network parameters: 33281
Discriminator parameters: 33281
Classifier parameters: 9223
Using device: cuda
Starting training...
Epoch 0, Batch 0: MI Loss=0.0604, Clf Loss=1.9309
Epoch 0, Batch 90: MI Loss=0.0569, Clf Loss=1.9881
Epoch 0, Batch 180: MI Loss=0.0644, Clf Loss=1.9395
Epoch 0, Batch 270: MI Loss=0.0656, Clf Loss=1.9639
Epoch 0, Batch 360: MI Loss=0.0559, Clf Loss=1.9826
Epoch 0, Batch 450: MI Loss=0.0710, Clf Loss=1.9705
Epoch 0, Batch 540: MI Loss=0.0691, Clf Loss=1.9585
Epoch 0, Batch 630: MI Loss=0.0813, Clf Loss=1.9671
Epoch 0, Batch 720: MI Loss=0.0618, Clf Loss=1.9611
Epoch 0, Batch

In [46]:
def test():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load test dataset
    try:
        test_dataset = TestFERDataset("/test")
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
        print(f"Test dataset loaded: {len(test_dataset)} images")
    except:
        print("Test dataset not found. Please ensure 'test' directory exists with proper structure.")
        return
    
    # Load trained models
    try:
        E_exp = Encoder().to(device)
        clf = ExpressionClassifier().to(device)
        
        # Load latest saved weights
        E_exp.load_state_dict(torch.load('expression_encoder_epoch_9.pth'))
        clf.load_state_dict(torch.load('classifier_epoch_9.pth'))
        
        print("Models loaded successfully")
    except:
        print("Model weights not found. Please train the model first.")
        return
    
    # Set models to evaluation mode
    E_exp.eval()
    clf.eval()
    
    # Store predictions and true labels
    all_predictions = []
    all_labels = []
    
    print("Starting testing...")
    
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            
            # Convert grayscale to RGB if needed
            if images.shape[1] == 1:
                images = images.repeat(1, 3, 1, 1)
            
            # Forward pass
            expression_embeddings = E_exp(images)
            predictions = clf(expression_embeddings)
            
            # Get predicted classes
            predicted_classes = torch.argmax(predictions, dim=1)
            
            # Store results
            all_predictions.extend(predicted_classes.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            if batch_idx % 40 == 0:
                print(f"Processed batch {batch_idx}/{len(test_loader)}")
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    
    # Class names
    class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
    
    print(f"\n=== TEST RESULTS ===")
    print(f"Total test samples: {len(all_labels)}")
    print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    # Detailed classification report
    print(f"\n=== CLASSIFICATION REPORT ===")
    print(classification_report(all_labels, all_predictions, target_names=class_names))

def demo_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create models
    E_exp = Encoder().to(device)
    E_id = Encoder().to(device)
    stat_net = StatisticsNetwork(64).to(device)
    disc = Discriminator(64).to(device)
    clf = ExpressionClassifier().to(device)
    
    # Create dummy input
    dummy_input = torch.randn(4, 3, 112, 112).to(device)
    
    print("Model Architecture Demo:")
    print(f"Input shape: {dummy_input.shape}")
    
    # Forward pass
    exp_emb = E_exp(dummy_input)
    id_emb = E_id(dummy_input)
    
    print(f"Expression embedding shape: {exp_emb.shape}")
    print(f"Identity embedding shape: {id_emb.shape}")
    
    # Statistics network
    mi_score = stat_net(exp_emb, id_emb)
    print(f"MI score shape: {mi_score.shape}")
    
    # Discriminator
    disc_score = disc(exp_emb, id_emb)
    print(f"Discriminator score shape: {disc_score.shape}")
    
    # Classifier
    pred = clf(exp_emb)
    print(f"Classification output shape: {pred.shape}")
    
    print("\nModel summary:")
    print(f"Expression Encoder parameters: {sum(p.numel() for p in E_exp.parameters())}")
    print(f"Identity Encoder parameters: {sum(p.numel() for p in E_id.parameters())}")
    print(f"Statistics Network parameters: {sum(p.numel() for p in stat_net.parameters())}")
    print(f"Discriminator parameters: {sum(p.numel() for p in disc.parameters())}")
    print(f"Classifier parameters: {sum(p.numel() for p in clf.parameters())}")