In [None]:
import os

import cv2
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import torchvision.transforms.functional as F_v
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler, ConcatDataset
from torchvision import models
from tqdm.auto import tqdm

In [None]:
TRAIN_METADATA_CSV = "/kaggle/input/kaggleisic-challenge/new-train-metadata.csv"
TEST_METADATA_CSV = "/kaggle/input/kaggleisic-challenge/students-test-metadata.csv"
TRAIN_METADATA_PROCESSED_CSV = (
    "/kaggle/input/kaggleisic-challenge/train-metadata-processed.csv"
)
TEST_METADATA_PROCESSED_CSV = (
    "/kaggle/input/kaggleisic-challenge/test-metadata-processed.csv"
)
TRAIN_HDF5 = "/kaggle/input/kaggleisic-challenge/train-image.hdf5"
TEST_HDF5 = "/kaggle/input/kaggleisic-challenge/test-image.hdf5"
OUTPUT_FINAL_MODEL = "/kaggle/working/final_model.pth"
OUTPUT_FINAL_SUBMISSION = "/kaggle/working/final_submission.csv"

DROP_COLUMNS = [
    "image_type",
    "patient_id",
    "copyright_license",
    "attribution",
    "anatom_site_general",
    "tbp_lv_location_simple",
]

In [None]:
class ISIC_HDF5_Dataset(Dataset):
    """
    A PyTorch Dataset that loads images from an HDF5 file given a DataFrame of IDs.
    Applies image transforms.
    """

    def __init__(
        self, df: pd.DataFrame, hdf5_path: str, transform=None, is_labelled: bool = True
    ):
        """
        Args:
            df (pd.DataFrame): DataFrame containing 'isic_id' and optionally 'target'.
            hdf5_path (str): Path to the HDF5 file containing images.
            transform (callable): Optional transforms to be applied on a sample.
            is_labelled (bool): Whether the dataset includes labels (for train/val).
        """
        self.df = df.reset_index(drop=True)
        self.hdf5_path = hdf5_path
        self.transform = transform
        self.is_labelled = is_labelled

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        isic_id = row["isic_id"]

        # Load image from HDF5
        image_rgb = self._load_image_from_hdf5(isic_id)

        # Apply transforms (PIL-style transforms require converting np array to PIL, or we can do tensor transforms)
        if self.transform is not None:
            # Convert NumPy array (H x W x C) to a PIL Imag
            image_pil = F_v.to_pil_image(image_rgb)
            image = self.transform(image_pil)
        else:
            # By default, convert it to a PIL Image
            view_transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
            image_pil = F_v.to_pil_image(image_rgb)
            image = view_transform(image_pil)

        if self.is_labelled:
            label = row["target"]
            label = torch.tensor(label).float()
            return image, label, isic_id
        else:
            return image, isic_id

    def _load_image_from_hdf5(self, isic_id: str):
        """
        Loads and decodes an image from HDF5 by isic_id.
        Returns a NumPy array in RGB format (H x W x 3).
        """
        with h5py.File(self.hdf5_path, "r") as hf:
            encoded_bytes = hf[isic_id][()]  # uint8 array

        # Decode the image bytes with OpenCV (returns BGR)
        image_bgr = cv2.imdecode(encoded_bytes, cv2.IMREAD_COLOR)
        # Convert to RGB
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        return image_rgb


class ISIC_Multimodal_Dataset(Dataset):
    """
    A PyTorch Dataset that loads images from an HDF5 file and metadata from a DataFrame.
    Supports optional transforms and training/testing mode.
    """

    def __init__(self, df, hdf5_path: str, transform=None, is_labelled: bool = True):
        """
        Args:
            df (pd.DataFrame): DataFrame containing 'isic_id', metadata features, and optionally 'target'.
            hdf5_path (str): Path to the HDF5 file containing images.
            transform (callable): Optional image transforms.
            is_labelled (bool): Whether the dataset includes labels (for training/validation).
        """
        self.df = df.reset_index(drop=True)
        self.hdf5_path = hdf5_path
        self.transform = transform
        self.is_labelled = is_labelled

        # Identify metadata columns (exclude isic_id and target)
        self.metadata_cols = [
            col for col in self.df.columns if col not in ["isic_id", "target"]
        ]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        isic_id = row["isic_id"]

        # --- Load and transform image ---
        image_rgb = self._load_image_from_hdf5(isic_id)

        if self.transform is not None:
            image_pil = F_v.to_pil_image(image_rgb)
            image = self.transform(image_pil)
        else:
            default_transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
            image_pil = F_v.to_pil_image(image_rgb)
            image = default_transform(image_pil)

        # --- Load metadata ---
        metadata = torch.tensor(row[self.metadata_cols].values.astype("float32"))

        if self.is_labelled:
            label = torch.tensor(row["target"]).float()
            return metadata, image, label
        else:
            return metadata, image, isic_id

    def _load_image_from_hdf5(self, isic_id: str):
        """
        Loads and decodes an image from HDF5 by isic_id.
        Returns a NumPy array in RGB format (H x W x 3).
        """
        with h5py.File(self.hdf5_path, "r") as hf:
            encoded_bytes = hf[isic_id][()]
        image_bgr = cv2.imdecode(encoded_bytes, cv2.IMREAD_COLOR)
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

        return image_rgb


class ISIC_Metadata_Dataset(Dataset):
    def __init__(self, df: pd.DataFrame, is_labelled: bool = True):
        self.df = df.reset_index(drop=True)
        self.is_labelled = is_labelled

        # Store input features separately for safety
        self.features = df.drop(
            columns=["target", "isic_id"], errors="ignore"
        ).values.astype("float32")

        self.isic_ids = df["isic_id"].values.astype("str")

        if self.is_labelled:
            self.labels = df["target"].values.astype("float32")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        metadata = torch.tensor(self.features[idx])
        isic_id = self.isic_ids[idx]

        if self.is_labelled:
            label = torch.tensor(self.labels[idx])
            return metadata, label, isic_id
        else:
            return metadata, isic_id

In [None]:
def load_metadata_dataset(train_frac=0.8, seed=42) -> tuple:
    # Load the metadata CSV files
    train_df = pd.read_csv(TRAIN_METADATA_PROCESSED_CSV)
    test_df = pd.read_csv(TEST_METADATA_PROCESSED_CSV)

    # Perform stratified train/validation split to maintain class distribution
    train_dataset, valid_dataset = train_test_split(
        train_df, train_size=train_frac, stratify=train_df["target"], random_state=seed
    )

    # Reset index for train and validation datasets
    train_dataset = train_dataset.reset_index(drop=True)
    valid_dataset = valid_dataset.reset_index(drop=True)
    test_dataset = test_df.reset_index(drop=True)

    print(f"train_dataset shape: {train_dataset.shape}")
    print(f"valid_dataset shape: {valid_dataset.shape}")
    print(f"test_dataset shape:  {test_dataset.shape}")

    return train_dataset, valid_dataset, test_dataset


def load_hdf5_dataset(
    transform: T.Compose, train_frac=0.8, seed=42
) -> tuple[ISIC_HDF5_Dataset]:
    """
    Load the ISIC dataset from HDF5 files and split it into train, validation, and test sets.
    Args:
        transform (T.Compose): Transformations to apply to the images.
        train_frac (float): Fraction of the dataset to use for training.
        seed (int): Random seed for reproducibility.
    Returns:
        tuple: A tuple containing the train, validation, and test datasets.
    """
    # Load the metadata CSV files
    train_df_sub, valid_df_sub, test_df = load_metadata_dataset(
        train_frac=train_frac, seed=seed
    )

    # Create Datasets
    train_dataset = ISIC_HDF5_Dataset(
        df=train_df_sub, hdf5_path=TRAIN_HDF5, transform=transform, is_labelled=True
    )

    valid_dataset = ISIC_HDF5_Dataset(
        df=valid_df_sub, hdf5_path=TRAIN_HDF5, transform=transform, is_labelled=True
    )

    test_dataset = ISIC_HDF5_Dataset(
        df=test_df, hdf5_path=TEST_HDF5, transform=transform, is_labelled=False
    )

    return train_dataset, valid_dataset, test_dataset


def load_multimodal_dataset(
    transform: T.Compose, train_frac=0.8, seed=42
) -> tuple[ISIC_Multimodal_Dataset]:
    """
    Load the ISIC dataset from HDF5 files and split it into train, validation, and test sets.
    Args:
        transform (T.Compose): Transformations to apply to the images.
        train_frac (float): Fraction of the dataset to use for training.
        seed (int): Random seed for reproducibility.
    Returns:
        tuple: A tuple containing the train, validation, and test datasets.
    """
    # Load the metadata CSV files
    train_df_sub, valid_df_sub, test_df = load_metadata_dataset(
        train_frac=train_frac, seed=seed
    )

    # Create Datasets
    train_dataset = ISIC_Multimodal_Dataset(
        df=train_df_sub,
        hdf5_path=TRAIN_HDF5,
        transform=transform,
        is_labelled=True,
    )

    valid_dataset = ISIC_Multimodal_Dataset(
        df=valid_df_sub,
        hdf5_path=TRAIN_HDF5,
        transform=transform,
        is_labelled=True,
    )

    test_dataset = ISIC_Multimodal_Dataset(
        df=test_df,
        hdf5_path=TEST_HDF5,
        transform=transform,
        is_labelled=False,
    )

    return train_dataset, valid_dataset, test_dataset

In [None]:
train_meta_df, valid_meta_df, test_meta_df = load_metadata_dataset()

train_meta_dataset = ISIC_Metadata_Dataset(train_meta_df, is_labelled=True)
valid_meta_dataset = ISIC_Metadata_Dataset(valid_meta_df, is_labelled=True)
test_meta_dataset = ISIC_Metadata_Dataset(test_meta_df, is_labelled=False)

In [None]:
view_transform = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
train_hdf5_dataset, valid_hdf5_dataset, test_hdf5_dataset = load_hdf5_dataset(
    transform=view_transform
)

In [None]:
view_transform = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
train_multi_dataset, valid_multi_dataset, test_multi_dataset = load_multimodal_dataset(
    transform=view_transform
)

In [None]:
BATCH_SIZE = 16  # batch size
NUM_SAMPLES = 5_000  # samples per epoch
NUM_WORKERS = min(
    4, os.cpu_count() if os.cpu_count() is not None else 2
)  # number of CPU threads

class_counts = train_meta_df["target"].value_counts().sort_index()
class_weights = 1.0 / class_counts

# Normalize weights to sum to 1
class_weights = class_weights / class_weights.sum()

sample_weights = train_meta_df["target"].map(class_weights).values

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=NUM_SAMPLES,
    replacement=True,
)

In [None]:
EPOCHS = 15
LEARNING_RATE = 1e-3
SCHEDULER_STEP_SIZE = 4
SCHEDULER_GAMMA = 0.5


def train_valid(model, train_loader, valid_loader, is_multimodal=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = StepLR(optimizer, step_size=SCHEDULER_STEP_SIZE, gamma=SCHEDULER_GAMMA)

    # Tracking lists
    train_accuracies = []
    valid_accuracies = []

    for epoch in range(1, EPOCHS + 1):
        if is_multimodal:
            train_acc = train_multimodal(
                model, device, train_loader, optimizer, criterion, epoch
            )
            valid_acc = validate_multimodal(
                model, device, valid_loader, criterion, epoch
            )
        else:
            train_acc = train_singles(
                model, device, train_loader, optimizer, criterion, epoch
            )
            valid_acc = validate_singles(model, device, valid_loader, criterion, epoch)

        train_accuracies.append(train_acc)
        valid_accuracies.append(valid_acc)

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        print(f"Learning Rate: {current_lr}")

    # Plot training and validation accuracies
    plot_train_valid_curves(train_accuracies, valid_accuracies)

    print("Training complete ✅")


def train_eval(
    model,
    full_loader,
    test_loader,
    is_multimodal=False,
    output_model_file=OUTPUT_FINAL_MODEL,
    output_submission_file=OUTPUT_FINAL_SUBMISSION,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = StepLR(optimizer, step_size=SCHEDULER_STEP_SIZE, gamma=SCHEDULER_GAMMA)

    # Tracking lists
    train_accuracies = []

    for epoch in range(1, EPOCHS + 1):
        if is_multimodal:
            train_acc = train_multimodal(
                model, device, full_loader, optimizer, criterion, epoch
            )
        else:
            train_acc = train_singles(
                model, device, full_loader, optimizer, criterion, epoch
            )
        train_accuracies.append(train_acc)

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        print(f"Learning Rate: {current_lr}")

    # Save final model
    output_model_path = output_model_file
    torch.save(model.state_dict(), output_model_path)
    print(f"Model saved to {output_model_path}")

    # Plot training and validation accuracies
    plot_train_curves(train_accuracies)

    print("Training complete ✅")

    # Evaluate on test set
    if is_multimodal:
        submission_df = evaluate_multimodal(model, device, test_loader)
    else:
        submission_df = evaluate_singles(model, device, test_loader)

    # Save submission file
    submission_file_path = output_submission_file
    submission_df.to_csv(submission_file_path, index=False)

    print(
        f"Saved submission with {len(submission_df)} rows to {submission_file_path} ✅"
    )


def train_multimodal(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0

    for metadatas, images, labels in tqdm(
        train_loader, desc=f"Train Epoch {epoch}", leave=False
    ):
        metadatas, images, labels = (
            metadatas.to(device).float(),
            images.to(device),
            labels.to(device),
        )

        optimizer.zero_grad()
        logits = model(images, metadatas).view(-1)  # [batch_size]

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        predicted = (logits >= 0.5).float()
        correct_preds += (predicted == labels).sum().item()
        total_preds += labels.size(0)

    avg_train_loss = running_loss / len(train_loader)
    train_accuracy = correct_preds / total_preds

    print(
        f"Epoch {epoch}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}"
    )
    return train_accuracy


def validate_multimodal(model, device, valid_loader, criterion, epoch):
    model.eval()
    val_loss = 0.0
    val_correct_preds = 0
    val_total_preds = 0

    with torch.no_grad():
        for metadatas, images, labels in tqdm(
            valid_loader, desc=f"Validation Epoch {epoch}", leave=False
        ):
            metadatas, images, labels = (
                metadatas.to(device).float(),
                images.to(device),
                labels.to(device),
            )

            logits = model(images, metadatas).view(-1)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            predicted = (logits >= 0.5).float()
            val_correct_preds += (predicted == labels).sum().item()
            val_total_preds += labels.size(0)

    avg_val_loss = val_loss / len(valid_loader)
    val_accuracy = val_correct_preds / val_total_preds
    print(
        f"Epoch {epoch}/{EPOCHS} | Validation Loss: {avg_val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}"
    )
    return val_accuracy


def evaluate_multimodal(model, device, test_loader):
    model.eval()
    predictions = []

    with torch.no_grad():
        for metadatas, images, isic_ids in tqdm(test_loader, desc="Inference on Test"):
            metadatas, images = metadatas.to(device).float(), images.to(device)

            logits = model(images, metadatas).view(-1)  # shape [batch_size]
            probs = torch.sigmoid(logits)  # shape [batch_size], in [0,1]

            probs = probs.cpu().numpy()

            for isic_id, p in zip(isic_ids, probs):
                predictions.append({"isic_id": isic_id, "target": float(p)})

    submission_df = pd.DataFrame(predictions)
    submission_df = submission_df.sort_values(by="isic_id").reset_index(drop=True)

    return submission_df


def train_singles(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0

    for singles, labels, _ in tqdm(
        train_loader, desc=f"Train Epoch {epoch}", leave=False
    ):
        singles, labels = (
            singles.to(device),
            labels.to(device),
        )

        optimizer.zero_grad()
        logits = model(singles).view(-1)  # [batch_size]

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        predicted = (logits >= 0.5).float()
        correct_preds += (predicted == labels).sum().item()
        total_preds += labels.size(0)

    avg_train_loss = running_loss / len(train_loader)
    train_accuracy = correct_preds / total_preds

    print(
        f"Epoch {epoch}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}"
    )
    return train_accuracy


def validate_singles(model, device, valid_loader, criterion, epoch):
    model.eval()
    val_loss = 0.0
    val_correct_preds = 0
    val_total_preds = 0

    with torch.no_grad():
        for singles, labels, _ in tqdm(
            valid_loader, desc=f"Validation Epoch {epoch}", leave=False
        ):
            singles, labels = (
                singles.to(device),
                labels.to(device),
            )

            logits = model(singles).view(-1)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            predicted = (logits >= 0.5).float()
            val_correct_preds += (predicted == labels).sum().item()
            val_total_preds += labels.size(0)

    avg_val_loss = val_loss / len(valid_loader)
    val_accuracy = val_correct_preds / val_total_preds
    print(
        f"Epoch {epoch}/{EPOCHS} | Validation Loss: {avg_val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}"
    )
    return val_accuracy


def evaluate_singles(model, device, test_loader):
    model.eval()
    predictions = []

    with torch.no_grad():
        for singles, isic_ids in tqdm(test_loader, desc="Inference on Test"):
            singles = singles.to(device)

            logits = model(singles).view(-1)  # shape [batch_size]
            probs = torch.sigmoid(logits)  # shape [batch_size], in [0,1]

            probs = probs.cpu().numpy()

            for isic_id, p in zip(isic_ids, probs):
                predictions.append({"isic_id": isic_id, "target": float(p)})

    submission_df = pd.DataFrame(predictions)
    submission_df = submission_df.sort_values(by="isic_id").reset_index(drop=True)

    return submission_df


def plot_train_valid_curves(train_accs, val_accs):
    # Prepare DataFrame for seaborn
    epochs = list(range(1, len(train_accs) + 1))
    data = pd.DataFrame(
        {
            "Epoch": epochs * 2,
            "Accuracy": train_accs + val_accs,
            "Phase": ["Train"] * len(train_accs) + ["Validation"] * len(val_accs),
        }
    )

    # Plot with seaborn
    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=data, x="Epoch", y="Accuracy", hue="Phase", marker="o")
    plt.title("Training vs Validation Accuracy")
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.show()


def plot_train_curves(train_accs):
    # Prepare DataFrame for seaborn
    epochs = list(range(1, len(train_accs) + 1))
    data = pd.DataFrame(
        {
            "Epoch": epochs,
            "Accuracy": train_accs,
        }
    )

    # Plot with seaborn
    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=data, x="Epoch", y="Accuracy", marker="o")
    plt.title("Training Accuracy (Full Dataset)")
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.show()

# Only Metadata


In [None]:
train_meta_loader = DataLoader(
    train_meta_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
)

valid_meta_loader = DataLoader(
    valid_meta_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

test_meta_loader = DataLoader(
    test_meta_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

full_meta_dataset = ConcatDataset([train_meta_dataset, valid_meta_dataset])
full_meta_loader = DataLoader(
    full_meta_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

print(
    f"Train loader: {len(train_meta_loader)} batches (total = {NUM_SAMPLES} samples / {BATCH_SIZE} batches)"
)
print(f"Valid loader: {len(valid_meta_loader)} batches")
print(f"Test loader:  {len(test_meta_loader)} batches")
print(f"Full loader:  {len(full_meta_loader)} batches")

In [None]:
class MLP_MetadataOnly(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),  # Batch normalization for the hidden layer
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),  # Batch normalization for the second hidden layer
            nn.Linear(64, 1),  # Single output unit for binary classification
        )

    def forward(self, metadata):
        return self.mlp(metadata)

In [None]:
model = MLP_MetadataOnly(
    num_features=train_meta_df.shape[1] - 2
)  # Exclude target column

train_valid(model, train_meta_loader, valid_meta_loader, is_multimodal=False)

In [None]:
train_eval(
    model,
    full_meta_loader,
    test_meta_loader,
    is_multimodal=False,
    output_model_file="/kaggle/working/mlp_metadata.pth",
    output_submission_file="/kaggle/working/mlp_metadata.csv",
)

# Only Images


In [None]:
train_hdf5_loader = DataLoader(
    train_hdf5_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
)

valid_hdf5_loader = DataLoader(
    valid_hdf5_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

test_hdf5_loader = DataLoader(
    test_hdf5_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

full_hdf5_dataset = ConcatDataset([train_hdf5_dataset, valid_hdf5_dataset])
full_hdf5_loader = DataLoader(
    full_hdf5_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

print(
    f"Train loader: {len(train_hdf5_loader)} batches (total = {NUM_SAMPLES} samples / {BATCH_SIZE} batches)"
)
print(f"Valid loader: {len(valid_hdf5_loader)} batches")
print(f"Test loader:  {len(test_hdf5_loader)} batches")
print(f"Full loader:  {len(full_hdf5_loader)} batches")

In [None]:
class MLP_ImageOnly(nn.Module):
    def __init__(self, image_shape=(3, 224, 224)):
        super().__init__()
        flattened_size = image_shape[0] * image_shape[1] * image_shape[2]
        self.mlp = nn.Sequential(
            nn.Linear(flattened_size, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, image):
        x = image.view(image.size(0), -1)
        return self.mlp(x)

In [None]:
model = MLP_ImageOnly()

train_valid(model, train_hdf5_loader, valid_hdf5_loader, is_multimodal=False)

In [None]:
train_eval(
    model,
    full_hdf5_loader,
    test_hdf5_loader,
    is_multimodal=False,
    output_model_file="/kaggle/working/mlp_images.pth",
    output_submission_file="/kaggle/working/mlp_images.csv",
)

# Multimodal


In [None]:
train_multi_loader = DataLoader(
    train_multi_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
)

valid_multi_loader = DataLoader(
    valid_multi_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

test_multi_loader = DataLoader(
    test_multi_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

full_multi_dataset = ConcatDataset([train_multi_dataset, valid_multi_dataset])
full_multi_loader = DataLoader(
    full_multi_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

print(
    f"Train loader: {len(train_multi_loader)} batches (total = {NUM_SAMPLES} samples / {BATCH_SIZE} batches)"
)
print(f"Valid loader: {len(valid_multi_loader)} batches")
print(f"Test loader:  {len(test_multi_loader)} batches")
print(f"Full loader:  {len(full_multi_loader)} batches")

In [None]:
class MLP_Multimodal(nn.Module):
    def __init__(self, num_metadata_features, image_shape=(3, 224, 224)):
        super().__init__()
        flattened_image_size = image_shape[0] * image_shape[1] * image_shape[2]
        input_size = flattened_image_size + num_metadata_features

        self.mlp = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, image, metadata):
        image_flat = image.view(image.size(0), -1)
        combined = torch.cat((image_flat, metadata), dim=1)
        return self.mlp(combined)

In [None]:
model = MLP_Multimodal(
    num_metadata_features=train_meta_df.shape[1] - 2
)  # Exclude target column

train_valid(model, train_multi_loader, valid_multi_loader, is_multimodal=True)

In [None]:
train_eval(
    model,
    full_multi_loader,
    test_multi_loader,
    is_multimodal=True,
    output_model_file="/kaggle/working/mlp_multimodal.pth",
    output_submission_file="/kaggle/working/mlp_multimodal.csv",
)