### Imports


In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

from src.dataset_loaders import ISAdetectDataset
from src.models import EmbeddingAndCNNModel
from src.transforms import Vector1D

### Setup


In [2]:
# Specify the model
MODEL = EmbeddingAndCNNModel
TARGET_FEATURE = "endianness"

# Model hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4

# Specify which groups to use as validation set. Set to None to validate all groups.
VALIDATION_GROUPS = None
# VALIDATION_GROUPS = ["ia64", "arm64", "m68k", "hppa", "ppc64"]

# Set to an integer to limit the dataset size. Set to None to disable limit.
MAX_FILES_PER_ISA = None

### Helper functions


In [3]:
def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device():
    """
    Returns 'cuda' if CUDA is available, else 'mps' if Apple Silicon GPU is available,
    otherwise 'cpu'.
    """
    device = None
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")
    return device

### Prepare


In [4]:
device = get_device()
set_seed(42)

dataset = ISAdetectDataset(
    dataset_path="../../dataset/ISAdetect/ISAdetect_full_dataset",
    transform=Vector1D(2048),
    file_byte_read_limit=2048,
    per_architecture_limit=MAX_FILES_PER_ISA,
)

groups = list(map(lambda x: x["architecture"], dataset.metadata))
target_feature = list(map(lambda x: x[TARGET_FEATURE], dataset.metadata))

Using device: cuda


### Train and evaluate


In [5]:
logo = LeaveOneGroupOut()
label_encoder = LabelEncoder()

fold = 1
accuracies = {}
for train_idx, test_idx in logo.split(
    X=range(len(dataset)), y=target_feature, groups=groups
):
    set_seed(42)

    group_left_out = groups[test_idx[0]]

    if VALIDATION_GROUPS != None and group_left_out not in VALIDATION_GROUPS:
        continue

    print(f"\n=== Fold {fold} – leaving out group '{group_left_out}' ===")
    fold += 1

    all_train_labels = [dataset.metadata[i][TARGET_FEATURE] for i in train_idx]
    label_encoder.fit(all_train_labels)

    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, test_idx)

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8
    )
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8
    )

    model = MODEL(input_length=2048, num_classes=2)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Train model
    for epoch in range(NUM_EPOCHS):
        model.train()
        print(f"\nEpoch {epoch+1}:")

        total_training_loss = 0
        for images, labels in tqdm(train_loader):
            images = images.to(device)

            encoded_labels = torch.from_numpy(
                label_encoder.transform(labels[TARGET_FEATURE])
            ).to(device)

            optimizer.zero_grad()
            predictions = model(images)
            loss = criterion(predictions, encoded_labels)
            loss.backward()
            optimizer.step()

            total_training_loss += loss.item()

        avg_training_loss = total_training_loss / len(train_loader)

        # Evaluate model
        model.eval()
        correct = 0
        total = 0
        total_test_loss = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                encoded_labels = torch.from_numpy(
                    label_encoder.transform(labels[TARGET_FEATURE])
                ).to(device)

                outputs = model(images)
                loss = criterion(outputs, encoded_labels)
                total_test_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == encoded_labels).sum().item()
                total += encoded_labels.size(0)

        avg_test_loss = total_test_loss / len(test_loader)
        accuracy = correct / total

        print(
            f"Training Loss: {avg_training_loss:.4f} | Test loss: {avg_test_loss:.4f}"
        )
        print(f"Test Accuracy: {100*accuracy:.2f}%")

    accuracies[group_left_out] = accuracy


=== Fold 1 – leaving out group 'alpha' ===

Epoch 1:


100%|██████████| 1445/1445 [00:15<00:00, 94.58it/s] 


Training Loss: 0.0424 | Test loss: 5.0345
Test Accuracy: 7.73%

Epoch 2:


100%|██████████| 1445/1445 [00:14<00:00, 98.10it/s] 


Training Loss: 0.0002 | Test loss: 2.4602
Test Accuracy: 43.29%

Epoch 3:


100%|██████████| 1445/1445 [00:14<00:00, 101.81it/s]


Training Loss: 0.0006 | Test loss: 0.4524
Test Accuracy: 90.43%

Epoch 4:


100%|██████████| 1445/1445 [00:14<00:00, 96.52it/s] 


Training Loss: 0.0003 | Test loss: 0.4482
Test Accuracy: 88.44%

Epoch 5:


100%|██████████| 1445/1445 [00:14<00:00, 96.97it/s] 


Training Loss: 0.0005 | Test loss: 0.2102
Test Accuracy: 93.43%

Epoch 6:


100%|██████████| 1445/1445 [00:14<00:00, 98.79it/s] 


Training Loss: 0.0007 | Test loss: 0.0280
Test Accuracy: 99.55%

Epoch 7:


100%|██████████| 1445/1445 [00:14<00:00, 97.30it/s] 


Training Loss: 0.0002 | Test loss: 1.1378
Test Accuracy: 81.72%

Epoch 8:


100%|██████████| 1445/1445 [00:14<00:00, 97.57it/s] 


Training Loss: 0.0000 | Test loss: 0.1664
Test Accuracy: 96.20%

Epoch 9:


100%|██████████| 1445/1445 [00:14<00:00, 100.78it/s]


Training Loss: 0.0002 | Test loss: 1.5995
Test Accuracy: 79.75%

Epoch 10:


100%|██████████| 1445/1445 [00:14<00:00, 98.84it/s] 


Training Loss: 0.0001 | Test loss: 0.0447
Test Accuracy: 99.55%

=== Fold 2 – leaving out group 'amd64' ===

Epoch 1:


100%|██████████| 1438/1438 [00:14<00:00, 97.60it/s] 


Training Loss: 0.0409 | Test loss: 0.2030
Test Accuracy: 91.11%

Epoch 2:


100%|██████████| 1438/1438 [00:14<00:00, 101.94it/s]


Training Loss: 0.0009 | Test loss: 0.0143
Test Accuracy: 99.54%

Epoch 3:


100%|██████████| 1438/1438 [00:14<00:00, 96.53it/s] 


Training Loss: 0.0005 | Test loss: 0.0031
Test Accuracy: 99.91%

Epoch 4:


100%|██████████| 1438/1438 [00:15<00:00, 95.58it/s] 


Training Loss: 0.0007 | Test loss: 0.0057
Test Accuracy: 99.93%

Epoch 5:


100%|██████████| 1438/1438 [00:15<00:00, 95.54it/s] 


Training Loss: 0.0001 | Test loss: 0.0074
Test Accuracy: 99.75%

Epoch 6:


100%|██████████| 1438/1438 [00:15<00:00, 94.03it/s] 


Training Loss: 0.0000 | Test loss: 0.0058
Test Accuracy: 99.82%

Epoch 7:


100%|██████████| 1438/1438 [00:14<00:00, 98.36it/s] 


Training Loss: 0.0007 | Test loss: 0.0608
Test Accuracy: 98.19%

Epoch 8:


100%|██████████| 1438/1438 [00:14<00:00, 96.16it/s] 


Training Loss: 0.0005 | Test loss: 0.0029
Test Accuracy: 99.93%

Epoch 9:


100%|██████████| 1438/1438 [00:14<00:00, 99.55it/s] 


Training Loss: 0.0000 | Test loss: 0.0025
Test Accuracy: 99.91%

Epoch 10:


100%|██████████| 1438/1438 [00:14<00:00, 98.28it/s] 


Training Loss: 0.0000 | Test loss: 0.0031
Test Accuracy: 99.91%

=== Fold 3 – leaving out group 'arm64' ===

Epoch 1:


100%|██████████| 1450/1450 [00:14<00:00, 97.69it/s] 


Training Loss: 0.0390 | Test loss: 4.4477
Test Accuracy: 20.43%

Epoch 2:


100%|██████████| 1450/1450 [00:14<00:00, 99.19it/s] 


Training Loss: 0.0008 | Test loss: 2.7006
Test Accuracy: 42.90%

Epoch 3:


100%|██████████| 1450/1450 [00:14<00:00, 98.44it/s] 


Training Loss: 0.0005 | Test loss: 2.4486
Test Accuracy: 48.38%

Epoch 4:


100%|██████████| 1450/1450 [00:14<00:00, 98.12it/s] 


Training Loss: 0.0002 | Test loss: 5.3136
Test Accuracy: 36.96%

Epoch 5:


100%|██████████| 1450/1450 [00:14<00:00, 97.80it/s] 


Training Loss: 0.0004 | Test loss: 5.2990
Test Accuracy: 31.74%

Epoch 6:


100%|██████████| 1450/1450 [00:14<00:00, 96.98it/s] 


Training Loss: 0.0002 | Test loss: 2.3403
Test Accuracy: 61.91%

Epoch 7:


100%|██████████| 1450/1450 [00:15<00:00, 95.65it/s] 


Training Loss: 0.0000 | Test loss: 1.9418
Test Accuracy: 71.45%

Epoch 8:


100%|██████████| 1450/1450 [00:14<00:00, 97.54it/s] 


Training Loss: 0.0004 | Test loss: 0.9916
Test Accuracy: 85.12%

Epoch 9:


100%|██████████| 1450/1450 [00:14<00:00, 96.83it/s] 


Training Loss: 0.0005 | Test loss: 2.9491
Test Accuracy: 51.07%

Epoch 10:


100%|██████████| 1450/1450 [00:14<00:00, 97.44it/s] 


Training Loss: 0.0000 | Test loss: 1.9253
Test Accuracy: 64.93%

=== Fold 4 – leaving out group 'armel' ===

Epoch 1:


100%|██████████| 1444/1444 [00:15<00:00, 95.71it/s] 


Training Loss: 0.0430 | Test loss: 0.0531
Test Accuracy: 99.00%

Epoch 2:


100%|██████████| 1444/1444 [00:14<00:00, 98.73it/s] 


Training Loss: 0.0009 | Test loss: 0.1169
Test Accuracy: 97.72%

Epoch 3:


100%|██████████| 1444/1444 [00:14<00:00, 98.27it/s] 


Training Loss: 0.0003 | Test loss: 0.0852
Test Accuracy: 98.88%

Epoch 4:


100%|██████████| 1444/1444 [00:14<00:00, 101.94it/s]


Training Loss: 0.0009 | Test loss: 0.0385
Test Accuracy: 99.33%

Epoch 5:


100%|██████████| 1444/1444 [00:14<00:00, 96.76it/s] 


Training Loss: 0.0004 | Test loss: 0.0860
Test Accuracy: 98.67%

Epoch 6:


100%|██████████| 1444/1444 [00:14<00:00, 97.12it/s] 


Training Loss: 0.0004 | Test loss: 0.0554
Test Accuracy: 98.78%

Epoch 7:


100%|██████████| 1444/1444 [00:14<00:00, 99.02it/s] 


Training Loss: 0.0000 | Test loss: 0.0090
Test Accuracy: 99.90%

Epoch 8:


100%|██████████| 1444/1444 [00:15<00:00, 93.59it/s] 


Training Loss: 0.0003 | Test loss: 0.0149
Test Accuracy: 99.75%

Epoch 9:


100%|██████████| 1444/1444 [00:15<00:00, 91.08it/s] 


Training Loss: 0.0004 | Test loss: 0.0326
Test Accuracy: 99.33%

Epoch 10:


100%|██████████| 1444/1444 [00:08<00:00, 172.59it/s]


Training Loss: 0.0000 | Test loss: 0.0100
Test Accuracy: 99.85%

=== Fold 5 – leaving out group 'armhf' ===

Epoch 1:


100%|██████████| 1444/1444 [00:08<00:00, 172.49it/s]


Training Loss: 0.0420 | Test loss: 0.0510
Test Accuracy: 98.57%

Epoch 2:


100%|██████████| 1444/1444 [00:08<00:00, 177.30it/s]


Training Loss: 0.0009 | Test loss: 0.0370
Test Accuracy: 99.17%

Epoch 3:


 14%|█▍        | 208/1444 [00:01<00:11, 105.45it/s]


KeyboardInterrupt: 

### Evaluate


In [7]:
print("Test accuracies for each fold/group:")
for group, acc in accuracies.items():
    print(f"{group}: {100*acc:.2f}%")


# Print overall performance across folds
mean_acc = np.mean(list(accuracies.values()))
std_acc = np.std(list(accuracies.values()))
print(f"\nAverage LOGO cross-validated test accuracy: {mean_acc:.4f} ± {std_acc:.4f}")

Test accuracies for each fold/group:
alpha: 99.55%
amd64: 99.91%
arm64: 64.93%
armel: 99.85%

Average LOGO cross-validated test accuracy: 0.9106 ± 0.1508
