### Imports


In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import LeaveOneGroupOut, KFold
from sklearn.preprocessing import LabelEncoder
from torchvision.models import mobilenet_v3_small
from tqdm import tqdm

from src.dataset_loaders import ISAdetectDataset
from src.transforms import GrayScaleImage

### Setup


In [9]:
# Specify the model
MODEL = mobilenet_v3_small
TARGET_FEATURE = "architecture"

# Model hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4

# Specify which groups to use as validation set. Set to None to validate all groups.
VALIDATION_GROUPS = None
# VALIDATION_GROUPS = ["mips", "mipsel"]

# Set to an integer to limit the dataset size. Set to None to disable limit.
MAX_FILES_PER_ISA = None

### Helper functions


In [10]:
def get_device():
    """
    Returns 'cuda' if CUDA is available, else 'mps' if Apple Silicon GPU is available,
    otherwise 'cpu'.
    """
    device = None
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")
    return device

### Prepare


In [11]:
device = get_device()

dataset = ISAdetectDataset(
    dataset_path="../../dataset/ISAdetect/ISAdetect_full_dataset",
    transform=GrayScaleImage(224, 224),
    file_byte_read_limit=224 * 224,
    per_architecture_limit=MAX_FILES_PER_ISA,
)

groups = list(map(lambda x: x["architecture"], dataset.metadata))
target_feature = list(map(lambda x: x[TARGET_FEATURE], dataset.metadata))
print(f"groups: {set(groups)}")
print(f"features: {set(target_feature)}")

Using device: cuda
groups: {'armel', 'arm64', 'ppc64el', 'ia64', 'sh4', 's390x', 'amd64', 'sparc', 'ppc64', 'riscv64', 's390', 'x32', 'mipsel', 'powerpc', 'm68k', 'hppa', 'alpha', 'mips64el', 'sparc64', 'armhf', 'powerpcspe', 'mips', 'i386'}
features: {'armel', 'arm64', 'ppc64el', 'ia64', 'sh4', 's390x', 'amd64', 'sparc', 'ppc64', 'riscv64', 's390', 'x32', 'mipsel', 'powerpc', 'm68k', 'hppa', 'alpha', 'mips64el', 'sparc64', 'armhf', 'powerpcspe', 'mips', 'i386'}


### Train and evaluate


In [12]:
kfold = KFold(n_splits=10, shuffle=True)
label_encoder = LabelEncoder()

scaler = torch.cuda.amp.GradScaler()

fold = 1
accuracies = {}
corrects = {}
totals = {}
for train_idx, test_idx in kfold.split(X=range(len(dataset))):

    print(f"\n=== Fold {fold} ===")
    fold += 1

    all_train_labels = [dataset.metadata[i][TARGET_FEATURE] for i in train_idx]
    label_encoder.fit(all_train_labels)

    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, test_idx)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )

    model = MODEL(num_classes=23, weights=None)
    model.features[0][0] = nn.Conv2d(
        1, 16, kernel_size=3, stride=2, padding=1, bias=False
    )

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # Train model
    for epoch in range(NUM_EPOCHS):
        model.train()
        print(f"Epoch {epoch+1}:")

        for images, labels in tqdm(train_loader):
            images = images.to(device)

            encoded_labels = torch.from_numpy(
                label_encoder.transform(labels[TARGET_FEATURE])
            ).to(device)

            optimizer.zero_grad()

            # with torch.cuda.amp.autocast():
            predictions = model(images)
            loss = criterion(predictions, encoded_labels)

            loss.backward()
            optimizer.step()
            # scaler.scale(loss).backward()
            # scaler.step(optimizer)
            # scaler.update()

        # Evaluate model
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                encoded_labels = torch.from_numpy(
                    label_encoder.transform(labels[TARGET_FEATURE])
                ).to(device)

                outputs = model(images)

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == encoded_labels).sum().item()
                total += encoded_labels.size(0)

                for i in range(len(predicted)):
                    # add accuracy per architecture
                    arch = label_encoder.inverse_transform([encoded_labels[i].item()])[
                        0
                    ]
                    if predicted[i] == encoded_labels[i]:
                        corrects[arch] = corrects.get(arch, 0) + 1
                    totals[arch] = totals.get(arch, 0) + 1

        accuracy = correct / total
        print(f"Test Accuracy: {100*accuracy:.2f}%")

    accuracies[fold] = accuracy


=== Fold 1 ===
Epoch 1:


100%|██████████| 1356/1356 [00:28<00:00, 46.93it/s]


Test Accuracy: 86.06%
Epoch 2:


100%|██████████| 1356/1356 [00:28<00:00, 47.58it/s]


Test Accuracy: 93.58%
Epoch 3:


100%|██████████| 1356/1356 [00:28<00:00, 47.50it/s]


Test Accuracy: 93.59%
Epoch 4:


100%|██████████| 1356/1356 [00:29<00:00, 46.59it/s]


Test Accuracy: 96.39%
Epoch 5:


100%|██████████| 1356/1356 [00:29<00:00, 45.70it/s]


Test Accuracy: 96.99%

=== Fold 2 ===
Epoch 1:


100%|██████████| 1356/1356 [00:28<00:00, 47.92it/s]


Test Accuracy: 89.47%
Epoch 2:


100%|██████████| 1356/1356 [00:28<00:00, 48.27it/s]


Test Accuracy: 94.09%
Epoch 3:


100%|██████████| 1356/1356 [00:28<00:00, 47.20it/s]


Test Accuracy: 95.93%
Epoch 4:


100%|██████████| 1356/1356 [00:28<00:00, 46.83it/s]


Test Accuracy: 96.73%
Epoch 5:


100%|██████████| 1356/1356 [00:28<00:00, 47.21it/s]


Test Accuracy: 97.44%

=== Fold 3 ===
Epoch 1:


100%|██████████| 1356/1356 [00:28<00:00, 47.05it/s]


Test Accuracy: 81.78%
Epoch 2:


100%|██████████| 1356/1356 [00:28<00:00, 47.15it/s]


Test Accuracy: 81.08%
Epoch 3:


100%|██████████| 1356/1356 [00:29<00:00, 46.58it/s]


Test Accuracy: 93.94%
Epoch 4:


100%|██████████| 1356/1356 [00:28<00:00, 47.72it/s]


Test Accuracy: 94.62%
Epoch 5:


100%|██████████| 1356/1356 [00:29<00:00, 45.90it/s]


Test Accuracy: 92.15%

=== Fold 4 ===
Epoch 1:


100%|██████████| 1356/1356 [00:28<00:00, 46.84it/s]


Test Accuracy: 83.03%
Epoch 2:


100%|██████████| 1356/1356 [00:28<00:00, 46.82it/s]


Test Accuracy: 92.30%
Epoch 3:


100%|██████████| 1356/1356 [00:28<00:00, 47.00it/s]


Test Accuracy: 91.44%
Epoch 4:


100%|██████████| 1356/1356 [00:29<00:00, 46.14it/s]


Test Accuracy: 96.32%
Epoch 5:


100%|██████████| 1356/1356 [00:28<00:00, 47.00it/s]


Test Accuracy: 95.87%

=== Fold 5 ===
Epoch 1:


100%|██████████| 1356/1356 [00:29<00:00, 46.51it/s]


Test Accuracy: 88.31%
Epoch 2:


100%|██████████| 1356/1356 [00:30<00:00, 43.88it/s]


Test Accuracy: 93.96%
Epoch 3:


100%|██████████| 1356/1356 [00:28<00:00, 47.05it/s]


Test Accuracy: 96.67%
Epoch 4:


100%|██████████| 1356/1356 [00:29<00:00, 46.07it/s]


Test Accuracy: 97.91%
Epoch 5:


100%|██████████| 1356/1356 [00:36<00:00, 37.44it/s]


Test Accuracy: 97.67%

=== Fold 6 ===
Epoch 1:


100%|██████████| 1356/1356 [00:29<00:00, 46.60it/s]


Test Accuracy: 88.20%
Epoch 2:


100%|██████████| 1356/1356 [00:28<00:00, 47.17it/s]


Test Accuracy: 94.11%
Epoch 3:


100%|██████████| 1356/1356 [00:29<00:00, 45.75it/s]


Test Accuracy: 95.52%
Epoch 4:


100%|██████████| 1356/1356 [00:29<00:00, 45.50it/s]


Test Accuracy: 97.29%
Epoch 5:


100%|██████████| 1356/1356 [00:31<00:00, 42.66it/s]


Test Accuracy: 97.50%

=== Fold 7 ===
Epoch 1:


100%|██████████| 1356/1356 [00:29<00:00, 45.93it/s]


Test Accuracy: 88.95%
Epoch 2:


100%|██████████| 1356/1356 [00:28<00:00, 47.60it/s]


Test Accuracy: 93.11%
Epoch 3:


100%|██████████| 1356/1356 [00:29<00:00, 46.18it/s]


Test Accuracy: 94.22%
Epoch 4:


100%|██████████| 1356/1356 [00:29<00:00, 46.75it/s]


Test Accuracy: 94.82%
Epoch 5:


100%|██████████| 1356/1356 [00:28<00:00, 47.25it/s]


Test Accuracy: 96.02%

=== Fold 8 ===
Epoch 1:


100%|██████████| 1356/1356 [00:29<00:00, 45.41it/s]


Test Accuracy: 62.66%
Epoch 2:


100%|██████████| 1356/1356 [00:29<00:00, 46.60it/s]


Test Accuracy: 91.11%
Epoch 3:


100%|██████████| 1356/1356 [00:29<00:00, 46.09it/s]


Test Accuracy: 93.11%
Epoch 4:


100%|██████████| 1356/1356 [00:29<00:00, 45.83it/s]


Test Accuracy: 94.17%
Epoch 5:


100%|██████████| 1356/1356 [00:30<00:00, 44.74it/s]


Test Accuracy: 95.11%

=== Fold 9 ===
Epoch 1:


100%|██████████| 1356/1356 [00:29<00:00, 46.03it/s]


Test Accuracy: 90.05%
Epoch 2:


100%|██████████| 1356/1356 [00:29<00:00, 46.28it/s]


Test Accuracy: 94.34%
Epoch 3:


100%|██████████| 1356/1356 [00:30<00:00, 44.71it/s]


Test Accuracy: 96.15%
Epoch 4:


100%|██████████| 1356/1356 [00:29<00:00, 45.80it/s]


Test Accuracy: 95.70%
Epoch 5:


100%|██████████| 1356/1356 [00:28<00:00, 46.99it/s]


Test Accuracy: 98.54%

=== Fold 10 ===
Epoch 1:


100%|██████████| 1356/1356 [00:28<00:00, 46.84it/s]


Test Accuracy: 82.01%
Epoch 2:


100%|██████████| 1356/1356 [00:28<00:00, 46.91it/s]


Test Accuracy: 91.16%
Epoch 3:


100%|██████████| 1356/1356 [00:29<00:00, 46.40it/s]


Test Accuracy: 95.63%
Epoch 4:


100%|██████████| 1356/1356 [00:29<00:00, 46.44it/s]


Test Accuracy: 95.67%
Epoch 5:


100%|██████████| 1356/1356 [00:29<00:00, 46.07it/s]


Test Accuracy: 97.64%


### Evaluate


In [13]:
print("Test accuracies for each architecture:")
for arch, correct_count in corrects.items():
    print(f"{arch}: {100*correct_count/totals[arch]:.2f}%")


# Print overall performance across folds
mean_acc = np.mean(list(accuracies.values()))
std_acc = np.std(list(accuracies.values()))
print(f"\nAverage LOGO cross-validated test accuracy: {mean_acc:.4f} ± {std_acc:.4f}")

Test accuracies for each architecture:
alpha: 98.76%
mips: 96.49%
sh4: 96.53%
armel: 98.26%
powerpcspe: 75.19%
sparc64: 84.35%
x32: 74.36%
sparc: 93.44%
powerpc: 77.18%
ia64: 99.43%
i386: 90.71%
hppa: 97.13%
mipsel: 96.19%
m68k: 96.53%
ppc64: 96.91%
armhf: 94.88%
s390x: 98.33%
ppc64el: 96.36%
mips64el: 96.41%
riscv64: 94.89%
arm64: 97.51%
s390: 97.34%
amd64: 78.64%

Average LOGO cross-validated test accuracy: 0.9649 ± 0.0174
