In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from transformers import BertModel, BertTokenizer
from PIL import Image
from sklearn.model_selection import train_test_split

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f'Using device: {device}')

# Data preparation
def load_data(data_dir):
    image_paths = []
    text_descriptions = []
    labels = []
    label_map = {"Black": 0, "Blue": 1, "Green": 2, "TTR": 3}

    # Traverse each folder corresponding to a label
    for label_name, label_idx in label_map.items():
        folder_path = os.path.join(data_dir, label_name)
        if not os.path.isdir(folder_path):
            continue

        # Iterate through all files in the folder
        for filename in os.listdir(folder_path):
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_path = os.path.join(folder_path, filename)
                image_paths.append(image_path)

                # Use the filename as the text description
                text_description = ' '.join(filename.split('_')[:-1])  # Assuming file format: description_words_number.jpg
                text_descriptions.append(text_description)

                labels.append(label_idx)

    return image_paths, text_descriptions, labels

# Load data
train_dir = '../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Train'
image_paths, text_descriptions, labels = load_data(train_dir)

# Split into training and validation sets
train_image_paths, val_image_paths, train_text_descriptions, val_text_descriptions, train_labels, val_labels = train_test_split(
    image_paths, text_descriptions, labels, test_size=0.2, random_state=42
)

# Define the custom dataset
class GarbageDataset(Dataset):
    def __init__(self, image_paths, text_descriptions, labels, transform=None, tokenizer=None, max_length=128):
        self.image_paths = image_paths
        self.text_descriptions = text_descriptions
        self.labels = labels
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Tokenize text
        text = self.text_descriptions[idx]
        encoded_text = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt',
            truncation=True
        )
        input_ids = encoded_text['input_ids'].squeeze(0)
        attention_mask = encoded_text['attention_mask'].squeeze(0)

        # Load label
        label = self.labels[idx]

        return image, input_ids, attention_mask, label

# Data preprocessing
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create datasets and dataloaders
train_dataset = GarbageDataset(train_image_paths, train_text_descriptions, train_labels, transform=transform, tokenizer=tokenizer)
val_dataset = GarbageDataset(val_image_paths, val_text_descriptions, val_labels, transform=transform, tokenizer=tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Multimodal model combining ResNet50 and BERT
class MultimodalModel(nn.Module):
    def __init__(self, num_classes):
        super(MultimodalModel, self).__init__()
        # Load pretrained ResNet50 model
        self.resnet = models.resnet50(pretrained=True)
        # Replace the last fully connected layer
        self.resnet.fc = nn.Identity()
        self.resnet_feature_dim = 2048  # ResNet50 output feature size

        # Load pretrained BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.bert_feature_dim = 768  # BERT output feature size

        # Enhanced fully connected layers to combine features
        self.fc = nn.Sequential(
            nn.Linear(self.resnet_feature_dim + self.bert_feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        # Extract image features
        image_features = self.resnet(images)

        # Extract text features
        text_features = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output

        # Concatenate image and text features
        combined_features = torch.cat((image_features, text_features), dim=1)

        # Pass through the fully connected layers
        output = self.fc(combined_features)
        return output

# Hyperparameters
num_classes = 4  # e.g., Blue, Green, Black, TTR
learning_rate = 0.001
num_epochs = 5

# Model, loss function, and optimizer
model = MultimodalModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, input_ids, attention_mask, labels in train_loader:
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images, input_ids, attention_mask)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Validation phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in val_loader:
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            outputs = model(images, input_ids, attention_mask)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100 * correct / total
    print(f'Epoch [{epoch + 1}/{num_epochs}], Training Loss: {running_loss / len(train_loader):.4f}, Validation Loss: {val_loss / len(val_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%')

# Save the model
torch.save(model.state_dict(), 'multimodal_model_resnet50.pth')


Using device: mps


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Epoch [1/5], Training Loss: 1.2628, Validation Loss: 1.1942, Validation Accuracy: 50.39%
Epoch [2/5], Training Loss: 1.1783, Validation Loss: 1.1819, Validation Accuracy: 48.24%


KeyboardInterrupt: 