In [2]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition (Gated Fusion with ResNet50)
# ======================
class MultiModalClassifier(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze all image model parameters except those in layer4.
        """
        super(MultiModalClassifier, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.5)
        
        # ---- Image model: ResNet18 ----
        self.image_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.image_model.fc = nn.Identity()
        if freeze_image_layers:
            for name, param in self.image_model.named_parameters():
                if not name.startswith('layer4'):
                    param.requires_grad = False
        if 512 != fusion_dim:
            self.image_projection = nn.Linear(512, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.5)
        
        # ---- Intermediate Fusion: Concatenation + MLP ----
        # Concatenates text and image features (resulting in 2*fusion_dim) and fuses via MLP.
        self.fusion_mlp = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(fusion_dim, fusion_dim)
        )
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # --- Process text ---
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # Use the [CLS] token representation
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)  # (batch, fusion_dim)
        
        # --- Process image ---
        image_feat = self.image_model(image)  # (batch, 512)
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)  # (batch, fusion_dim)
        
        # --- Intermediate Fusion: Concatenation + MLP ---
        # Concatenate along the feature dimension
        multimodal_feats = torch.cat((text_feat, image_feat), dim=1)  # (batch, fusion_dim*2)
        fused_feat = self.fusion_mlp(multimodal_feats)  # (batch, fusion_dim)
        
        # --- Classification ---
        logits = self.classifier(fused_feat)
        return logits

# ======================
# 3. Data Preparation with Enhanced Augmentation
# ======================
train_image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as before

# Define dataset paths.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

# ======================
# Create DataLoaders with Balancing for Training Data
# ======================
# Compute sample weights based on class frequencies in the training set.
train_labels = [sample[2] for sample in train_dataset.samples]
class_counts = np.bincount(train_labels)
class_weights = 1.0 / class_counts  # Inverse frequency for each class.
sample_weights = [class_weights[label] for label in train_labels]

# Create the sampler for balanced sampling.
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# Use the sampler in the training DataLoader.
train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4

# Instantiate the gating fusion model with ResNet50.
model = MultiModalClassifierGated(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

# Use weight decay for regularization.
optimizer = optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Learning rate scheduler (without verbose parameter).
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    scheduler.step(val_loss)
    print("Current LR:", scheduler.optimizer.param_groups[0]['lr'])
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


Epoch 1/4


                                                                                                                        

  Train Loss: 0.5016 | Train Acc: 82.26%
  Val   Loss: 0.3525 | Val   Acc: 87.22%
  Confusion Matrix:
[[326  23   1  22]
 [ 67 653  36  12]
 [ 10  13 324   5]
 [ 16  23   2 267]]

Current LR: 2e-05
Epoch 2/4


                                                                                                                        

  Train Loss: 0.2522 | Train Acc: 91.19%
  Val   Loss: 0.3399 | Val   Acc: 87.33%
  Confusion Matrix:
[[333  13   5  21]
 [ 71 639  40  18]
 [  7  12 329   4]
 [ 20  14   3 271]]

Current LR: 2e-05
Epoch 3/4


                                                                                                                        

  Train Loss: 0.1975 | Train Acc: 93.26%
  Val   Loss: 0.3097 | Val   Acc: 90.22%
  Confusion Matrix:
[[329  23   1  19]
 [ 35 700  12  21]
 [ 10  21 316   5]
 [ 13  15   1 279]]

Current LR: 2e-05
Epoch 4/4


                                                                                                                        

  Train Loss: 0.1627 | Train Acc: 94.52%
  Val   Loss: 0.3213 | Val   Acc: 89.22%
  Confusion Matrix:
[[311  29   9  23]
 [ 27 692  26  23]
 [  5  18 325   4]
 [ 12  14   4 278]]

Current LR: 2e-05
Test Results
  Test Loss:   0.4562 | Accuracy: 85.40%
  Confusion Matrix:
[[475  98  36  86]
 [ 47 967  47  25]
 [ 12  20 764   3]
 [ 52  59  16 725]]


Epoch 1/4

                                                                                                                        

  Train Loss: 0.5016 | Train Acc: 82.26%
  Val   Loss: 0.3525 | Val   Acc: 87.22%
  Confusion Matrix:
[[326  23   1  22]
 [ 67 653  36  12]
 [ 10  13 324   5]
 [ 16  23   2 267]]

Current LR: 2e-05
Epoch 2/4

                                                                                                                        

  Train Loss: 0.2522 | Train Acc: 91.19%
  Val   Loss: 0.3399 | Val   Acc: 87.33%
  Confusion Matrix:
[[333  13   5  21]
 [ 71 639  40  18]
 [  7  12 329   4]
 [ 20  14   3 271]]

Current LR: 2e-05
Epoch 3/4

                                                                                                                        

  Train Loss: 0.1975 | Train Acc: 93.26%
  Val   Loss: 0.3097 | Val   Acc: 90.22%
  Confusion Matrix:
[[329  23   1  19]
 [ 35 700  12  21]
 [ 10  21 316   5]
 [ 13  15   1 279]]

Current LR: 2e-05
Epoch 4/4

                                                                                                                        

  Train Loss: 0.1627 | Train Acc: 94.52%
  Val   Loss: 0.3213 | Val   Acc: 89.22%
  Confusion Matrix:
[[311  29   9  23]
 [ 27 692  26  23]
 [  5  18 325   4]
 [ 12  14   4 278]]

Current LR: 2e-05
Test Results
  Test Loss:   0.4562 | Accuracy: 85.40%
  Confusion Matrix:
[[475  98  36  86]
 [ 47 967  47  25]
 [ 12  20 764   3]
 [ 52  59  16 725]]


In [11]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
import torch
import torchvision.models as models
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition
# ======================
class MultiModalClassifier(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze most image model parameters except those in the last transformer block.
        """
        super(MultiModalClassifier, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        # Project DistilBERT's output (768-dim) to fusion_dim.
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.3)
        
        # ---- Image model: Vision Transformer (ViT) ----
        self.image_model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
        # Remove the classification head to obtain features.
        self.image_model.heads = nn.Identity()
        
        if freeze_image_layers:
            # Freeze all parameters except for those in the final transformer block.
            # Note: vit_b_16 has 12 transformer blocks stored in encoder.layers.
            last_block_index = len(self.image_model.encoder.layers) - 1
            for name, param in self.image_model.named_parameters():
                if 'encoder.layers' in name:
                    # Expected format: "encoder.layers.encoder_layer_<index>...."
                    block_index = int(name.split('.')[2].split('_')[-1])
                    if block_index != last_block_index:
                        param.requires_grad = False
                else:
                    # Optionally freeze non-encoder parameters (e.g. patch embedding, positional encoding)
                    param.requires_grad = False


        
        # The ViT model outputs embeddings of size 768 by default.
        if 768 != fusion_dim:
            self.image_projection = nn.Linear(768, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.3)
        
        # ---- Attention-based Fusion ----
        # We stack text and image embeddings (each of size fusion_dim) into a sequence of 2 tokens.
        self.attention = nn.MultiheadAttention(embed_dim=fusion_dim, num_heads=1, batch_first=True)
        self.fusion_dropout = nn.Dropout(0.3)
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # --- Process text ---
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the [CLS] token representation (first token) as text feature.
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # (batch, 768)
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)  # (batch, fusion_dim)
        
        # --- Process image ---
        image_feat = self.image_model(image)  # (batch, 768)
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)  # (batch, fusion_dim)
        
        # --- Stack features and apply attention fusion ---
        # Create a sequence of two tokens per sample: one from text and one from image.
        multimodal_feats = torch.stack([text_feat, image_feat], dim=1)  # (batch, 2, fusion_dim)
        # Self-attention: query, key, and value are the same.
        attn_output, _ = self.attention(multimodal_feats, multimodal_feats, multimodal_feats)
        # Aggregate the two tokens by averaging over the sequence dimension.
        fused_feat = attn_output.mean(dim=1)  # (batch, fusion_dim)
        fused_feat = self.fusion_dropout(fused_feat)
        
        # --- Classification ---
        logits = self.classifier(fused_feat)
        return logits

# ======================
# 3. Data Preparation
# ======================
# Define transforms for training and validation (you can adjust augmentation as needed)
train_image_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Instantiate tokenizer for DistilBERT.
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as in your example

# Define paths to your train, validation, and test folders.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

# Create dataset instances.
train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

# Create dataloaders.
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4
model = MultiModalClassifier(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    # Initialize tqdm progress bar
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Update running loss
        total_loss += loss.item()
        
        # Calculate batch accuracy
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        
        # Update progress bar with current loss and accuracy
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

# Updated training loop with progress evaluation
EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

# Optionally, run on test set.
test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


Epoch 1/4


                                                                                                                        

  Train Loss: 0.5373 | Train Acc: 79.74%
  Val   Loss: 0.3192 | Val   Acc: 88.44%
  Confusion Matrix:
[[299  47   4  22]
 [ 31 711  15  11]
 [  6  21 322   3]
 [ 24  22   2 260]]

Epoch 2/4


                                                                                                                        

  Train Loss: 0.2649 | Train Acc: 90.67%
  Val   Loss: 0.2952 | Val   Acc: 89.78%
  Confusion Matrix:
[[318  34   1  19]
 [ 33 711   8  16]
 [  4  26 317   5]
 [ 21  17   0 270]]

Epoch 3/4


                                                                                                                        

  Train Loss: 0.1989 | Train Acc: 92.98%
  Val   Loss: 0.3014 | Val   Acc: 89.94%
  Confusion Matrix:
[[333  23   1  15]
 [ 42 695  14  17]
 [  7  21 321   3]
 [ 22  16   0 270]]

Epoch 4/4


                                                                                                                        

  Train Loss: 0.1586 | Train Acc: 94.27%
  Val   Loss: 0.2869 | Val   Acc: 90.33%
  Confusion Matrix:
[[329  29   1  13]
 [ 30 710  17  11]
 [  4  22 322   4]
 [ 21  22   0 265]]

Test Results
  Test Loss:   0.4506 | Accuracy: 85.23%
  Confusion Matrix:
[[516 108  15  56]
 [ 46 989  38  13]
 [ 24  26 745   4]
 [ 69  90  18 675]]


Epoch 1/4

                                                                                                                        

  Train Loss: 0.5373 | Train Acc: 79.74%
  Val   Loss: 0.3192 | Val   Acc: 88.44%
  Confusion Matrix:
[[299  47   4  22]
 [ 31 711  15  11]
 [  6  21 322   3]
 [ 24  22   2 260]]

Epoch 2/4

                                                                                                                        

  Train Loss: 0.2649 | Train Acc: 90.67%
  Val   Loss: 0.2952 | Val   Acc: 89.78%
  Confusion Matrix:
[[318  34   1  19]
 [ 33 711   8  16]
 [  4  26 317   5]
 [ 21  17   0 270]]

Epoch 3/4

                                                                                                                        

  Train Loss: 0.1989 | Train Acc: 92.98%
  Val   Loss: 0.3014 | Val   Acc: 89.94%
  Confusion Matrix:
[[333  23   1  15]
 [ 42 695  14  17]
 [  7  21 321   3]
 [ 22  16   0 270]]

Epoch 4/4

                                                                                                                        

  Train Loss: 0.1586 | Train Acc: 94.27%
  Val   Loss: 0.2869 | Val   Acc: 90.33%
  Confusion Matrix:
[[329  29   1  13]
 [ 30 710  17  11]
 [  4  22 322   4]
 [ 21  22   0 265]]

Test Results
  Test Loss:   0.4506 | Accuracy: 85.23%
  Confusion Matrix:
[[516 108  15  56]
 [ 46 989  38  13]
 [ 24  26 745   4]
 [ 69  90  18 675]]


In [13]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition
# ======================
class MultiModalClassifier(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze most image model parameters except for the last two transformer blocks.
        """
        super(MultiModalClassifier, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.3)
        
        # ---- Image model: Vision Transformer (ViT) ----
        self.image_model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
        self.image_model.heads = nn.Identity()  # Remove classification head
        
        if freeze_image_layers:
            # Unfreeze the last two transformer blocks; freeze earlier ones.
            last_block_index = len(self.image_model.encoder.layers) - 1
            second_last_index = last_block_index - 1
            for name, param in self.image_model.named_parameters():
                if 'encoder.layers' in name:
                    # Expecting names like "encoder.layers.encoder_layer_0...."
                    block_str = name.split('.')[2]
                    if block_str.startswith('encoder_layer_'):
                        block_index = int(block_str.split('_')[-1])
                    else:
                        block_index = int(block_str)
                    if block_index not in [second_last_index, last_block_index]:
                        param.requires_grad = False
                else:
                    param.requires_grad = False
        
        if 768 != fusion_dim:
            self.image_projection = nn.Linear(768, fusion_dim)
        else:            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.3)
        
        # ---- Gating Mechanism for Fusion ----
        self.gate = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )
        self.fusion_dropout = nn.Dropout(0.3)
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # Process text
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # Use [CLS] token
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)
        
        # Process image
        image_feat = self.image_model(image)
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)
        
        # Gated Fusion
        gate_weights = self.gate(torch.cat([text_feat, image_feat], dim=1))
        fused_feat = gate_weights * text_feat + (1 - gate_weights) * image_feat
        fused_feat = self.fusion_dropout(fused_feat)
        
        # Classification
        logits = self.classifier(fused_feat)
        return logits


# ======================
# 3. Data Preparation
# ======================
# Define transforms for training and validation (adjust augmentations as needed)
train_image_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Instantiate tokenizer for DistilBERT.
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as in your example

# Define paths to your train, validation, and test folders.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

# Create dataset instances.
train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

# Create dataloaders.
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4
model = MultiModalClassifier(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

# Training loop
EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

# Evaluate on test set.
test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


Epoch 1/4


                                                                                                                        

  Train Loss: 0.5098 | Train Acc: 80.93%
  Val   Loss: 0.3196 | Val   Acc: 88.61%
  Confusion Matrix:
[[305  46   1  20]
 [ 31 716  15   6]
 [  9  23 317   3]
 [ 24  27   0 257]]

Epoch 2/4


                                                                                                                        

  Train Loss: 0.2649 | Train Acc: 90.41%
  Val   Loss: 0.2830 | Val   Acc: 89.94%
  Confusion Matrix:
[[327  26   2  17]
 [ 35 710  12  11]
 [  5  21 322   4]
 [ 29  19   0 260]]

Epoch 3/4


                                                                                                                        

  Train Loss: 0.2037 | Train Acc: 92.89%
  Val   Loss: 0.3050 | Val   Acc: 90.06%
  Confusion Matrix:
[[333  28   1  10]
 [ 35 707  18   8]
 [  4  19 326   3]
 [ 34  19   0 255]]

Epoch 4/4


                                                                                                                        

  Train Loss: 0.1477 | Train Acc: 94.81%
  Val   Loss: 0.2943 | Val   Acc: 90.11%
  Confusion Matrix:
[[314  36   4  18]
 [ 24 717  15  12]
 [  2  20 326   4]
 [ 21  21   1 265]]

Test Results
  Test Loss:   0.4373 | Accuracy: 85.69%
  Confusion Matrix:
[[ 499  121   19   56]
 [  30 1002   39   15]
 [  17   25  751    6]
 [  58   85   20  689]]


Epoch 1/4

                                                                                                                        

  Train Loss: 0.5098 | Train Acc: 80.93%
  Val   Loss: 0.3196 | Val   Acc: 88.61%
  Confusion Matrix:
[[305  46   1  20]
 [ 31 716  15   6]
 [  9  23 317   3]
 [ 24  27   0 257]]

Epoch 2/4

                                                                                                                        

  Train Loss: 0.2649 | Train Acc: 90.41%
  Val   Loss: 0.2830 | Val   Acc: 89.94%
  Confusion Matrix:
[[327  26   2  17]
 [ 35 710  12  11]
 [  5  21 322   4]
 [ 29  19   0 260]]

Epoch 3/4

                                                                                                                        

  Train Loss: 0.2037 | Train Acc: 92.89%
  Val   Loss: 0.3050 | Val   Acc: 90.06%
  Confusion Matrix:
[[333  28   1  10]
 [ 35 707  18   8]
 [  4  19 326   3]
 [ 34  19   0 255]]

Epoch 4/4

                                                                                                                        

  Train Loss: 0.1477 | Train Acc: 94.81%
  Val   Loss: 0.2943 | Val   Acc: 90.11%
  Confusion Matrix:
[[314  36   4  18]
 [ 24 717  15  12]
 [  2  20 326   4]
 [ 21  21   1 265]]

Test Results
  Test Loss:   0.4373 | Accuracy: 85.69%
  Confusion Matrix:
[[ 499  121   19   56]
 [  30 1002   39   15]
 [  17   25  751    6]
 [  58   85   20  689]]
