### Imports


In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Function
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.preprocessing import LabelEncoder
from torchvision.models import resnet50
from tqdm import tqdm

from src.dataset_loaders import ISAdetectDataset
from src.transforms import GrayScaleImage

### Gradient Reversal Layer (GRL) and Domain-Adversarial Model


In [2]:
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        # Reverse the gradients (multiply by -alpha)
        return grad_output.neg() * ctx.alpha, None


def grad_reverse(x, alpha=1.0):
    return GradReverse.apply(x, alpha)


class DANNResnet(nn.Module):
    """
    A wrapper around a ResNet feature extractor that outputs:
      (a) a prediction over endianness (target)
      (b) a prediction over domains (ISA/architecture) where the gradients to the
          feature extractor are reversed.
    """

    def __init__(self, base_model, num_endian_classes=2, num_domains=23):
        super(DANNResnet, self).__init__()
        self.feature_extractor = base_model
        # Target classification head (endianness)
        self.classifier = nn.Linear(2048, num_endian_classes)
        # Domain classification head (ISA): note that GRL will reverse gradient
        self.domain_classifier = nn.Linear(2048, num_domains)

    def forward(self, x, grl_alpha=1.0):
        # The base model should output a feature vector of dimension 2048.
        # (We made sure to set fc=Identity in the base model below.)
        features = self.feature_extractor(x)
        # Endianness (target) prediction head
        endian_logits = self.classifier(features)
        # Domain prediction head with gradient reversal
        reverse_features = grad_reverse(features, grl_alpha)
        domain_logits = self.domain_classifier(reverse_features)
        return endian_logits, domain_logits

### Setup


In [3]:
# Specify the backbone; we ignore the default classification head later.
MODEL = resnet50
TARGET_FEATURE = "endianness"

# Model hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4

# You can choose to use only some groups for validation or all.
VALIDATION_GROUPS = None
# VALIDATION_GROUPS = ["mips", "mipsel"]
VALIDATION_GROUPS = ["ia64", "arm64", "m68k", "hppa", "ppc64", "s390", "s390x"]

# Set to an integer to limit the dataset size per ISA, or None to disable.
MAX_FILES_PER_ISA = None

### Helper Functions


In [4]:
def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device():
    """
    Returns the available device.
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")
    return device

### Prepare Data


In [None]:
device = get_device()
scaler = torch.cuda.amp.GradScaler()

dataset = ISAdetectDataset(
    dataset_path="../../dataset/ISAdetect/ISAdetect_full_dataset",
    feature_csv_path="../../dataset/ISAdetect-features.csv",
    transform=GrayScaleImage(224, 224),
    file_byte_read_limit=224 * 224,
    per_architecture_limit=MAX_FILES_PER_ISA,
)

# The groups for leave-one-group-out will be the ISA names.
groups = list(map(lambda x: x["architecture"], dataset.metadata))
# Endianness labels
target_feature_list = list(map(lambda x: x[TARGET_FEATURE], dataset.metadata))

### Train and Evaluate with Gradient Reversal (Adversarial Domain Training)


In [None]:
logo = LeaveOneGroupOut()

fold = 1
accuracies = {}
for train_idx, test_idx in logo.split(
    X=range(len(dataset)), y=target_feature_list, groups=groups
):
    set_seed()

    group_left_out = groups[test_idx[0]]
    if VALIDATION_GROUPS is not None and group_left_out not in VALIDATION_GROUPS:
        continue

    print(f"\n=== Fold {fold} – leaving out group '{group_left_out}' ===")
    fold += 1

    # Build target label encoder (for endianness) using training data.
    all_train_targets = [dataset.metadata[i][TARGET_FEATURE] for i in train_idx]
    target_label_encoder = LabelEncoder()
    target_label_encoder.fit(all_train_targets)

    # Build domain label encoder (for ISA/domain) using training data.
    all_train_domains = [dataset.metadata[i]["architecture"] for i in train_idx]
    domain_label_encoder = LabelEncoder()
    domain_label_encoder.fit(all_train_domains)
    num_domains = len(domain_label_encoder.classes_)

    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, test_idx)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        prefetch_factor=2,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        prefetch_factor=2,
    )

    # Build the backbone model and update it:
    base_model = MODEL(num_classes=2, weights=None)
    # Change the first Conv2d layer to accept a single (grayscale) channel.
    base_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    # Remove the final classification head so that the backbone returns a 2048-dim
    # feature vector.
    base_model.fc = nn.Identity()
    base_model = base_model.to(device)

    # Create the adversarial (DANN) model with two heads:
    model = DANNResnet(base_model, num_endian_classes=2, num_domains=num_domains).to(
        device
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    domain_loss_weight = 1.0  # Weight for the ISA/domain loss

    # Train model
    for epoch in range(NUM_EPOCHS):
        model.train()
        print(f"\nEpoch {epoch+1}:")
        total_training_loss = 0

        for images, labels in tqdm(train_loader):
            images = images.to(device)
            # Prepare endianness (target) labels.
            target_labels = torch.from_numpy(
                target_label_encoder.transform(labels[TARGET_FEATURE])
            ).to(device)
            # Prepare domain (ISA) labels.
            domain_labels = torch.from_numpy(
                domain_label_encoder.transform(labels["architecture"])
            ).to(device)

            optimizer.zero_grad()
            # You can adjust the grl_alpha if desired.
            with torch.cuda.amp.autocast():
                target_logits, domain_logits = model(images, grl_alpha=0.5)
                loss_target = criterion(target_logits, target_labels)
                loss_domain = criterion(domain_logits, domain_labels)
                loss = loss_target + domain_loss_weight * loss_domain

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_training_loss += loss.item()

        avg_training_loss = total_training_loss / len(train_loader)

        # Evaluate model: (we only evaluate the endianness classification)
        model.eval()
        correct = 0
        total = 0
        total_test_loss = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                target_labels = torch.from_numpy(
                    target_label_encoder.transform(labels[TARGET_FEATURE])
                ).to(device)

                # Only use the target (endianness) output.
                target_logits, _ = model(images, grl_alpha=0.5)
                loss = criterion(target_logits, target_labels)
                total_test_loss += loss.item()

                _, predicted = torch.max(target_logits, 1)
                correct += (predicted == target_labels).sum().item()
                total += target_labels.size(0)

        avg_test_loss = total_test_loss / len(test_loader)
        accuracy = correct / total

        print(
            f"Training Loss: {avg_training_loss:.4f} | "
            f"Test Loss: {avg_test_loss:.4f}"
        )
        print(f"Test Accuracy: {100*accuracy:.2f}%")

    accuracies[group_left_out] = accuracy

### Evaluate


In [None]:
print("Test accuracies for each fold/group:")
for group, acc in accuracies.items():
    print(f"{group}: {100*acc:.2f}%")

mean_acc = np.mean(list(accuracies.values()))
std_acc = np.std(list(accuracies.values()))
print(f"\nAverage LOGO cross-validated test accuracy: {mean_acc:.4f} ± {std_acc:.4f}")