In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import numpy as np
import re

# Preprocessing for the images
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.Resize((224, 224)),  # ResNet50 expects 224x224 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Vocabulary and tokenization for text processing
class Vocabulary:
    def __init__(self):
        # Initialize with "<unk>" token for unknown words
        self.word2idx = {"<unk>": 0}
        self.idx2word = {0: "<unk>"}
        self.idx = 1  # Start from 1 because 0 is reserved for "<unk>"

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        # Return the index for the word or "<unk>" if not found
        return self.word2idx.get(word, self.word2idx["<unk>"])

    def __len__(self):
        return len(self.word2idx)

def tokenize_text(text):
    return text.lower().split()

# Custom Dataset class for both image and text input
class CustomImageTextDataset(Dataset):
    def __init__(self, root_dir, transform=None, vocab=None):
        self.root_dir = root_dir
        self.transform = transform
        self.class_names = ['Black', 'Blue', 'Green', 'TTR']
        self.vocab = vocab
        
        # Collect all image file paths and corresponding labels from subfolders
        self.image_paths = []
        self.labels = []
        
        for class_name in self.class_names:
            class_dir = os.path.join(self.root_dir, class_name)
            for img_file in os.listdir(class_dir):
                if img_file.endswith('.jpg') or img_file.endswith('.jpeg') or img_file.endswith('.png'):
                    self.image_paths.append(os.path.join(class_dir, img_file))
                    self.labels.append(self.class_names.index(class_name))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load the image
        image = Image.open(img_path).convert('RGB')

        # Extract the text from the filename (if needed)
        img_name = os.path.basename(img_path)
        label_text = re.sub(r'\d+', '', img_name.split('.')[0]).strip().lower()
        text_tokens = tokenize_text(label_text)

        # Convert text tokens to indices using the vocabulary
        text_indices = [self.vocab(token) for token in text_tokens]

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(text_indices), label

# Load GloVe Embeddings
def load_glove_embeddings(glove_file, vocab):
    embedding_dim = 100  # GloVe 100D embeddings
    embeddings_index = {}

    with open(glove_file, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.array(values[1:], dtype='float32')
            embeddings_index[word] = vector

    # Create the embedding matrix for our vocabulary
    embedding_matrix = np.zeros((len(vocab), embedding_dim))

    for word, idx in vocab.word2idx.items():
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[idx] = embedding_vector
        else:
            embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

    return embedding_matrix

# Define the Model
class ImageTextResNet50(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, pretrained_embeddings=None):
        super(ImageTextResNet50, self).__init__()
        
        # Load pretrained ResNet50 model for image classification
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 128)  # Replace final layer

        # Use GloVe pretrained embeddings for text
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False)
        else:
            self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc_text = nn.Linear(hidden_dim, 128)

        # Combine the features from both image and text
        self.fc_combined = nn.Linear(256, num_classes)

    def forward(self, image, text):
        # Forward pass through ResNet50
        x_img = self.resnet(image)

        # Forward pass through embedding + LSTM for text
        embedded_text = self.embedding(text)
        lstm_out, _ = self.lstm(embedded_text)
        x_text = torch.relu(self.fc_text(lstm_out[:, -1, :]))  # Last output of LSTM

        # Combine image and text features
        x_combined = torch.cat((x_img, x_text), dim=1)
        output = self.fc_combined(x_combined)
        return output

# Training and validation code would be as before:
# Create DataLoaders, loss function, optimizer, etc.
