#### Set up hyperparams

In [1]:
import random
from torchvision import models
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Lambda, ToTensor
from tqdm import tqdm
from src.dataset_loaders import (
    ISAdetectDataset,
    random_train_test_split,
)
from src.transforms import GrayScaleImage

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using device: {device}")


num_epochs = 1
learning_rate = 0.001
batch_size = 16
SEED = random.randint(0, 1000)

transform = Compose(
    [
        GrayScaleImage(255, 255),
        Lambda(lambda x: x.float() / 255.0),
        # 2. Add a dimension to make it a single-channel image (1, 255, 255)
        # Lambda(lambda x: x.unsqueeze(0)),
        # 3. Duplicate the grayscale channel to create 3 channels (3, 255, 255)
        Lambda(lambda x: x.repeat(3, 1, 1)),
    ]
)

dataset = ISAdetectDataset(
    "../../dataset/ISAdetect/ISAdetect_full_dataset",
    transform=transform,
    file_byte_read_limit=255 * 255,
)
train_set, test_set = random_train_test_split(
    dataset=dataset, test_split=0.2, seed=SEED
)

train_dataloader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

Using device: cuda


#### Set up model

In [2]:
model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)

# Freeze all layers in the model
for param in model.parameters():
    param.requires_grad = True

num_inputs = model.classifier[1].in_features

model.classifier[1] = nn.Linear(num_inputs, 1)

for param in model.classifier[1].parameters():
    param.requires_grad = True

#### Criterions and optim

In [3]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params=model.classifier[1].parameters(), lr=learning_rate)

with torch.no_grad():
    it = iter(train_dataloader)
    first_x, first_y = next(it)
    print(first_x[0])
    print(first_x.shape)
    out = model(first_x)
    print(out)

torch.Size([16, 3, 255, 255])
tensor([[ 0.0014],
        [ 0.1326],
        [ 0.5038],
        [-0.0303],
        [-0.1013],
        [ 0.2860],
        [ 0.2360],
        [ 0.1214],
        [ 0.0934],
        [ 0.5258],
        [-0.3594],
        [-0.0244],
        [ 0.0216],
        [ 0.0906],
        [-0.1930],
        [-0.1185]])


#### Training loop

In [4]:
model.to(device)
dataset.use_code_only = True
endianness_map = {"little": 0, "big": 1}
for epoch in range(num_epochs):
    for i, (batch_x, batch_y) in enumerate(tqdm(train_dataloader)):
        batch_x = batch_x.to(device)
        optimizer.zero_grad()
        output = model(batch_x)
        targets = torch.tensor(
            [endianness_map[e] for e in batch_y["endianness"]],
            dtype=torch.float32,
        ).to(device)
        loss = criterion(output, targets.unsqueeze(1))
        loss.backward()
        optimizer.step()
        if (i + 1) % 400 == 0:
            print(f"Loss {loss.item()}")

  8%|▊         | 403/4820 [00:32<05:41, 12.93it/s]

Loss 0.3886141777038574


 17%|█▋        | 801/4820 [01:06<05:52, 11.41it/s]

Loss 0.5172271132469177


 25%|██▍       | 1201/4820 [01:40<05:18, 11.38it/s]

Loss 0.4752216339111328


 33%|███▎      | 1603/4820 [02:12<04:21, 12.30it/s]

Loss 0.44637545943260193


 42%|████▏     | 2003/4820 [02:46<03:36, 12.98it/s]

Loss 0.48386505246162415


 50%|████▉     | 2403/4820 [03:18<03:09, 12.77it/s]

Loss 0.8157467842102051


 58%|█████▊    | 2801/4820 [03:53<02:44, 12.31it/s]

Loss 0.6231052875518799


 66%|██████▋   | 3201/4820 [04:27<02:21, 11.40it/s]

Loss 0.4467373788356781


 75%|███████▍  | 3603/4820 [04:59<01:33, 13.06it/s]

Loss 0.29321324825286865


 83%|████████▎ | 4001/4820 [05:31<01:11, 11.43it/s]

Loss 0.3485111594200134


 91%|█████████▏| 4401/4820 [06:04<00:37, 11.03it/s]

Loss 0.5008001327514648


100%|█████████▉| 4801/4820 [06:36<00:01, 11.37it/s]

Loss 0.7154148817062378


100%|██████████| 4820/4820 [06:38<00:00, 12.10it/s]


In [5]:
dataset.use_code_only = True
print(f"{dataset.use_code_only=}")
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    arch_stats = {}
    for batch_x, batch_y in tqdm(test_dataloader):
        batch_x = batch_x.to(device)
        output = model(batch_x / 255.0)
        batch_y_endian = [endianness_map[e] for e in batch_y["endianness"]]
        batch_y_arch = batch_y["architecture"]
        for i in range(len(output)):
            if output[i] >= 0.5:
                pred = 1
            else:
                pred = 0

            current_arch = batch_y_arch[i]
            if current_arch not in arch_stats:
                arch_stats[current_arch] = {"correct": 0, "total": 0}

            if pred == batch_y_endian[i]:
                correct += 1
                arch_stats[current_arch]["correct"] += 1

            arch_stats[current_arch]["total"] += 1
            total += 1

    overall_accuracy = correct / total
    print("\nPer-Architecture Accuracies:")
    for arch in sorted(arch_stats.keys()):
        arch_correct = arch_stats[arch]["correct"]
        arch_total = arch_stats[arch]["total"]
        arch_accuracy = arch_correct / arch_total
        arch_stats[arch]["accuracy"] = arch_accuracy
        print(f"{arch:10s}: {arch_accuracy:.4f} ({arch_correct}/{arch_total})")

dataset.use_code_only=True


100%|██████████| 1205/1205 [00:33<00:00, 35.47it/s]


Per-Architecture Accuracies:
alpha     : 1.0000 (832/832)
amd64     : 1.0000 (869/869)
arm64     : 1.0000 (752/752)
armel     : 1.0000 (824/824)
armhf     : 1.0000 (789/789)
hppa      : 0.0000 (0/982)
i386      : 1.0000 (1009/1009)
ia64      : 1.0000 (967/967)
m68k      : 0.0000 (0/879)
mips      : 0.0000 (0/678)
mips64el  : 1.0000 (894/894)
mipsel    : 1.0000 (770/770)
powerpc   : 0.0000 (0/694)
powerpcspe: 0.0000 (0/812)
ppc64     : 0.0000 (0/554)
ppc64el   : 1.0000 (690/690)
riscv64   : 1.0000 (851/851)
s390      : 0.0000 (0/1028)
s390x     : 0.0000 (0/695)
sh4       : 1.0000 (1213/1213)
sparc     : 0.0000 (0/964)
sparc64   : 0.0000 (0/691)
x32       : 1.0000 (842/842)



