In [None]:
# self attentition

In [2]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition
# ======================
class MultiModalClassifier(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze all image model parameters except those in layer4.
        """
        super(MultiModalClassifier, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        # Project DistilBERT's output (768-dim) to fusion_dim.
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.3)
        
        # ---- Image model: ResNet18 ----
        self.image_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        # Remove the final fc layer to obtain features (ResNet18 outputs 512-dim).
        self.image_model.fc = nn.Identity()
        if freeze_image_layers:
            # Freeze all parameters except those in layer4 (or adjust as needed)
            for name, param in self.image_model.named_parameters():
                if not name.startswith('layer4'):
                    param.requires_grad = False
        # If necessary, project image features to fusion_dim.
        if 512 != fusion_dim:
            self.image_projection = nn.Linear(512, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.3)
        
        # ---- Attention-based Fusion ----
        # We stack text and image embeddings (each of size fusion_dim) into a sequence of 2 tokens.
        # Using a single-head self-attention (batch_first=True so that input is (B, SeqLen, fusion_dim)).
        self.attention = nn.MultiheadAttention(embed_dim=fusion_dim, num_heads=1, batch_first=True)
        self.fusion_dropout = nn.Dropout(0.3)
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # --- Process text ---
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the [CLS] token representation (first token) as text feature.
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # (batch, 768)
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)  # (batch, fusion_dim)
        
        # --- Process image ---
        image_feat = self.image_model(image)  # (batch, 512)
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)  # (batch, fusion_dim)
        
        # --- Stack features and apply attention fusion ---
        # Create a sequence of two tokens per sample: one from text and one from image.
        multimodal_feats = torch.stack([text_feat, image_feat], dim=1)  # (batch, 2, fusion_dim)
        # Self-attention: here query, key, and value are the same.
        attn_output, attn_weights = self.attention(multimodal_feats, multimodal_feats, multimodal_feats)
        # Aggregate the two tokens by averaging over the sequence dimension.
        fused_feat = attn_output.mean(dim=1)  # (batch, fusion_dim)
        fused_feat = self.fusion_dropout(fused_feat)
        
        # --- Classification ---
        logits = self.classifier(fused_feat)
        return logits

# ======================
# 3. Data Preparation
# ======================
# Define transforms for training and validation (you can adjust augmentation as needed)
train_image_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Instantiate tokenizer for DistilBERT.
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as in your example

# Define paths to your train, validation, and test folders.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

# Create dataset instances.
train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

# Create dataloaders.
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4
model = MultiModalClassifier(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    # Initialize tqdm progress bar
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Update running loss
        total_loss += loss.item()
        
        # Calculate batch accuracy
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        
        # Update progress bar with current loss and accuracy
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

# Updated training loop with progress evaluation
EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

# Optionally, run on test set.
test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


Started next Epoch
Epoch 1/4
  Train Loss: 0.4570
  Val Loss:   0.3340 | Accuracy: 87.17%
  Confusion Matrix:
[[333  22   3  14]
 [ 68 674  15  11]
 [ 10  20 319   3]
 [ 49  15   1 243]]

Started next Epoch
Epoch 2/4
  Train Loss: 0.2606
  Val Loss:   0.3123 | Accuracy: 89.72%
  Confusion Matrix:
[[329  31   1  11]
 [ 35 718   7   8]
 [  7  20 321   4]
 [ 40  21   0 247]]

Started next Epoch
Epoch 3/4
  Train Loss: 0.1973
  Val Loss:   0.2871 | Accuracy: 90.11%
  Confusion Matrix:
[[328  30   2  12]
 [ 33 706  16  13]
 [  7  21 320   4]
 [ 23  17   0 268]]

Started next Epoch
Epoch 4/4
  Train Loss: 0.1560
  Val Loss:   0.3175 | Accuracy: 90.39%
  Confusion Matrix:
[[321  27   8  16]
 [ 28 711  20   9]
 [  4  17 328   3]
 [ 18  18   5 267]]

Test Results
  Test Loss:   0.4996 | Accuracy: 85.02%
  Confusion Matrix:
[[500 101  33  61]
 [ 51 982  37  16]
 [ 15  20 761   3]
 [ 64  88  25 675]]


Started next Epoch
Epoch 1/4
  Train Loss: 0.4570
  Val Loss:   0.3340 | Accuracy: 87.17%
  Confusion Matrix:
[[333  22   3  14]
 [ 68 674  15  11]
 [ 10  20 319   3]
 [ 49  15   1 243]]

Started next Epoch
Epoch 2/4
  Train Loss: 0.2606
  Val Loss:   0.3123 | Accuracy: 89.72%
  Confusion Matrix:
[[329  31   1  11]
 [ 35 718   7   8]
 [  7  20 321   4]
 [ 40  21   0 247]]

Started next Epoch
Epoch 3/4
  Train Loss: 0.1973
  Val Loss:   0.2871 | Accuracy: 90.11%
  Confusion Matrix:
[[328  30   2  12]
 [ 33 706  16  13]
 [  7  21 320   4]
 [ 23  17   0 268]]

Started next Epoch
Epoch 4/4
  Train Loss: 0.1560
  Val Loss:   0.3175 | Accuracy: 90.39%
  Confusion Matrix:
[[321  27   8  16]
 [ 28 711  20   9]
 [  4  17 328   3]
 [ 18  18   5 267]]

Test Results
  Test Loss:   0.4996 | Accuracy: 85.02%
  Confusion Matrix:
[[500 101  33  61]
 [ 51 982  37  16]
 [ 15  20 761   3]
 [ 64  88  25 675]]

In [None]:
# concatenation

In [4]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition
# ======================
class MultiModalClassifier(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze all image model parameters except those in layer4.
        """
        super(MultiModalClassifier, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.3)
        
        # ---- Image model: ResNet18 ----
        self.image_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.image_model.fc = nn.Identity()
        if freeze_image_layers:
            for name, param in self.image_model.named_parameters():
                if not name.startswith('layer4'):
                    param.requires_grad = False
        if 512 != fusion_dim:
            self.image_projection = nn.Linear(512, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.3)
        
        # ---- Intermediate Fusion: Concatenation + MLP ----
        # Concatenates text and image features (resulting in 2*fusion_dim) and fuses via MLP.
        self.fusion_mlp = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim, fusion_dim)
        )
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # --- Process text ---
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # Use the [CLS] token representation
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)  # (batch, fusion_dim)
        
        # --- Process image ---
        image_feat = self.image_model(image)  # (batch, 512)
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)  # (batch, fusion_dim)
        
        # --- Intermediate Fusion: Concatenation + MLP ---
        # Concatenate along the feature dimension
        multimodal_feats = torch.cat((text_feat, image_feat), dim=1)  # (batch, fusion_dim*2)
        fused_feat = self.fusion_mlp(multimodal_feats)  # (batch, fusion_dim)
        
        # --- Classification ---
        logits = self.classifier(fused_feat)
        return logits

# ======================
# 3. Data Preparation
# ======================
# Define transforms for training and validation (you can adjust augmentation as needed)
train_image_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Instantiate tokenizer for DistilBERT.
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as in your example

# Define paths to your train, validation, and test folders.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

# Create dataset instances.
train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

# Create dataloaders.
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4
model = MultiModalClassifier(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    # Initialize tqdm progress bar
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Update running loss
        total_loss += loss.item()
        
        # Calculate batch accuracy
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        
        # Update progress bar with current loss and accuracy
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

# Updated training loop with progress evaluation
EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

# Optionally, run on test set.
test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


Epoch 1/4


                                                                                                                        

  Train Loss: 0.4989 | Train Acc: 81.43%
  Val   Loss: 0.3339 | Val   Acc: 88.17%
  Confusion Matrix:
[[300  41  10  21]
 [ 41 691  26  10]
 [  4  17 328   3]
 [ 17  17   6 268]]

Epoch 2/4


                                                                                                                        

  Train Loss: 0.2668 | Train Acc: 90.78%
  Val   Loss: 0.2975 | Val   Acc: 89.72%
  Confusion Matrix:
[[325  29   1  17]
 [ 37 712  10   9]
 [ 10  21 317   4]
 [ 24  23   0 261]]

Epoch 3/4


                                                                                                                        

  Train Loss: 0.1959 | Train Acc: 93.29%
  Val   Loss: 0.2983 | Val   Acc: 90.17%
  Confusion Matrix:
[[334  20   4  14]
 [ 47 702   7  12]
 [  8  26 314   4]
 [ 23  12   0 273]]

Epoch 4/4


                                                                                                                        

  Train Loss: 0.1569 | Train Acc: 94.48%
  Val   Loss: 0.3457 | Val   Acc: 89.56%
  Confusion Matrix:
[[307  39   6  20]
 [ 26 710  22  10]
 [  4  19 325   4]
 [ 16  19   3 270]]

Test Results
  Test Loss:   0.4953 | Accuracy: 85.20%
  Confusion Matrix:
[[476 117  33  69]
 [ 36 989  44  17]
 [ 10  25 760   4]
 [ 51  78  24 699]]


Epoch 1/4

                                                                                                                        

  Train Loss: 0.4989 | Train Acc: 81.43%
  Val   Loss: 0.3339 | Val   Acc: 88.17%
  Confusion Matrix:
[[300  41  10  21]
 [ 41 691  26  10]
 [  4  17 328   3]
 [ 17  17   6 268]]

Epoch 2/4

                                                                                                                        

  Train Loss: 0.2668 | Train Acc: 90.78%
  Val   Loss: 0.2975 | Val   Acc: 89.72%
  Confusion Matrix:
[[325  29   1  17]
 [ 37 712  10   9]
 [ 10  21 317   4]
 [ 24  23   0 261]]

Epoch 3/4

                                                                                                                        

  Train Loss: 0.1959 | Train Acc: 93.29%
  Val   Loss: 0.2983 | Val   Acc: 90.17%
  Confusion Matrix:
[[334  20   4  14]
 [ 47 702   7  12]
 [  8  26 314   4]
 [ 23  12   0 273]]

Epoch 4/4

                                                                                                                        

  Train Loss: 0.1569 | Train Acc: 94.48%
  Val   Loss: 0.3457 | Val   Acc: 89.56%
  Confusion Matrix:
[[307  39   6  20]
 [ 26 710  22  10]
 [  4  19 325   4]
 [ 16  19   3 270]]

Test Results
  Test Loss:   0.4953 | Accuracy: 85.20%
  Confusion Matrix:
[[476 117  33  69]
 [ 36 989  44  17]
 [ 10  25 760   4]
 [ 51  78  24 699]]


In [None]:
# bilinear pooling

In [6]:
# import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition
# ======================
class MultiModalClassifierBilinear(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze all image model parameters except those in layer4.
        """
        super(MultiModalClassifierBilinear, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.3)
        
        # ---- Image model: ResNet18 ----
        self.image_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.image_model.fc = nn.Identity()
        if freeze_image_layers:
            for name, param in self.image_model.named_parameters():
                if not name.startswith('layer4'):
                    param.requires_grad = False
        if 512 != fusion_dim:
            self.image_projection = nn.Linear(512, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.3)
        
        # ---- Bilinear Pooling Fusion ----
        # The bilinear layer learns a transformation: fused = x^T W y + b,
        # where x and y are the projected text and image features.
        self.bilinear = nn.Bilinear(fusion_dim, fusion_dim, fusion_dim)
        self.pool_dropout = nn.Dropout(0.3)
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # --- Process text ---
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the [CLS] token representation (first token) as text feature.
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # (batch, 768)
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)  # (batch, fusion_dim)
        
        # --- Process image ---
        image_feat = self.image_model(image)  # (batch, 512)
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)  # (batch, fusion_dim)
        
        # --- Bilinear Fusion ---
        fused_feat = self.bilinear(text_feat, image_feat)  # (batch, fusion_dim)
        fused_feat = self.pool_dropout(fused_feat)
        
        # --- Classification ---
        logits = self.classifier(fused_feat)
        return logits

# ======================
# 3. Data Preparation
# ======================
# Define transforms for training and validation (you can adjust augmentation as needed)
train_image_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Instantiate tokenizer for DistilBERT.
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as in your example

# Define paths to your train, validation, and test folders.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

# Create dataset instances.
train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

# Create dataloaders.
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4
model = MultiModalClassifier(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    # Initialize tqdm progress bar
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Update running loss
        total_loss += loss.item()
        
        # Calculate batch accuracy
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        
        # Update progress bar with current loss and accuracy
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

# Updated training loop with progress evaluation
EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

# Optionally, run on test set.
test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


Epoch 1/4


                                                                                                                        

  Train Loss: 0.4767 | Train Acc: 82.91%
  Val   Loss: 0.3170 | Val   Acc: 88.39%
  Confusion Matrix:
[[285  60   2  25]
 [ 24 714  15  15]
 [  1  27 320   4]
 [ 16  19   1 272]]

Epoch 2/4


                                                                                                                        

  Train Loss: 0.2635 | Train Acc: 90.90%
  Val   Loss: 0.2875 | Val   Acc: 90.11%
  Confusion Matrix:
[[305  50   2  15]
 [ 25 726   8   9]
 [  3  22 324   3]
 [ 18  21   2 267]]

Epoch 3/4


                                                                                                                        

  Train Loss: 0.2024 | Train Acc: 92.79%
  Val   Loss: 0.2992 | Val   Acc: 90.61%
  Confusion Matrix:
[[327  27   2  16]
 [ 33 716  11   8]
 [  4  19 325   4]
 [ 21  23   1 263]]

Epoch 4/4


                                                                                                                        

  Train Loss: 0.1627 | Train Acc: 94.37%
  Val   Loss: 0.3404 | Val   Acc: 88.78%
  Confusion Matrix:
[[342  16   3  11]
 [ 56 688  16   8]
 [  6  20 322   4]
 [ 42  19   1 246]]

Test Results
  Test Loss:   0.5367 | Accuracy: 83.77%
  Confusion Matrix:
[[599  64  18  14]
 [ 99 941  39   7]
 [ 32  22 742   3]
 [148  96  15 593]]


Epoch 1/4

                                                                                                                        

  Train Loss: 0.4767 | Train Acc: 82.91%
  Val   Loss: 0.3170 | Val   Acc: 88.39%
  Confusion Matrix:
[[285  60   2  25]
 [ 24 714  15  15]
 [  1  27 320   4]
 [ 16  19   1 272]]

Epoch 2/4

                                                                                                                        

  Train Loss: 0.2635 | Train Acc: 90.90%
  Val   Loss: 0.2875 | Val   Acc: 90.11%
  Confusion Matrix:
[[305  50   2  15]
 [ 25 726   8   9]
 [  3  22 324   3]
 [ 18  21   2 267]]

Epoch 3/4

                                                                                                                        

  Train Loss: 0.2024 | Train Acc: 92.79%
  Val   Loss: 0.2992 | Val   Acc: 90.61%
  Confusion Matrix:
[[327  27   2  16]
 [ 33 716  11   8]
 [  4  19 325   4]
 [ 21  23   1 263]]

Epoch 4/4

                                                                                                                        

  Train Loss: 0.1627 | Train Acc: 94.37%
  Val   Loss: 0.3404 | Val   Acc: 88.78%
  Confusion Matrix:
[[342  16   3  11]
 [ 56 688  16   8]
 [  6  20 322   4]
 [ 42  19   1 246]]

Test Results
  Test Loss:   0.5367 | Accuracy: 83.77%
  Confusion Matrix:
[[599  64  18  14]
 [ 99 941  39   7]
 [ 32  22 742   3]
 [148  96  15 593]]


In [None]:
# gating mechanism

In [9]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition (Gated Fusion)
# ======================
class MultiModalClassifierGated(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze all image model parameters except those in layer4.
        """
        super(MultiModalClassifierGated, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.5)  # increased dropout
        
        # ---- Image model: ResNet18 ----
        self.image_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.image_model.fc = nn.Identity()
        if freeze_image_layers:
            for name, param in self.image_model.named_parameters():
                if not name.startswith('layer4'):
                    param.requires_grad = False
        if 512 != fusion_dim:
            self.image_projection = nn.Linear(512, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.5)  # increased dropout
        
        # ---- Gating Mechanism ----
        # The gate network computes element-wise weights for text and image features.
        # It takes the concatenated features (of size fusion_dim*2) and outputs gating weights in [0, 1]
        self.gate = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # --- Process text ---
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the [CLS] token representation (first token) as text feature.
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # (batch, 768)
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)  # (batch, fusion_dim)
        
        # --- Process image ---
        image_feat = self.image_model(image)  # (batch, 512)
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)  # (batch, fusion_dim)
        
        # --- Gating Fusion ---
        # Concatenate features from both modalities.
        combined_feats = torch.cat([text_feat, image_feat], dim=1)  # (batch, fusion_dim*2)
        # Compute gate weights (element-wise) from the combined features.
        gate_weights = self.gate(combined_feats)  # (batch, fusion_dim)
        # Fuse the modalities by weighting text and image features.
        fused_feat = gate_weights * text_feat + (1 - gate_weights) * image_feat  # (batch, fusion_dim)
        
        # --- Classification ---
        logits = self.classifier(fused_feat)
        return logits

# ======================
# 3. Data Preparation with Enhanced Augmentation
# ======================
# Enhanced training transforms with additional augmentation.
train_image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Instantiate tokenizer for DistilBERT.
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as in your example

# Define paths to your train, validation, and test folders.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

# Create dataset instances.
train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

# Create dataloaders.
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4
# Instantiate the gating fusion model.
model = MultiModalClassifierGated(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

# Using weight decay for regularization.
optimizer = optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Learning rate scheduler: ReduceLROnPlateau reduces the LR when the validation loss plateaus.
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    # Initialize tqdm progress bar
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Update running loss and accuracy
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

# Updated training loop with scheduler integration
EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    # Step the scheduler based on validation loss
    scheduler.step(val_loss)
    
    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

# Evaluate on the test set.
test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")




Epoch 1/4


                                                                                                                        

  Train Loss: 0.5415 | Train Acc: 80.17%
  Val   Loss: 0.3330 | Val   Acc: 87.94%
  Confusion Matrix:
[[298  51   6  17]
 [ 32 710  17   9]
 [  4  22 323   3]
 [ 29  22   5 252]]

Epoch 2/4


                                                                                                                        

  Train Loss: 0.2960 | Train Acc: 89.71%
  Val   Loss: 0.3107 | Val   Acc: 88.61%
  Confusion Matrix:
[[313  34   3  22]
 [ 36 675  31  26]
 [  6  17 326   3]
 [ 15   8   4 281]]

Epoch 3/4


                                                                                                                        

  Train Loss: 0.2313 | Train Acc: 91.86%
  Val   Loss: 0.3369 | Val   Acc: 89.00%
  Confusion Matrix:
[[316  28   1  27]
 [ 43 690   5  30]
 [ 11  27 309   5]
 [ 12   9   0 287]]

Epoch 4/4


                                                                                                                        

  Train Loss: 0.1926 | Train Acc: 93.02%
  Val   Loss: 0.3289 | Val   Acc: 89.61%
  Confusion Matrix:
[[303  49   3  17]
 [ 25 707  20  16]
 [  5  16 327   4]
 [ 10  18   4 276]]

Test Results
  Test Loss:   0.4768 | Accuracy: 84.82%
  Confusion Matrix:
[[442 151  28  74]
 [ 29 995  40  22]
 [ 10  21 762   6]
 [ 45  78  17 712]]


Epoch 1/4

                                                                                                                        

  Train Loss: 0.5415 | Train Acc: 80.17%
  Val   Loss: 0.3330 | Val   Acc: 87.94%
  Confusion Matrix:
[[298  51   6  17]
 [ 32 710  17   9]
 [  4  22 323   3]
 [ 29  22   5 252]]

Epoch 2/4

                                                                                                                        

  Train Loss: 0.2960 | Train Acc: 89.71%
  Val   Loss: 0.3107 | Val   Acc: 88.61%
  Confusion Matrix:
[[313  34   3  22]
 [ 36 675  31  26]
 [  6  17 326   3]
 [ 15   8   4 281]]

Epoch 3/4

                                                                                                                        

  Train Loss: 0.2313 | Train Acc: 91.86%
  Val   Loss: 0.3369 | Val   Acc: 89.00%
  Confusion Matrix:
[[316  28   1  27]
 [ 43 690   5  30]
 [ 11  27 309   5]
 [ 12   9   0 287]]

Epoch 4/4

                                                                                                                        

  Train Loss: 0.1926 | Train Acc: 93.02%
  Val   Loss: 0.3289 | Val   Acc: 89.61%
  Confusion Matrix:
[[303  49   3  17]
 [ 25 707  20  16]
 [  5  16 327   4]
 [ 10  18   4 276]]

Test Results
  Test Loss:   0.4768 | Accuracy: 84.82%
  Confusion Matrix:
[[442 151  28  74]
 [ 29 995  40  22]
 [ 10  21 762   6]
 [ 45  78  17 712]]


In [10]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition (Gated Fusion with ResNet50)
# ======================
class MultiModalClassifierGated(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze all image model parameters except those in layer4.
        """
        super(MultiModalClassifierGated, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.5)  # increased dropout
        
        # ---- Image model: ResNet50 ----
        self.image_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.image_model.fc = nn.Identity()
        if freeze_image_layers:
            for name, param in self.image_model.named_parameters():
                if not name.startswith('layer4'):
                    param.requires_grad = False
        # ResNet50 outputs 2048-dimensional features.
        if 2048 != fusion_dim:
            self.image_projection = nn.Linear(2048, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.5)  # increased dropout
        
        # ---- Gating Mechanism ----
        # Computes element-wise weights for text and image features.
        self.gate = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # --- Process text ---
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # use [CLS] token representation
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)  # project to fusion_dim
        
        # --- Process image ---
        image_feat = self.image_model(image)  # get 2048-dim features from ResNet50
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)  # project to fusion_dim
        
        # --- Gating Fusion ---
        combined_feats = torch.cat([text_feat, image_feat], dim=1)  # (batch, fusion_dim*2)
        gate_weights = self.gate(combined_feats)  # (batch, fusion_dim)
        fused_feat = gate_weights * text_feat + (1 - gate_weights) * image_feat  # element-wise fusion
        
        # --- Classification ---
        logits = self.classifier(fused_feat)
        return logits

# ======================
# 3. Data Preparation with Enhanced Augmentation
# ======================
train_image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as before

# Define dataset paths.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4

# Instantiate the gating fusion model with ResNet50.
model = MultiModalClassifierGated(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

# Use weight decay for regularization.
optimizer = optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Learning rate scheduler (without verbose parameter).
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    scheduler.step(val_loss)
    print("Current LR:", scheduler.optimizer.param_groups[0]['lr'])
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/jacks/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|███████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:00<00:00, 120MB/s]


Epoch 1/4


                                                                                                                        

  Train Loss: 0.4458 | Train Acc: 83.95%
  Val   Loss: 0.3161 | Val   Acc: 89.22%
  Confusion Matrix:
[[328  27   4  13]
 [ 45 702  14   7]
 [  9  18 321   4]
 [ 33  18   2 255]]

Current LR: 2e-05
Epoch 2/4


                                                                                                                        

  Train Loss: 0.2574 | Train Acc: 91.04%
  Val   Loss: 0.2884 | Val   Acc: 89.56%
  Confusion Matrix:
[[317  29   8  18]
 [ 30 695  30  13]
 [  5  17 326   4]
 [ 21  13   0 274]]

Current LR: 2e-05
Epoch 3/4


                                                                                                                        

  Train Loss: 0.1943 | Train Acc: 93.16%
  Val   Loss: 0.3003 | Val   Acc: 89.78%
  Confusion Matrix:
[[322  33   6  11]
 [ 29 710  17  12]
 [  6  22 320   4]
 [ 25  18   1 264]]

Current LR: 2e-05
Epoch 4/4


                                                                                                                        

  Train Loss: 0.1584 | Train Acc: 94.48%
  Val   Loss: 0.3171 | Val   Acc: 89.72%
  Confusion Matrix:
[[318  31   6  17]
 [ 30 705  22  11]
 [  6  19 325   2]
 [ 19  18   4 267]]

Current LR: 1e-05
Test Results
  Test Loss:   0.4935 | Accuracy: 84.99%
  Confusion Matrix:
[[495 120  29  51]
 [ 38 991  43  14]
 [ 17  21 756   5]
 [ 53 100  24 675]]


Epoch 1/4

                                                                                                                        

  Train Loss: 0.4458 | Train Acc: 83.95%
  Val   Loss: 0.3161 | Val   Acc: 89.22%
  Confusion Matrix:
[[328  27   4  13]
 [ 45 702  14   7]
 [  9  18 321   4]
 [ 33  18   2 255]]

Current LR: 2e-05
Epoch 2/4

                                                                                                                        

  Train Loss: 0.2574 | Train Acc: 91.04%
  Val   Loss: 0.2884 | Val   Acc: 89.56%
  Confusion Matrix:
[[317  29   8  18]
 [ 30 695  30  13]
 [  5  17 326   4]
 [ 21  13   0 274]]

Current LR: 2e-05
Epoch 3/4

                                                                                                                        

  Train Loss: 0.1943 | Train Acc: 93.16%
  Val   Loss: 0.3003 | Val   Acc: 89.78%
  Confusion Matrix:
[[322  33   6  11]
 [ 29 710  17  12]
 [  6  22 320   4]
 [ 25  18   1 264]]

Current LR: 2e-05
Epoch 4/4

                                                                                                                        

  Train Loss: 0.1584 | Train Acc: 94.48%
  Val   Loss: 0.3171 | Val   Acc: 89.72%
  Confusion Matrix:
[[318  31   6  17]
 [ 30 705  22  11]
 [  6  19 325   2]
 [ 19  18   4 267]]

Current LR: 1e-05
Test Results
  Test Loss:   0.4935 | Accuracy: 84.99%
  Confusion Matrix:
[[495 120  29  51]
 [ 38 991  43  14]
 [ 17  21 756   5]
 [ 53 100  24 675]]


In [None]:
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import confusion_matrix, accuracy_score

# ======================
# 1. Custom Multimodal Dataset
# ======================
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                # sort files to ensure consistent order
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        # Extract text from file name
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
            
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# ======================
# 2. Multimodal Model Definition (Gated Fusion with ResNet50)
# ======================
class MultiModalClassifierGated(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze all image model parameters except those in layer4.
        """
        super(MultiModalClassifierGated, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.5)  # increased dropout
        
        # ---- Image model: ResNet50 ----
        self.image_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.image_model.fc = nn.Identity()
        if freeze_image_layers:
            for name, param in self.image_model.named_parameters():
                if not name.startswith('layer4'):
                    param.requires_grad = False
        # ResNet50 outputs 2048-dimensional features.
        if 2048 != fusion_dim:
            self.image_projection = nn.Linear(2048, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.5)  # increased dropout
        
        # ---- Gating Mechanism ----
        # Computes element-wise weights for text and image features.
        self.gate = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # --- Process text ---
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # use [CLS] token representation
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)  # project to fusion_dim
        
        # --- Process image ---
        image_feat = self.image_model(image)  # get 2048-dim features from ResNet50
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)  # project to fusion_dim
        
        # --- Gating Fusion ---
        combined_feats = torch.cat([text_feat, image_feat], dim=1)  # (batch, fusion_dim*2)
        gate_weights = self.gate(combined_feats)  # (batch, fusion_dim)
        fused_feat = gate_weights * text_feat + (1 - gate_weights) * image_feat  # element-wise fusion
        
        # --- Classification ---
        logits = self.classifier(fused_feat)
        return logits

# ======================
# 3. Data Preparation with Enhanced Augmentation
# ======================
train_image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24  # as before

# Define dataset paths.
TRAIN_PATH = './CVPR_2024_dataset_Train'
VAL_PATH = './CVPR_2024_dataset_Val'
TEST_PATH = './CVPR_2024_dataset_Test'

train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)

# ======================
# Create DataLoaders with Balancing for Training Data
# ======================
# Compute sample weights based on class frequencies in the training set.
train_labels = [sample[2] for sample in train_dataset.samples]
class_counts = np.bincount(train_labels)
class_weights = 1.0 / class_counts  # Inverse frequency for each class.
sample_weights = [class_weights[label] for label in train_labels]

# Create the sampler for balanced sampling.
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# Use the sampler in the training DataLoader.
train_loader = DataLoader(train_dataset, batch_size=8, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# ======================
# 4. Model, Optimizer, and Loss Function Setup
# ======================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4

# Instantiate the gating fusion model with ResNet50.
model = MultiModalClassifierGated(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

# Use weight decay for regularization.
optimizer = optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Learning rate scheduler (without verbose parameter).
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

# ======================
# 5. Training and Evaluation Loops
# ======================
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0

    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), accuracy=running_corrects / total_samples)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm

EPOCHS = 4
best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm = evaluate_model(model, val_loader, criterion, device)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    scheduler.step(val_loss)
    print("Current LR:", scheduler.optimizer.param_groups[0]['lr'])
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_multimodal_model.pth')

test_loss, test_acc, test_cm = evaluate_model(model, test_loader, criterion, device)
print("Test Results")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


/home/jacks/miniconda3/envs/enen_645/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Epoch 1/4

                                                                                                                        

  Train Loss: 0.4106 | Train Acc: 85.17%
  Val   Loss: 0.3581 | Val   Acc: 86.89%
  Confusion Matrix:
[[324  19   5  24]
 [ 68 646  34  20]
 [  8  14 327   3]
 [ 22  16   3 267]]

Current LR: 2e-05
Epoch 2/4

                                                                                                                        

  Train Loss: 0.2303 | Train Acc: 92.28%
  Val   Loss: 0.3172 | Val   Acc: 90.17%
  Confusion Matrix:
[[333  21   1  17]
 [ 46 692  13  17]
 [ 11  15 321   5]
 [ 13  18   0 277]]

Current LR: 2e-05
Epoch 3/4

                                                                                                                        

  Train Loss: 0.1822 | Train Acc: 93.75%
  Val   Loss: 0.3428 | Val   Acc: 90.22%
  Confusion Matrix:
[[326  23   6  17]
 [ 34 708  12  14]
 [  7  20 320   5]
 [ 21  16   1 270]]

Current LR: 2e-05
Epoch 4/4

                                                                                                                        

  Train Loss: 0.1457 | Train Acc: 95.00%
  Val   Loss: 0.3300 | Val   Acc: 90.22%
  Confusion Matrix:
[[324  25   6  17]
 [ 32 712  15   9]
 [  7  24 317   4]
 [ 15  21   1 271]]

Current LR: 1e-05
Test Results
  Test Loss:   0.5175 | Accuracy: 84.47%
  Confusion Matrix:
[[488 106  23  78]
 [ 51 987  37  11]
 [ 23  28 739   9]
 [ 50 102  15 685]]
