In [76]:
import torch
import torch.optim as optim
import torch.nn as nn
import lightning as L
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Dataset
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from sklearn.model_selection import train_test_split
import random
torch.set_float32_matmul_precision('medium')

In [77]:
def set_seed(seed):
    random.seed(seed)                   # Python random module
    np.random.seed(seed)                # NumPy random seed
    torch.manual_seed(seed)             # PyTorch CPU seed
    torch.cuda.manual_seed(seed)        # PyTorch CUDA seed
    torch.cuda.manual_seed_all(seed)    # All CUDA devices if using multi-GPU
    torch.backends.cudnn.deterministic = True   # For reproducibility
    torch.backends.cudnn.benchmark = False      # Disable for reproducibility

# Example usage:
set_seed(42)

In [78]:
dataset_mean = [0.54148953, 0.42486119, 0.37428667]
dataset_std = [0.23021227, 0.2072772, 0.1976669 ]

class ExpDataset(Dataset):
    def __init__(self, imgs_path, csv_path, transform=None):
        super().__init__()
        self.imgs_path = imgs_path
        self.csv_data = pd.read_csv(csv_path)
        self.transform = transform

    def __len__(self):
        return len(self.csv_data)
    def __getitem__(self, index):
        record = self.csv_data.iloc[index]

        img_name = record['image_name']
        img_path = os.path.join(self.imgs_path, img_name+".jpg")
        img = Image.open(img_path).convert("RGB")

        # Apply transforms to the image
        if self.transform:
            img = self.transform(img)
        else:
            img = transforms.Compose([
                transforms.Resize((224, 224)).interpolation,
                transforms.ToTensor(),
                transforms.Normalize(mean=dataset_mean, std=dataset_std)
            ])(img)
        
        label = record['expression_label']
        return img, label

imgs_path = "../expW/origin_cleaned"
labels_path = "../expW/new_label.csv"

exp_dataset = ExpDataset(imgs_path, labels_path)
N = len(exp_dataset)
dataset_range = range(N)

train_indices, val_indices = train_test_split(
    dataset_range,
    train_size=0.8,
    random_state=42,
    stratify=exp_dataset.csv_data['expression_label']
)

In [79]:
train_transforms = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.25),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=2.5, translate=(0.05, 0.05), scale=(1.05, 1.05)),
    transforms.ToTensor(),
    transforms.Normalize(mean=dataset_mean, std=dataset_std)
])

val_transforms = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=dataset_mean, std=dataset_std)
])

train_dataset = ExpDataset(imgs_path, labels_path, transform=train_transforms)
val_dataset = ExpDataset(imgs_path, labels_path, transform=val_transforms)

# Final Splits
train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)

img, label = train_dataset[36]
img
# len(train_dataset)

tensor([[[-2.3521, -2.1307, -2.1307,  ..., -2.1307, -2.1307, -2.1307],
         [-2.3521, -2.1307, -2.1307,  ..., -2.1307, -2.1307, -2.1307],
         [-2.3521, -2.1307, -2.1307,  ..., -2.1307, -2.1307, -2.1307],
         ...,
         [-2.3521, -2.3521, -2.3521,  ..., -2.1307, -2.1307, -2.1307],
         [-2.3521, -2.3521, -2.3521,  ..., -2.1307, -2.1307, -2.1307],
         [-2.3521, -2.3521, -2.3521,  ..., -2.1307, -2.1307, -2.1307]],

        [[-2.0497, -1.8038, -1.8038,  ..., -1.8038, -1.8038, -1.8038],
         [-2.0497, -1.8038, -1.8038,  ..., -1.8038, -1.8038, -1.8038],
         [-2.0497, -1.8038, -1.8038,  ..., -1.8038, -1.8038, -1.8038],
         ...,
         [-2.0497, -2.0497, -2.0497,  ..., -1.8038, -1.8038, -1.8038],
         [-2.0497, -2.0497, -2.0497,  ..., -1.8038, -1.8038, -1.8038],
         [-2.0497, -2.0497, -2.0497,  ..., -1.8038, -1.8038, -1.8038]],

        [[-1.8935, -1.6356, -1.6356,  ..., -1.6356, -1.6356, -1.6356],
         [-1.8935, -1.6356, -1.6356,  ..., -1

In [80]:
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False
)

In [81]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, padding=1, dropout=False, d=0.05, kernel_size=3):
        super(ResidualBlock, self).__init__()
        
        # Reduce Spatial Dimenstion here by half or more by taking (stride > 1)
        # Use the kernel size to determine what 'kind' of features to focus on, smaller kernels for finer detail but needs more stacks
        # to achieve the perfect receptive field
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=padding,
            stride=stride,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.dropout1 = nn.Dropout2d(d) if dropout else nn.Identity()
        self.relu = nn.ReLU(inplace=True)

        # Use stride 1 here always
        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=padding,
            stride=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout2 = nn.Dropout2d(d) if dropout else nn.Identity()

        # Skip connection Logic
        if in_channels != out_channels or stride != 1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=1,
                    stride=stride,  # match spatial downsample
                    padding=0,
                    bias=False
                ),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = x
        out_intermediate = self.relu(self.bn1(self.conv1(x)))
        out_intermediate = self.dropout1(out_intermediate)

        out_intermediate = self.bn2(self.conv2(out_intermediate))
        out_intermediate = self.dropout2(out_intermediate)

        out = out_intermediate + self.shortcut(identity)
        out = self.relu(out)
        return out, out_intermediate

In [82]:
# Attention
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1=nn.Conv2d(in_planes, in_planes//ratio, 1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)
    
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = (kernel_size - 1) // 2
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv1(x_cat))

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=8, spatial_kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(spatial_kernel_size)

    def forward(self, x):
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

In [83]:
class Branch1(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = ResidualBlock(16, 16, stride=1, kernel_size=3, padding=1)
        self.block2 = ResidualBlock(16, 32, stride=2, kernel_size=3, padding=1)
        self.cbam1 = CBAM(32, ratio=4)
        self.block3 = ResidualBlock(32, 64, stride=2, kernel_size=3, padding=1)
        self.cbam2 = CBAM(64, ratio=8)
        self.block4 = ResidualBlock(64, 128, stride=2, kernel_size=3, padding=1, dropout=True)
        self.cbam3 = CBAM(128, ratio=8)
        self.block5 = ResidualBlock(128, 128, stride=1, kernel_size=3, padding=1, dropout=True)

    def forward(self, x):
        out, _ = self.block1(x)
        out, _ = self.block2(out)
        out = self.cbam1(out)
        out, _ = self.block3(out)
        out = self.cbam2(out)
        out, _ = self.block4(out)
        out = self.cbam3(out)
        out, final_intermediate = self.block5(out)
        return out, final_intermediate

class Branch2(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = ResidualBlock(16, 32, stride=2, kernel_size=5, padding=2)
        self.cbam1 = CBAM(32, ratio=4)
        self.block2 = ResidualBlock(32, 64, stride=2, kernel_size=5, padding=2, dropout=True)
        self.cbam2 = CBAM(64, ratio=8)
        self.block3 = ResidualBlock(64, 128, stride=2, kernel_size=5, padding=2, dropout=True)

    def forward(self, x):
        out, _ = self.block1(x)
        out = self.cbam1(out)
        out, _ = self.block2(out)
        out = self.cbam2(out)
        out, final_intermediate = self.block3(out)
        return out, final_intermediate

In [84]:
class MildCBAM(nn.Module):
    def __init__(self, in_planes):
        super().__init__()
        self.cbam = CBAM(in_planes, ratio=16)  # mild attention
    
    def forward(self, x):
        attn = self.cbam(x)
        return x * attn

class AttentionWeightedConcat(nn.Module):
    def __init__(self, channels1, channels2):
        super().__init__()
        self.attn1 = MildCBAM(channels1)
        self.attn2 = MildCBAM(channels2)
        self.weight1 = nn.Parameter(torch.ones(1, channels1, 1, 1))
        self.weight2 = nn.Parameter(torch.ones(1, channels2, 1, 1))
        self.batch_norm = nn.BatchNorm2d(channels1 + channels2)
    
    def forward(self, feat1, feat2):
        modulated1 = self.attn1(feat1)
        modulated2 = self.attn2(feat2)
        weighted1 = modulated1 * self.weight1
        weighted2 = modulated2 * self.weight2
        fused = torch.cat([weighted1, weighted2], dim=1)
        fused = self.batch_norm(fused)
        return fused

In [85]:
class ResnetBranchModel(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
        )
        self.branch1 = Branch1()  # Your previously defined branch1
        self.branch2 = Branch2()  # Your previously defined branch2
        self.fusion = AttentionWeightedConcat(128, 128)  # 128 channels from each branch
    
    def forward(self, x):
        shared = self.initial(x)
        b1_out, b1_intermediate = self.branch1(shared)
        b2_out, b2_intermediate = self.branch2(shared)

        if b1_out.shape[2:] != b2_out.shape[2:]:
            b2_out = nn.functional.interpolate(b2_out, size=b1_out.shape[2:], mode='bilinear', align_corners=False)

        fused = self.fusion(b1_out, b2_out)

        return fused, b1_intermediate, b2_intermediate

In [86]:
model = ResnetBranchModel()
x = torch.randn(1, 3, 224, 224)
fused, b1_int, b2_int = model(x)

print(f"Branch1 intermediate: {b1_int.shape}")
print(f"Branch2 intermediate: {b2_int.shape}")
print(f"Fused features: {fused.shape}")

Branch1 intermediate: torch.Size([1, 128, 14, 14])
Branch2 intermediate: torch.Size([1, 128, 14, 14])
Fused features: torch.Size([1, 256, 14, 14])


In [87]:
class MultiInputTransformerEncoder(nn.Module):
    def __init__(self, input_channels_fused, input_channels_int, embed_dim=256, num_heads=8, num_layers=4, height=14, width=14):
        super().__init__()
        self.height = height
        self.width = width
        
        # Project each input (fused and intermediates) to a common embedding dimension
        self.proj_fused = nn.Conv2d(input_channels_fused, embed_dim, kernel_size=1)
        self.proj_b1_int = nn.Conv2d(input_channels_int, embed_dim, kernel_size=1)
        self.proj_b2_int = nn.Conv2d(input_channels_int, embed_dim, kernel_size=1)

        # Learnable positional embeddings for spatial information
        self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, height, width))

        # Learnable path embeddings to distinguish between fused and branch inputs
        self.path_embed = nn.Parameter(torch.randn(3, 1, embed_dim))  # for fused, branch1 intermediate, branch2 intermediate

        # Transformer encoder stack with multi-head attention
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dropout=0.1,
            dim_feedforward=embed_dim * 4,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, fused, b1_int, b2_int):
        B = fused.size(0)

        # Project inputs and add positional embedding
        fused_embed = self.proj_fused(fused) + self.pos_embed
        b1_embed = self.proj_b1_int(b1_int) + self.pos_embed
        b2_embed = self.proj_b2_int(b2_int) + self.pos_embed

        # Flatten spatial dimensions (B, C, H, W) -> (B, H*W, C)
        def flatten_spatial(x):
            B, C, H, W = x.shape
            return x.flatten(2).permute(0, 2, 1)  # shape: (B, H*W, C)

        fused_seq = flatten_spatial(fused_embed)
        b1_seq = flatten_spatial(b1_embed)
        b2_seq = flatten_spatial(b2_embed)

        # Add path embeddings broadcasted across batch and sequence length
        fused_seq = fused_seq + self.path_embed[0]
        b1_seq = b1_seq + self.path_embed[1]
        b2_seq = b2_seq + self.path_embed[2]

        # Concatenate sequences along token dimension (B, 3*H*W, embed_dim)
        combined_seq = torch.cat([fused_seq, b1_seq, b2_seq], dim=1)

        # Transformer expects input shape (sequence length, batch, embedding dim)
        combined_seq = combined_seq.permute(1, 0, 2)

        # Pass through transformer encoder
        transformed_seq = self.transformer_encoder(combined_seq)

        # Back to (batch, sequence length, embedding dim)
        transformed_seq = transformed_seq.permute(1, 0, 2)

        # Layer norm on transformer output
        transformed_seq = self.layer_norm(transformed_seq)

        return transformed_seq  # Shape: (B, 3*H*W, embed_dim)

In [None]:
class ResnetBranchHybridTransformer(L.LightningModule):
    def __init__(self, backbone, transformer, num_classes=7,
                 lr=1e-3, weight_decay=1e-4, max_epochs=40, min_lr=1e-6):
        super().__init__()
        self.save_hyperparameters(ignore=['backbone', 'transformer'])

        self.backbone = backbone
        self.transformer = transformer
        self.num_classes = num_classes

        # Set classifier input size from transformer params
        embed_dim = transformer.layer_norm.normalized_shape[0]
        H, W = transformer.height, transformer.width
        self.classifier = nn.Linear(embed_dim * 3 * H * W, num_classes)

        # Metrics
        self.train_acc = MulticlassAccuracy(num_classes=num_classes, average='weighted')
        self.val_acc = MulticlassAccuracy(num_classes=num_classes, average='weighted')
        self.train_f1 = MulticlassF1Score(num_classes=num_classes, average='weighted')
        self.val_f1 = MulticlassF1Score(num_classes=num_classes, average='weighted')

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        fused, b1_int, b2_int = self.backbone(x)
        transformer_out = self.transformer(fused, b1_int, b2_int)  # [B, 3*H*W, embed_dim]
        B = transformer_out.size(0)
        logits = self.classifier(transformer_out.view(B, -1))
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)
        self.train_f1.update(preds, y)
        self.log('train_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def on_train_epoch_end(self):
        avg_loss = self.trainer.callback_metrics.get('train_loss')
        acc = self.train_acc.compute().item()
        f1 = self.train_f1.compute().item()

        print(f"[Epoch {self.current_epoch}] "
              f"Train Loss: {avg_loss:.4f} | "
              f"Train Acc: {acc:.4f} | "
              f"Train F1: {f1:.4f}")

        self.train_acc.reset()
        self.train_f1.reset()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_acc.update(preds, y)
        self.val_f1.update(preds, y)
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', self.val_acc, prog_bar=True, on_epoch=True)
        self.log('val_f1', self.val_f1, prog_bar=True, on_epoch=True)
        return loss

    def on_validation_epoch_end(self):
        avg_loss = self.trainer.callback_metrics.get('val_loss')
        acc = self.val_acc.compute().item()
        f1 = self.val_f1.compute().item()

        print(f"[Epoch {self.current_epoch}] "
              f"Val Loss: {avg_loss:.4f} | "
              f"Val Acc: {acc:.4f} | "
              f"Val F1: {f1:.4f}")

        self.val_acc.reset()
        self.val_f1.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay
        )
        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=self.hparams.max_epochs,
            eta_min=self.hparams.min_lr
        )
        return {
            "optimizer": optimizer,
            'lr_scheduler': {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1
            }
        }

In [89]:
from lightning.pytorch.callbacks import ModelCheckpoint

backbone = ResnetBranchModel(in_channels=3)
# Instantiate the MultiInputTransformerEncoder
# Input channel details from your model:
# - fused features channels = 256 (128 from each branch concatenated)
# - intermediate features channels = 128 (each branch outputs 128 channels intermediate)
transformer = MultiInputTransformerEncoder(
    input_channels_fused=256,
    input_channels_int=128,
    embed_dim=256,       # embedding dimension for transformer tokens
    num_heads=8,         # number of attention heads
    num_layers=4,        # number of transformer encoder layers
    height=14,           # spatial height of feature maps
    width=14             # spatial width of feature maps
)
model = ResnetBranchHybridTransformer(
    backbone=backbone,
    transformer=transformer,
    num_classes=7
)
checkpoint_cb = ModelCheckpoint(
    monitor='val_loss',
    mode='max',
    save_top_k=1,
    dirpath='../models/checkpoints_',
    filename='best_model'
)
trainer = L.Trainer(
    accelerator='gpu',
    max_epochs=40,
    callbacks=[checkpoint_cb],
    log_every_n_steps=1,
    logger=False,
    devices=1,
    precision='16-mixed',
    num_sanity_val_steps=0
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [90]:
trainer.fit(
    model,
    train_loader,
    val_loader
)

d:\Coding\emotion_project\venv\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:658: Checkpoint directory D:\Coding\emotion_project\models\checkpoints_ exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type                         | Params | Mode 
---------------------------------------------------------------------
0 | backbone    | ResnetBranchModel            | 1.4 M  | train
1 | transformer | MultiInputTransformerEncoder | 3.3 M  | train
2 | classifier  | Linear                       | 1.1 M  | train
3 | train_acc   | MulticlassAccuracy           | 0      | train
4 | val_acc     | MulticlassAccuracy           | 0      | train
5 | train_f1    | MulticlassF1Score            | 0      | train
6 | val_f1      | MulticlassF1Score            | 0      | train
7 | loss_fn     | CrossEntropyLoss             | 0      | train
---------------------------------------------------------------------
5.8 M     Trainable params
0         Non-traina

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

TypeError: unsupported format string passed to NoneType.__format__

In [None]:
best_model_path = checkpoint_cb.best_model_path
print(f"Best model saved at: {best_model_path}")

best_model = ResnetBranchHybridTransformer.load_from_checkpoint(
    best_model_path,
    backbone=backbone,
    transformer=transformer,
    num_classes=7
)
dummy_input = torch.randn(1, 3, 224, 224)
best_model.to_onnx('../models/resnet_branch_hybrid.onnx', dummy_input, export_params=True)