### Imports


In [59]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.preprocessing import LabelEncoder
from torchvision.models import mobilenet_v3_small
from tqdm import tqdm

from src.dataset_loaders import ISAdetectDataset
from src.transforms import GrayScaleImage

### Setup


In [60]:
# Specify the model
MODEL = mobilenet_v3_small
TARGET_FEATURE = "endianness"

# Model hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 1
LEARNING_RATE = 1e-4

# Specify which groups to use as validation set. Set to None to validate all groups.
VALIDATION_GROUPS = None
# VALIDATION_GROUPS = ["mips", "mipsel"]

# Set to an integer to limit the dataset size. Set to None to disable limit.
MAX_FILES_PER_ISA = None

### Helper functions


In [61]:
def get_device():
    """
    Returns 'cuda' if CUDA is available, else 'mps' if Apple Silicon GPU is available,
    otherwise 'cpu'.
    """
    device = None
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")
    return device

### Prepare


In [62]:
device = get_device()

dataset = ISAdetectDataset(
    dataset_path="../../dataset/ISAdetect/ISAdetect_full_dataset",
    transform=GrayScaleImage(224, 224),
    file_byte_read_limit=224 * 224,
    per_architecture_limit=MAX_FILES_PER_ISA,
)


groups = list(map(lambda x: x["architecture"], dataset.metadata))
target_feature = list(map(lambda x: x["wordsize"], dataset.metadata))

Using device: cuda


### Train and evaluate


In [63]:
logo = LeaveOneGroupOut()
label_encoder = LabelEncoder()

fold = 1
accuracies = {}
for train_idx, test_idx in logo.split(
    X=range(len(dataset)), y=target_feature, groups=groups
):
    group_left_out = groups[test_idx[0]]

    if VALIDATION_GROUPS != None and group_left_out not in VALIDATION_GROUPS:
        continue

    print(f"\n=== Fold {fold} – leaving out group '{group_left_out}' ===")
    fold += 1

    all_train_labels = [dataset.metadata[i]["wordsize"] for i in train_idx]
    label_encoder.fit(all_train_labels)

    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, test_idx)

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
    )
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2
    )

    model = MODEL(num_classes=2, weights=None)
    model.features[0][0] = nn.Conv2d(
        1, 16, kernel_size=3, stride=2, padding=1, bias=False
    )
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Train model
    for epoch in range(NUM_EPOCHS):
        model.train()
        print(f"Epoch {epoch+1}:")

        for images, labels in tqdm(train_loader):
            images = images.to(device)

            encoded_labels = torch.from_numpy(
                label_encoder.transform(labels[TARGET_FEATURE])
            ).to(device)

            optimizer.zero_grad()
            predictions = model(images)
            loss = criterion(predictions, encoded_labels)
            loss.backward()
            optimizer.step()

        # Evaluate model
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                encoded_labels = torch.from_numpy(
                    label_encoder.transform(labels[TARGET_FEATURE])
                ).to(device)

                outputs = model(images)

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == encoded_labels).sum().item()
                total += encoded_labels.size(0)


        accuracy = correct / total
        print(f"Test Accuracy: {100*accuracy:.2f}%")

    accuracies[group_left_out] = accuracy


=== Fold 1 – leaving out group 'alpha' ===
Epoch 1:


  0%|          | 0/1445 [00:00<?, ?it/s]

100%|██████████| 1445/1445 [00:47<00:00, 30.70it/s]


Test Accuracy: 99.70%

=== Fold 2 – leaving out group 'amd64' ===
Epoch 1:


100%|██████████| 1438/1438 [00:47<00:00, 30.07it/s]


Test Accuracy: 94.07%

=== Fold 3 – leaving out group 'arm64' ===
Epoch 1:


100%|██████████| 1450/1450 [00:48<00:00, 30.19it/s]


Test Accuracy: 68.43%

=== Fold 4 – leaving out group 'armel' ===
Epoch 1:


100%|██████████| 1444/1444 [00:42<00:00, 33.71it/s]


Test Accuracy: 99.88%

=== Fold 5 – leaving out group 'armhf' ===
Epoch 1:


100%|██████████| 1444/1444 [00:45<00:00, 31.79it/s]


Test Accuracy: 96.15%

=== Fold 6 – leaving out group 'hppa' ===
Epoch 1:


100%|██████████| 1431/1431 [00:46<00:00, 30.71it/s]


Test Accuracy: 45.73%

=== Fold 7 – leaving out group 'i386' ===
Epoch 1:


100%|██████████| 1427/1427 [00:43<00:00, 33.01it/s]


Test Accuracy: 90.64%

=== Fold 8 – leaving out group 'ia64' ===
Epoch 1:


100%|██████████| 1429/1429 [00:45<00:00, 31.31it/s]


Test Accuracy: 69.62%

=== Fold 9 – leaving out group 'm68k' ===
Epoch 1:


100%|██████████| 1438/1438 [00:48<00:00, 29.95it/s]


Test Accuracy: 0.82%

=== Fold 10 – leaving out group 'mips' ===
Epoch 1:


100%|██████████| 1451/1451 [00:50<00:00, 28.54it/s]


Test Accuracy: 84.71%

=== Fold 11 – leaving out group 'mips64el' ===
Epoch 1:


100%|██████████| 1439/1439 [00:51<00:00, 28.05it/s]


Test Accuracy: 99.79%

=== Fold 12 – leaving out group 'mipsel' ===
Epoch 1:


100%|██████████| 1448/1448 [00:51<00:00, 27.96it/s]


Test Accuracy: 61.62%

=== Fold 13 – leaving out group 'powerpc' ===
Epoch 1:


100%|██████████| 1450/1450 [00:51<00:00, 28.03it/s]


Test Accuracy: 99.92%

=== Fold 14 – leaving out group 'powerpcspe' ===
Epoch 1:


100%|██████████| 1445/1445 [00:52<00:00, 27.56it/s]


Test Accuracy: 99.50%

=== Fold 15 – leaving out group 'ppc64' ===
Epoch 1:


100%|██████████| 1462/1462 [00:52<00:00, 27.69it/s]


Test Accuracy: 79.14%

=== Fold 16 – leaving out group 'ppc64el' ===
Epoch 1:


100%|██████████| 1452/1452 [00:52<00:00, 27.86it/s]


Test Accuracy: 100.00%

=== Fold 17 – leaving out group 'riscv64' ===
Epoch 1:


100%|██████████| 1437/1437 [00:51<00:00, 27.73it/s]


Test Accuracy: 65.32%

=== Fold 18 – leaving out group 's390' ===
Epoch 1:


100%|██████████| 1426/1426 [00:49<00:00, 28.53it/s]


Test Accuracy: 2.25%

=== Fold 19 – leaving out group 's390x' ===
Epoch 1:


100%|██████████| 1451/1451 [00:52<00:00, 27.62it/s]


Test Accuracy: 47.60%

=== Fold 20 – leaving out group 'sh4' ===
Epoch 1:


100%|██████████| 1414/1414 [00:51<00:00, 27.29it/s]


Test Accuracy: 60.99%

=== Fold 21 – leaving out group 'sparc' ===
Epoch 1:


100%|██████████| 1429/1429 [00:51<00:00, 28.01it/s]


Test Accuracy: 78.71%

=== Fold 22 – leaving out group 'sparc64' ===
Epoch 1:


100%|██████████| 1455/1455 [00:55<00:00, 26.13it/s]


Test Accuracy: 95.79%

=== Fold 23 – leaving out group 'x32' ===
Epoch 1:


100%|██████████| 1441/1441 [00:54<00:00, 26.44it/s]


Test Accuracy: 97.70%


### Evaluate


In [64]:
print("Test accuracies for each fold/group:")
for group, acc in accuracies.items():
    print(f"{group}: {100*acc:.2f}%")


# Print overall performance across folds
mean_acc = np.mean(list(accuracies.values()))
std_acc = np.std(list(accuracies.values()))
print(f"\nAverage LOGO cross-validated test accuracy: {mean_acc:.4f} ± {std_acc:.4f}")

Test accuracies for each fold/group:
alpha: 99.70%
amd64: 94.07%
arm64: 68.43%
armel: 99.88%
armhf: 96.15%
hppa: 45.73%
i386: 90.64%
ia64: 69.62%
m68k: 0.82%
mips: 84.71%
mips64el: 99.79%
mipsel: 61.62%
powerpc: 99.92%
powerpcspe: 99.50%
ppc64: 79.14%
ppc64el: 100.00%
riscv64: 65.32%
s390: 2.25%
s390x: 47.60%
sh4: 60.99%
sparc: 78.71%
sparc64: 95.79%
x32: 97.70%

Average LOGO cross-validated test accuracy: 0.7557 ± 0.2856
