### Imports


In [9]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from sklearn.metrics import accuracy_score
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.preprocessing import LabelEncoder
from torchvision.models import mobilenet_v3_small
from tqdm import tqdm

from src.dataset_loaders import ISAdetectDataset
from src.transforms import GrayScaleImage

### Setup


In [10]:
# Specify the model
MODEL = mobilenet_v3_small

# Model hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 1

# Set to an integer to limit the dataset size. Useful for debugging.
MAX_FILES_PER_ISA = None

### Helper functions


In [11]:
def get_device():
    """
    Returns 'cuda' if CUDA is available, else 'mps' if Apple Silicon GPU is available,
    otherwise 'cpu'.
    """
    device = None
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")
    return device

### Prepare


In [12]:
device = get_device()

dataset = ISAdetectDataset(
    dataset_path="../../dataset/ISAdetect/ISAdetect_full_dataset",
    transform=GrayScaleImage(224, 224),
    file_byte_read_limit=224 * 224,
    per_architecture_limit=MAX_FILES_PER_ISA,
)


groups = list(map(lambda x: x["architecture"], dataset.metadata))
target_feature = list(map(lambda x: x["endianness"], dataset.metadata))

Using device: cuda
ppc64el limit reached
powerpcspe limit reached
hppa limit reached
sparc64 limit reached
ppc64 limit reached
armel limit reached
sh4 limit reached
s390x limit reached
powerpc limit reached
ia64 limit reached
mips64el limit reached
riscv64 limit reached
i386 limit reached
mips limit reached
x32 limit reached
amd64 limit reached
s390 limit reached
alpha limit reached
armhf limit reached
mipsel limit reached
m68k limit reached
sparc limit reached
arm64 limit reached


### Train and evaluate


In [14]:
logo = LeaveOneGroupOut()
label_encoder = LabelEncoder()

fold = 1
all_fold_accuracies = []
for train_idx, test_idx in logo.split(
    X=range(len(dataset)), y=target_feature, groups=groups
):
    print(f"\n=== Fold {fold} – leaving out group '{groups[test_idx[0]]}' ===")
    fold += 1

    all_train_labels = [dataset.metadata[i]["endianness"] for i in train_idx]
    label_encoder.fit(all_train_labels)

    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, test_idx)

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
    )
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2
    )

    model = MODEL(num_classes=2, weights=None)
    model.features[0][0] = nn.Conv2d(
        1, 16, kernel_size=3, stride=2, padding=1, bias=False
    )
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Train model
    model.train()
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch+1}:")

        for images, labels in tqdm(train_loader):
            images = images.to(device)

            encoded_labels = torch.from_numpy(
                label_encoder.transform(labels["endianness"])
            ).to(device)

            optimizer.zero_grad()
            predictions = model(images)
            loss = criterion(predictions, encoded_labels)
            loss.backward()
            optimizer.step()

    # Evaluate model
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            encoded_labels = torch.from_numpy(
                label_encoder.transform(labels["endianness"])
            ).to(device)

            outputs = model(images)

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == encoded_labels).sum().item()
            total += encoded_labels.size(0)

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    all_fold_accuracies.append(accuracy)


=== Fold 1 – leaving out group 'alpha' ===
Epoch 1:


100%|██████████| 3/3 [00:01<00:00,  2.51it/s]


Test Accuracy: 100.00%

=== Fold 2 – leaving out group 'amd64' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  7.32it/s]


Test Accuracy: 100.00%

=== Fold 3 – leaving out group 'arm64' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  8.42it/s]


Test Accuracy: 100.00%

=== Fold 4 – leaving out group 'armel' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  8.63it/s]


Test Accuracy: 100.00%

=== Fold 5 – leaving out group 'armhf' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  6.98it/s]


Test Accuracy: 100.00%

=== Fold 6 – leaving out group 'hppa' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  7.45it/s]


Test Accuracy: 0.00%

=== Fold 7 – leaving out group 'i386' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  6.29it/s]


Test Accuracy: 100.00%

=== Fold 8 – leaving out group 'ia64' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  7.03it/s]


Test Accuracy: 100.00%

=== Fold 9 – leaving out group 'm68k' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  8.47it/s]


Test Accuracy: 0.00%

=== Fold 10 – leaving out group 'mips' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  7.69it/s]


Test Accuracy: 0.00%

=== Fold 11 – leaving out group 'mips64el' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  7.73it/s]


Test Accuracy: 100.00%

=== Fold 12 – leaving out group 'mipsel' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  7.86it/s]


Test Accuracy: 100.00%

=== Fold 13 – leaving out group 'powerpc' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  6.47it/s]


Test Accuracy: 0.00%

=== Fold 14 – leaving out group 'powerpcspe' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  8.32it/s]


Test Accuracy: 0.00%

=== Fold 15 – leaving out group 'ppc64' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  6.89it/s]


Test Accuracy: 0.00%

=== Fold 16 – leaving out group 'ppc64el' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  7.81it/s]


Test Accuracy: 100.00%

=== Fold 17 – leaving out group 'riscv64' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  7.97it/s]


Test Accuracy: 100.00%

=== Fold 18 – leaving out group 's390' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  6.41it/s]


Test Accuracy: 0.00%

=== Fold 19 – leaving out group 's390x' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  6.77it/s]


Test Accuracy: 0.00%

=== Fold 20 – leaving out group 'sh4' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  5.00it/s]


Test Accuracy: 100.00%

=== Fold 21 – leaving out group 'sparc' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  5.23it/s]


Test Accuracy: 0.00%

=== Fold 22 – leaving out group 'sparc64' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  6.75it/s]


Test Accuracy: 0.00%

=== Fold 23 – leaving out group 'x32' ===
Epoch 1:


100%|██████████| 3/3 [00:00<00:00,  5.99it/s]


Test Accuracy: 100.00%


### Evaluate


In [15]:
# Print overall performance across folds
mean_acc = np.mean(all_fold_accuracies)
std_acc = np.std(all_fold_accuracies)
print(f"\nOverall LOGO CV Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")


Overall LOGO CV Accuracy: 56.5217 ± 49.5728
