In [1]:
import random
from src.transforms import GrayScaleImage
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.models import MINOS
from src.dataset_loaders import (
    ISAdetectDataset,
    random_train_test_split,
)

num_epochs = 1
learning_rate = 0.001
batch_size = 4
SEED = random.randint(0, 1000)

dataset = ISAdetectDataset(
    dataset_path="../../dataset/ISAdetect/ISAdetect_full_dataset",
    transform=GrayScaleImage(100, 100),
    file_byte_read_limit=100 * 100,
)
train_set, test_set = random_train_test_split(dataset, test_split=0.2, seed=SEED)


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=1)

### Define model and train


In [2]:
model = MINOS(num_classes=1)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

endianness_map = {"little": 0, "big": 1}

print("Training model...")
# Training loop
for epoch in range(num_epochs):
    print(f"before enumerate {epoch}")
    for i, (batch_x, batch_y) in enumerate(train_loader):

        mask = [
            i
            for i, arch in enumerate(batch_y["architecture"])
            if arch not in ["mips", "mipsel"]
        ]
        if not mask:
            continue  # Skip batch if all samples are MIPS

        # Apply mask to batch
        filtered_batch_x = batch_x[mask]
        filtered_batch_y = {k: [v[i] for i in mask] for k, v in batch_y.items()}

        optimizer.zero_grad()
        output = model(filtered_batch_x / 255.0)
        targets = torch.tensor(
            [endianness_map[e] for e in filtered_batch_y["endianness"]],
            dtype=torch.float32,
        )
        loss = criterion(output, targets.unsqueeze(1))
        loss.backward()
        optimizer.step()

        if (i + 1) % (len(train_loader) // 20) == 0:
            print(f"Step {i+1}, Loss: {loss.item()}")
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

Training model...
before enumerate 0
Step 963, Loss: 0.08947305381298065
Step 1926, Loss: 0.2363634705543518
Step 2889, Loss: 0.05780516192317009
Step 3852, Loss: 0.8281068801879883
Step 4815, Loss: 2.700714958336903e-07
Step 5778, Loss: 0.004636387340724468
Step 6741, Loss: 0.0009812796488404274
Step 7704, Loss: 0.010266322642564774
Step 8667, Loss: 8.74203294642939e-07
Step 9630, Loss: 7.294008537428454e-05
Step 10593, Loss: 4.711683686764445e-06
Step 11556, Loss: 1.886288373498246e-05
Step 12519, Loss: 0.0599493533372879
Step 13482, Loss: 5.624382879432233e-07
Step 14445, Loss: 0.0023599041160196066
Step 15408, Loss: 3.05547753765012e-10
Step 16371, Loss: 9.417663932254072e-06
Step 17334, Loss: 0.2260298728942871
Step 18297, Loss: 1.1801858818216715e-05
Step 19260, Loss: 1.831167537602596e-06
Epoch 1/1, Loss: 1.593092383700423e-05


### Test model on .code only


In [4]:
# Test model on code only
dataset.use_code_only = True
print(f"{dataset.use_code_only=}")
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    arch_stats = {}
    for batch_x, batch_y in test_loader:
        output = model(batch_x / 255.0)
        batch_y_endian = [endianness_map[e] for e in batch_y["endianness"]]
        batch_y_arch = batch_y["architecture"]
        for i in range(len(output)):
            if output[i] >= 0.5:
                pred = 1
            else:
                pred = 0

            current_arch = batch_y_arch[i]
            if current_arch not in arch_stats:
                arch_stats[current_arch] = {"correct": 0, "total": 0}

            if pred == batch_y_endian[i]:
                correct += 1
                arch_stats[current_arch]["correct"] += 1

            arch_stats[current_arch]["total"] += 1
            total += 1

    overall_accuracy = correct / total
    print(f"\nOverall Accuracy: {overall_accuracy:.4f} ({correct}/{total})")

    # Print per-architecture accuracies
    print("\nPer-Architecture Accuracies:")
    for arch in sorted(arch_stats.keys()):
        arch_correct = arch_stats[arch]["correct"]
        arch_total = arch_stats[arch]["total"]
        arch_accuracy = arch_correct / arch_total
        print(f"{arch:10s}: {arch_accuracy:.4f} ({arch_correct}/{arch_total})")

dataset.use_code_only=True

Overall Accuracy: 0.9607 (18521/19279)

Per-Architecture Accuracies:
alpha     : 1.0000 (795/795)
amd64     : 1.0000 (887/887)
arm64     : 0.9972 (702/704)
armel     : 1.0000 (765/765)
armhf     : 0.9987 (746/747)
hppa      : 0.9408 (921/979)
i386      : 1.0000 (1007/1007)
ia64      : 1.0000 (1018/1018)
m68k      : 0.9772 (900/921)
mips      : 0.2345 (163/695)
mips64el  : 1.0000 (883/883)
mipsel    : 0.9659 (680/704)
powerpc   : 0.9986 (715/716)
powerpcspe: 0.9964 (836/839)
ppc64     : 0.9825 (561/571)
ppc64el   : 1.0000 (693/693)
riscv64   : 0.9989 (879/880)
s390      : 0.9775 (1045/1069)
s390x     : 0.9720 (694/714)
sh4       : 1.0000 (1225/1225)
sparc     : 0.9758 (969/993)
sparc64   : 0.9439 (623/660)
x32       : 1.0000 (814/814)


In [6]:
# Test model on full program
dataset.use_code_only = False
print(f"{dataset.use_code_only=}")
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    arch_stats = {}
    for batch_x, batch_y in test_loader:
        output = model(batch_x / 255.0)
        batch_y_endian = [endianness_map[e] for e in batch_y["endianness"]]
        batch_y_arch = batch_y["architecture"]
        for i in range(len(output)):
            if output[i] >= 0.5:
                pred = 1
            else:
                pred = 0

            current_arch = batch_y_arch[i]
            if current_arch not in arch_stats:
                arch_stats[current_arch] = {"correct": 0, "total": 0}

            if pred == batch_y_endian[i]:
                correct += 1
                arch_stats[current_arch]["correct"] += 1

            arch_stats[current_arch]["total"] += 1
            total += 1

    overall_accuracy = correct / total
    print(f"\nOverall Accuracy: {overall_accuracy:.4f} ({correct}/{total})")

    # Print per-architecture accuracies
    print("\nPer-Architecture Accuracies:")
    for arch in sorted(arch_stats.keys()):
        arch_correct = arch_stats[arch]["correct"]
        arch_total = arch_stats[arch]["total"]
        arch_accuracy = arch_correct / arch_total
        print(f"{arch:10s}: {arch_accuracy:.4f} ({arch_correct}/{arch_total})")

dataset.use_code_only=False

Overall Accuracy: 0.5836 (11251/19279)

Per-Architecture Accuracies:
alpha     : 1.0000 (795/795)
amd64     : 1.0000 (887/887)
arm64     : 0.9986 (703/704)
armel     : 1.0000 (765/765)
armhf     : 0.9987 (746/747)
hppa      : 0.0051 (5/979)
i386      : 1.0000 (1007/1007)
ia64      : 1.0000 (1018/1018)
m68k      : 0.0000 (0/921)
mips      : 0.0014 (1/695)
mips64el  : 1.0000 (883/883)
mipsel    : 0.9943 (700/704)
powerpc   : 0.0279 (20/716)
powerpcspe: 0.0608 (51/839)
ppc64     : 0.0105 (6/571)
ppc64el   : 1.0000 (693/693)
riscv64   : 1.0000 (880/880)
s390      : 0.0000 (0/1069)
s390x     : 0.0000 (0/714)
sh4       : 0.9943 (1218/1225)
sparc     : 0.0483 (48/993)
sparc64   : 0.0167 (11/660)
x32       : 1.0000 (814/814)


In [10]:
from src.dataset_loaders import MipsMipselDataset as MD


dataset_mips = MD(
    mips_dir="../../dataset/ISAdetect/ISAdetect_full_dataset/mips",
    mipsel_dir="../../dataset/ISAdetect/ISAdetect_full_dataset/mipsel",
    transform=GrayScaleImage(100, 100),
)

train_set, test_set = random_train_test_split(dataset, test_split=0.2, seed=SEED)

mips_dataloader = DataLoader(dataset=dataset_mips, batch_size=4, shuffle=False)

model.eval()
with torch.no_grad():
    correct = 0
    correct_class_1 = 0
    class_1 = 0
    total = 0
    arch_stats = {}
    for batch_x, batch_y in mips_dataloader:
        output = model(batch_x / 255.0)
        for i in range(len(output)):
            if output[i] >= 0.5:
                pred = 1
            else:
                pred = 0
            if pred == batch_y[i]:
                correct += 1
            if batch_y[i] == 0:
                class_1 += 1
                if pred == 0:
                    correct_class_1 += 1
            total += 1

    print(f"Accuracy: {correct/total}")
    print(
        f"Class 1: {class_1}, Total: {total}, Percentage: {class_1/total}, Correct: {correct_class_1}"
    )

Accuracy: 0.3874659400544959
Class 1: 3565, Total: 7340, Percentage: 0.48569482288828336, Correct: 2754
