In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import numpy as np
import re
import torch.nn.utils.rnn as rnn_utils
from torchvision import models, transforms
device = torch.device("cpu")  # Fallback to CPU

# Preprocessing for the images
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.Resize((224, 224)),  # ResNet50 expects 224x224 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Vocabulary and tokenization for text processing
class Vocabulary:
    def __init__(self):
        # Initialize with "<unk>" token for unknown words
        self.word2idx = {"<unk>": 0}
        self.idx2word = {0: "<unk>"}
        self.idx = 1  # Start from 1 because 0 is reserved for "<unk>"

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        # Return the index for the word or "<unk>" if not found
        return self.word2idx.get(word, self.word2idx["<unk>"])

    def __len__(self):
        return len(self.word2idx)

def tokenize_text(text):
    return text.lower().split()

# Custom Dataset class for both image and text input
class CustomImageTextDataset(Dataset):
    def __init__(self, root_dir, transform=None, vocab=None):
        self.root_dir = root_dir
        self.transform = transform
        self.class_names = ['Black', 'Blue', 'Green', 'TTR']
        self.vocab = vocab
        
        # Collect all image file paths and corresponding labels from subfolders
        self.image_paths = []
        self.labels = []
        
        for class_name in self.class_names:
            class_dir = os.path.join(self.root_dir, class_name)
            for img_file in os.listdir(class_dir):
                if img_file.endswith('.jpg') or img_file.endswith('.jpeg') or img_file.endswith('.png'):
                    self.image_paths.append(os.path.join(class_dir, img_file))
                    self.labels.append(self.class_names.index(class_name))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load the image
        image = Image.open(img_path).convert('RGB')

        # Extract the text from the filename (if needed)
        img_name = os.path.basename(img_path)
        label_text = re.sub(r'\d+', '', img_name.split('.')[0]).strip().lower()
        text_tokens = tokenize_text(label_text)

        # Convert text tokens to indices using the vocabulary
        text_indices = [self.vocab(token) for token in text_tokens]

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(text_indices), label

# Load GloVe Embeddings
def load_glove_embeddings(glove_file, vocab):
    embedding_dim = 100  # GloVe 100D embeddings
    embeddings_index = {}

    with open(glove_file, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.array(values[1:], dtype='float32')
            embeddings_index[word] = vector

    # Create the embedding matrix for our vocabulary
    embedding_matrix = np.zeros((len(vocab), embedding_dim))

    for word, idx in vocab.word2idx.items():
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[idx] = embedding_vector
        else:
            embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

    return embedding_matrix

# Define the Model
class ImageTextResNet50(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, pretrained_embeddings=None):
        super(ImageTextResNet50, self).__init__()
        
        # Load pretrained ResNet50 model for image classification
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 128)  # Replace final layer

        # Use GloVe pretrained embeddings for text
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False)
        else:
            self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc_text = nn.Linear(hidden_dim, 128)

        # Combine the features from both image and text
        self.fc_combined = nn.Linear(256, num_classes)

    def forward(self, image, text):
        # Ensure inputs are in float32
        image = image.float()
        text = text.long()

        # Forward pass through ResNet50 for images
        x_img = self.resnet(image)

        # Forward pass through embedding + LSTM for text
        embedded_text = self.embedding(text)
        embedded_text = embedded_text.float()  # Ensure embeddings are float32
        lstm_out, _ = self.lstm(embedded_text)
        x_text = torch.relu(self.fc_text(lstm_out[:, -1, :]))  # Last output of LSTM

        # Combine image and text features
        x_combined = torch.cat((x_img, x_text), dim=1)
        output = self.fc_combined(x_combined)
        return output


# Define custom collate function for padding text sequences
def collate_fn(batch):
    images, texts, labels = zip(*batch)
    
    # Stack images and labels
    images = torch.stack(images, 0)
    labels = torch.tensor(labels)
    
    # Pad the text sequences
    lengths = [len(text) for text in texts]
    padded_texts = rnn_utils.pad_sequence(texts, batch_first=True, padding_value=0)
    
    return images, padded_texts, labels

# Training loop
def train_model(model, train_loader, val_loader, num_epochs, criterion, optimizer, scheduler=None):
    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        running_loss = 0.0
        for images, text, labels in train_loader:
            images, text, labels = images.to(device), text.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images, text)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
        
        # Optional: Use a learning rate scheduler
        if scheduler:
            scheduler.step()

        # Validation after each epoch
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, text, labels in val_loader:
                images, text, labels = images.to(device), text.to(device), labels.to(device)
                outputs = model(images, text)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Calculate accuracy
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = 100 * correct / total
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%')

# Load datasets and DataLoader
train_dataset = CustomImageTextDataset(root_dir=train_data_path, transform=train_transform, vocab=vocab)
val_dataset = CustomImageTextDataset(root_dir=val_data_path, transform=test_transform, vocab=vocab)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Load GloVe embeddings
glove_file = '../../data/glove.6B/glove.6B.100d.txt'
embedding_matrix = load_glove_embeddings(glove_file, vocab)


# Initialize the model
num_classes = 4
model = ImageTextResNet50(vocab_size=len(vocab), embedding_dim=100, hidden_dim=128, num_classes=num_classes, pretrained_embeddings=torch.tensor(embedding_matrix))

# Move the model to the appropriate device
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training, Validation, and Testing loop
num_epochs = 5

for epoch in range(num_epochs):
    # Training phase
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for images, text, labels in train_loader:
        images, text, labels = images.to(device), text.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images, text)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {running_loss/len(train_loader):.4f}')

    # Validation phase
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient calculation for validation
        for images, text, labels in val_loader:
            images, text, labels = images.to(device), text.to(device), labels.to(device)
            outputs = model(images, text)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100 * correct / total
    print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%')

# Testing phase
model.eval()  # Set the model to evaluation mode
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():  # Disable gradient calculation for testing
    for images, text, labels in test_loader:
        images, text, labels = images.to(device), text.to(device), labels.to(device)
        outputs = model(images, text)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f'Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {test_accuracy:.2f}%')
