# Train/Test Split Model Comparison

This notebook trains six different models on a standard 80/20 train/test split of the ISAdetect dataset for the 'endianness' target feature and compares their final test accuracies.


In [None]:
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # Use notebook version of tqdm
from datetime import datetime
import random
from pathlib import Path

# Add src directory to sys.path to import project modules
# Assumes the notebook is run from the 'analysis' directory
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

from dataset_loaders import get_dataset
from models import get_model
from transforms import get_transform
from validators.train_test_utils import set_seed

In [None]:
# Custom collate function to handle the dataset output
# Pads image tensors to the maximum size in the batch

def custom_collate_fn(batch):
    # Find the maximum size of tensors in the batch
    max_size = max([item[0].size(0) for item in batch])

    # Pad tensors to the maximum size
    padded_images = []
    for item in batch:
        tensor = item[0]
        padding_size = max_size - tensor.size(0)
        # Pad with zeros at the end
        padded_tensor = torch.nn.functional.pad(tensor, (0, padding_size))
        padded_images.append(padded_tensor)

    images = torch.stack(padded_images)
    labels = [item[1] for item in batch] # Keep labels as a list of dicts
    # file_paths = [item[2] for item in batch] # If file_path is needed later as a batch

    return images, labels # Return images and the list of label dicts


## Configuration


In [None]:
# --- Configuration ---
TARGET_FEATURE = "endianness"
DATASET_NAME = "ISAdetectDataset"  # Or choose another like 'CpuRecDataset'
DATASET_BASE_PATH = Path(
    os.environ.get("DATASET_BASE_PATH", "../../dataset")
)  # Adjust if needed
MODEL_NAMES = [
    "Simple1d",
    "Simple1dEmbedding",
    "Simple2d",
    "Simple2dEmbedding",
    "ResNet50",
    "ResNet50Embedding",
]
TRAIN_SPLIT_RATIO = 0.8
SEED = 42
EPOCHS = 2
BATCH_SIZE = 64
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.01
OPTIMIZER = "AdamW"
CRITERION = "CrossEntropyLoss"
DEVICE = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

print(f"Using device: {DEVICE}")
print(f"Target Feature: {TARGET_FEATURE}")
print(f"Dataset: {DATASET_NAME}")
print(f"Models: {MODEL_NAMES}")

# Set seed for reproducibility
set_seed(SEED)

## Load Dataset and Prepare Splits


In [None]:
# --- Load Data ---
# Using default transforms for now, adjust if needed
transforms = None  # Ensure default transforms are used

# Load the full dataset
full_dataset = get_dataset(
    name=DATASET_NAME,
    transform=transforms,
    dataset_base_path=DATASET_BASE_PATH,
    target_feature=TARGET_FEATURE,
    params={
        # This path should point to the directory containing architecture subfolders (arm, mips, etc.)
        "dataset_path": "ISAdetect/ISAdetect_full_dataset",
        "feature_csv_path": "ISAdetect-features.csv",  # Relative path within DATASET_BASE_PATH
    },
)

# Prepare for stratified split
targets = [item[TARGET_FEATURE] for item in full_dataset.metadata]
indices = list(range(len(full_dataset)))

# Perform stratified train/test split
train_idx, test_idx, _, _ = train_test_split(
    indices,
    targets,
    stratify=targets,
    test_size=1.0 - TRAIN_SPLIT_RATIO,
    random_state=SEED,
)

# Create subset datasets
train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

print(f"Full dataset size: {len(full_dataset)}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Fit LabelEncoder on training data labels
label_encoder = LabelEncoder()
# Need to get the actual labels from the subset indices
train_labels = [full_dataset.metadata[i][TARGET_FEATURE] for i in train_idx]
label_encoder.fit(train_labels)
num_classes = len(label_encoder.classes_)
print(f"Classes: {label_encoder.classes_}")

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,  # Adjust based on your system
    pin_memory=True,
    collate_fn=custom_collate_fn
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,  # Adjust based on your system
    pin_memory=True,
    collate_fn=custom_collate_fn
)

## Training and Evaluation Loop


In [None]:
# --- Training & Evaluation ---
model_results = {}

for model_name in MODEL_NAMES:
    print(f"\n{'='*10} Training Model: {model_name} {'='*10}")
    set_seed(SEED)  # Reset seed for each model for consistent initialization

    # Get model class
    # Assuming get_model can infer params or uses defaults
    # Need to pass num_classes based on the dataset
    model_class = get_model(name=model_name, params={"num_classes": num_classes})
    model = model_class(num_classes=num_classes)  # Instantiate
    model = model.to(DEVICE)

    # Criterion and Optimizer
    criterion = getattr(nn, CRITERION)()
    optimizer = getattr(torch.optim, OPTIMIZER)(
        model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    # Training Loop
    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
        for batch_idx, (images, labels_list) in enumerate(progress_bar):
            images = images.to(DEVICE)
            # Extract the target feature string label for encoding from the list of labels
            str_labels = [label_item[TARGET_FEATURE] for label_item in labels_list]
            encoded_labels = torch.from_numpy(label_encoder.transform(str_labels)).to(
                DEVICE
            )

            optimizer.zero_grad()
            predictions = model(images)
            loss = criterion(predictions, encoded_labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            progress_bar.set_postfix({"train_loss": loss.item()})

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Epoch {epoch+1} Average Training Loss: {avg_train_loss:.4f}")

    # Evaluation Loop
    model.eval()
    total_test_loss = 0
    all_preds = []
    all_true = []
    file_predictions_map = {}
    file_true_labels_map = {}

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Evaluating")
        for images, labels_list in progress_bar:
            images = images.to(DEVICE)
            # Extract file paths and target feature labels from the list of labels
            file_paths = [label_item["file_path"] for label_item in labels_list]
            str_labels = [label_item[TARGET_FEATURE] for label_item in labels_list]
            encoded_labels = torch.from_numpy(label_encoder.transform(str_labels)).to(
                DEVICE
            )

            outputs = model(images)
            loss = criterion(outputs, encoded_labels)
            total_test_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            batch_predictions = predicted.cpu().numpy()
            batch_true_labels = encoded_labels.cpu().numpy()

            # Store predictions by parent file for majority voting (like in logo_cv)
            for pred, true_label, file_path in zip(
                batch_predictions, batch_true_labels, file_paths
            ):
                # Use file path as key; might need adjustment if paths aren't unique identifiers
                parent_file = os.path.basename(
                    file_path
                )  # Or some other way to group chunks
                if parent_file not in file_predictions_map:
                    file_predictions_map[parent_file] = []
                    file_true_labels_map[parent_file] = true_label
                file_predictions_map[parent_file].append(pred)

    avg_test_loss = total_test_loss / len(test_loader)

    # Calculate majority voting accuracy
    file_level_predictions = []
    file_level_true_labels = []
    for file_key in file_predictions_map:
        chunk_preds = file_predictions_map[file_key]
        # Check if chunk_preds is not empty before bincount
        if chunk_preds:
            vote_distribution = np.bincount(chunk_preds, minlength=num_classes)
            file_prediction = vote_distribution.argmax()
            file_true_label = file_true_labels_map[file_key]
            file_level_predictions.append(file_prediction)
            file_level_true_labels.append(file_true_label)
        else:
            print(
                f"Warning: No predictions found for file key {file_key}"
            )  # Handle cases with no predictions

    # Ensure there are predictions to calculate accuracy
    if file_level_predictions:
        file_level_accuracy = np.mean(
            np.array(file_level_predictions) == np.array(file_level_true_labels)
        )
    else:
        file_level_accuracy = 0.0  # Or handle as NaN or error
        print("Warning: No file-level predictions were made, accuracy set to 0.")

    print(f"Model: {model_name}")
    print(f"  Average Test Loss: {avg_test_loss:.4f}")
    print(f"  File-level Test Accuracy: {100 * file_level_accuracy:.2f}%")

    model_results[model_name] = file_level_accuracy

## Plot Results


In [None]:
# --- Plotting ---
model_names_sorted = sorted(model_results.keys())
accuracies_sorted = [model_results[m] for m in model_names_sorted]

plt.figure(figsize=(10, 6))
bars = plt.bar(model_names_sorted, accuracies_sorted, color="skyblue")

# Add accuracy values on top of bars
for bar in bars:
    yval = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2.0,
        yval,
        f"{yval:.3f}",
        va="bottom",
        ha="center",
    )  # Adjust position

plt.xlabel("Model")
plt.ylabel("Test Accuracy (File-Level)")
plt.title(f"Model Comparison on '{TARGET_FEATURE}' (Train/Test Split)")
plt.xticks(rotation=45, ha="right")
plt.ylim(0, 1.05)  # Extend y-limit slightly for text visibility
plt.tight_layout()

plt.show()