In [1]:
# Cell 1: Imports and Dependencies
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score
from transformers import DistilBertTokenizer, DistilBertModel
from tqdm import tqdm


In [2]:
# Cell 2: Custom Multimodal Dataset
# This dataset expects a folder structure where each subfolder corresponds to a class.
# The text is extracted from the image file names.
class MultiModalDataset(Dataset):
    def __init__(self, root_dir, tokenizer, max_len, image_transform=None):
        """
        Expects a folder structure:
            root_dir/
                class1/
                    image1.jpg
                    image2.jpg
                    ...
                class2/
                    ...
        The text is extracted from the file name (by removing extension, replacing underscores, etc.)
        """
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_transform = image_transform
        self.samples = []  # each element is a tuple (image_path, text, label)
        self.class_folders = sorted(os.listdir(root_dir))
        self.label_map = {cls: idx for idx, cls in enumerate(self.class_folders)}
        
        for cls in self.class_folders:
            cls_path = os.path.join(root_dir, cls)
            if os.path.isdir(cls_path):
                for file in sorted(os.listdir(cls_path)):
                    file_path = os.path.join(cls_path, file)
                    if os.path.isfile(file_path):
                        file_name_no_ext, _ = os.path.splitext(file)
                        text = file_name_no_ext.replace('_', ' ')
                        text = re.sub(r'\d+', '', text)
                        self.samples.append((file_path, text, self.label_map[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_path, text, label = self.samples[idx]
        image = Image.open(image_path).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'image': image,
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }


In [3]:
# Cell 3: Multimodal Model Definition
# The model uses DistilBERT for text and a RegNet model for images.
# A MLP layer fuses the two modalities.
class MultiModalClassifier(nn.Module):
    def __init__(self, num_classes, fusion_dim=512, text_model_name='distilbert-base-uncased', freeze_image_layers=True):
        """
        fusion_dim: common embedding dimension for both modalities.
        freeze_image_layers: if True, freeze most image model parameters except for the last stage ('s4'),
        ensuring that only one layer (i.e. the s4 block) is tuned in the vision branch.
        """
        super(MultiModalClassifier, self).__init__()
        self.fusion_dim = fusion_dim
        
        # ---- Text model: DistilBERT ----
        self.text_model = DistilBertModel.from_pretrained(text_model_name)
        self.text_projection = nn.Linear(self.text_model.config.hidden_size, fusion_dim)
        self.text_dropout = nn.Dropout(0.5)
        
        # ---- Image model: RegNet_Y_128GF ----
        self.image_model = models.regnet_y_128gf(weights=models.RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1)
        # Remove the final fc layer to obtain features.
        image_feature_dim = self.image_model.fc.in_features
        self.image_model.fc = nn.Identity()
        
        if freeze_image_layers:
            # Freeze all parameters except those in the last stage ('s4').
            for name, param in self.image_model.named_parameters():
                if not name.startswith('s4'):
                    param.requires_grad = False
        
        # If necessary, project image features to fusion_dim.
        if image_feature_dim != fusion_dim:
            self.image_projection = nn.Linear(image_feature_dim, fusion_dim)
        else:
            self.image_projection = nn.Identity()
        self.image_dropout = nn.Dropout(0.5)
 
        # ---- MLP-based Fusion Mechanism ----
        # The MLP takes concatenated text and image features and learns a fused representation.
        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(fusion_dim, fusion_dim)
        )
        
        # ---- Classification Head ----
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def forward(self, input_ids, attention_mask, image):
        # Text branch: extract [CLS] token representation from DistilBERT.
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_feat = text_outputs.last_hidden_state[:, 0, :]
        text_feat = self.text_dropout(text_feat)
        text_feat = self.text_projection(text_feat)
        
        # Image branch: extract features from the image model.
        image_feat = self.image_model(image)
        image_feat = self.image_dropout(image_feat)
        image_feat = self.image_projection(image_feat)
        
        # Fusion: concatenate text and image features and pass through the MLP.
        fused_feat = self.fusion(torch.cat([text_feat, image_feat], dim=1))
        
        # Classification
        logits = self.classifier(fused_feat)
        return logits


In [4]:
# Cell 4: Data Preparation and Augmentation
# Define image transformations and initialize the tokenizer.
train_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 24

# Paths for train, validation, and test datasets.
TRAIN_PATH = './garbage_data/CVPR_2024_dataset_Train'
VAL_PATH = './garbage_data/CVPR_2024_dataset_Val'
TEST_PATH = './garbage_data/CVPR_2024_dataset_Test'

# Create dataset instances.
train_dataset = MultiModalDataset(TRAIN_PATH, tokenizer, max_len, image_transform=train_image_transform)
val_dataset = MultiModalDataset(VAL_PATH, tokenizer, max_len, image_transform=val_test_transform)
test_dataset = MultiModalDataset(TEST_PATH, tokenizer, max_len, image_transform=val_test_transform)


In [5]:
# Cell 5: DataLoader Setup
# Use a WeightedRandomSampler to handle class imbalance in the training data.
train_labels = [sample[2] for sample in train_dataset.samples]
class_counts = np.bincount(train_labels)
class_weights = 1.0 / class_counts
sample_weights = [class_weights[label] for label in train_labels]

sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [6]:
# Cell 6: Model, Optimizer, and Loss Function Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 4

# Instantiate the multimodal classifier.
model = MultiModalClassifier(num_classes=num_classes, fusion_dim=512, freeze_image_layers=True)
model = model.to(device)

# Define the optimizer, loss function, and learning rate scheduler.
optimizer = optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-3)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)


In [7]:
# Cell 7: Training and Evaluation with Early Stopping
# The following functions handle a single training epoch and evaluation on validation data.
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    running_corrects = 0
    total_samples = 0
    batch_losses = []
    batch_accs = []
    
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_corrects += torch.sum(preds == labels).item()
        total_samples += labels.size(0)
        
        # Batch metrics
        batch_loss = loss.item()
        batch_accuracy = (torch.sum(preds == labels).item() / labels.size(0))
        batch_losses.append(batch_loss)
        batch_accs.append(batch_accuracy)
        
        progress_bar.set_postfix(loss=batch_loss, accuracy=batch_accuracy)
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = running_corrects / total_samples
    return epoch_loss, epoch_acc, batch_losses, batch_accs

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    batch_losses = []
    batch_accs = []
    
    progress_bar = tqdm(dataloader, desc="Validation", leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
            # Batch metrics for evaluation
            batch_loss = loss.item()
            batch_accuracy = (torch.sum(outputs.argmax(dim=1) == labels).item() / labels.size(0))
            batch_losses.append(batch_loss)
            batch_accs.append(batch_accuracy)
            
            progress_bar.set_postfix(loss=batch_loss, accuracy=batch_accuracy)
    
    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    return avg_loss, acc, cm, batch_losses, batch_accs



In [None]:

# Initialize metric tracking lists.
train_losses = []
train_accs = []
val_losses = []
val_accs = []

train_batches_losses = []
train_batches_accs = []
val_batches_losses = []
val_batches_accs = []

# Early stopping parameters.
EPOCHS = 10
best_val_loss = float('inf')
patience = 2
patience_counter = 0

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    
    train_loss, train_acc, epoch_train_batch_losses, epoch_train_batch_accs = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_cm, epoch_val_batch_losses, epoch_val_batch_accs = evaluate_model(model, val_loader, criterion, device)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    train_batches_losses.append(epoch_train_batch_losses)
    train_batches_accs.append(epoch_train_batch_accs)
    val_batches_losses.append(epoch_val_batch_losses)
    val_batches_accs.append(epoch_val_batch_accs)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc*100:.2f}%")
    print(f"  Confusion Matrix:\n{val_cm}\n")
    
    scheduler.step(val_loss)
    print("Current LR:", scheduler.optimizer.param_groups[0]['lr'])
    
    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_multimodal_model.pth')
        print("Model improved. Saving model.")
    else:
        patience_counter += 1
        print(f"No improvement in validation loss. Patience counter: {patience_counter}/{patience}")
    
    if patience_counter >= patience:
        print("Early stopping triggered.")
        break

Epoch 1/10


                                                                                                                        

  Train Loss: 0.6304 | Train Acc: 76.43%
  Val   Loss: 0.3683 | Val   Acc: 87.44%
  Confusion Matrix:
[[303  28   3  38]
 [ 49 666  19  34]
 [  9  18 315  10]
 [  3  12   3 290]]

Current LR: 2e-05
Model improved. Saving model.
Epoch 2/10


                                                                                                                        

  Train Loss: 0.2703 | Train Acc: 90.86%
  Val   Loss: 0.3267 | Val   Acc: 88.06%
  Confusion Matrix:
[[321  25   3  23]
 [ 63 670  18  17]
 [ 13  18 318   3]
 [ 12  17   3 276]]

Current LR: 2e-05
Model improved. Saving model.
Epoch 3/10


                                                                                                                        

  Train Loss: 0.2195 | Train Acc: 92.52%
  Val   Loss: 0.3296 | Val   Acc: 88.72%
  Confusion Matrix:
[[327  19   1  25]
 [ 63 670  16  19]
 [ 12  20 316   4]
 [  7  15   2 284]]

Current LR: 2e-05
No improvement in validation loss. Patience counter: 1/2
Epoch 4/10


Training:  12%|████▉                                      | 42/364 [01:41<12:02,  2.24s/it, accuracy=0.969, loss=0.0872]

In [None]:
# Load the best saved model before evaluating on the test set
model.load_state_dict(torch.load('best_multimodal_model.pth'))
model.to(device)  # Ensure the model is on the correct device

# Now evaluate on the test set
test_loss, test_acc, test_cm, test_batch_losses, test_batch_accs = evaluate_model(model, test_loader, criterion, device)

print("Test Results (Best Model)")
print(f"  Test Loss:   {test_loss:.4f} | Accuracy: {test_acc*100:.2f}%")
print(f"  Confusion Matrix:\n{test_cm}")


In [None]:
# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(test_cm, annot=True, cmap='Blues', fmt='g', cbar=False)

plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()

In [None]:
~7,4gb