In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.models import EndiannessModel
from src.dataset_loaders import ISAdetectCodeOnlyDataset, random_train_test_split
from src.transforms import EndiannessCount

num_epochs = 2
learning_rate = 0.0005
batch_size = 4
SEED = 423

# Create dataloaders
mips_dir = "../../dataset/ISAdetect/ISAdetect_full_dataset/mips"
mipsel_dir = "../../dataset/ISAdetect/ISAdetect_full_dataset/mipsel"

dataset = ISAdetectCodeOnlyDataset(
    dataset_path="../../dataset/ISAdetect/ISAdetect_full_dataset",
    transform=EndiannessCount(),
    file_byte_read_limit=None,  # 10 * (2**10),
)
train_set, test_set = random_train_test_split(dataset, test_split=0.2, seed=SEED)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=1)

In [2]:
model = EndiannessModel(with_sigmoid=False)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

endianness_map = {"little": 0, "big": 1}

print("Training model...")
# Training loop
for epoch in range(num_epochs):
    print(f"before enumerate {epoch}")
    for i, (batch_x, batch_y) in enumerate(train_loader):
        if i == 0:
            print(f"starting_epoch {epoch}")
        optimizer.zero_grad()
        output = model(batch_x)
        targets = torch.tensor(
            [endianness_map[e] for e in batch_y["endianness"]], dtype=torch.float32
        )
        loss = criterion(output, targets.unsqueeze(1))
        loss.backward()
        optimizer.step()

        # if epoch == 0 and i < 10:
        #     print(f"Initial loss: {loss.item()}")

        if (i + 1) % (len(train_loader) // 10) == 0:
            print(f"Step {i+1}, Loss: {loss.item()}")
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

Training model...
before enumerate 0
starting_epoch 0
Step 1927, Loss: 0.8184065818786621
Step 3854, Loss: 0.3589358329772949
Step 5781, Loss: 1.0378791093826294
Step 7708, Loss: 0.21004444360733032
Step 9635, Loss: 0.28510743379592896
Step 11562, Loss: 0.10131195932626724
Step 13489, Loss: 0.13254953920841217
Step 15416, Loss: 0.15119773149490356
Step 17343, Loss: 0.27797284722328186
Step 19270, Loss: 0.0006362693966366351
Epoch 1/2, Loss: 0.00014799779455643147
before enumerate 1
starting_epoch 1
Step 1927, Loss: 0.13367818295955658
Step 3854, Loss: 0.20184054970741272
Step 5781, Loss: 0.3954026699066162
Step 7708, Loss: 0.24861712753772736
Step 9635, Loss: 0.1988215297460556
Step 11562, Loss: 0.2767378091812134
Step 13489, Loss: 0.3927507996559143
Step 15416, Loss: 0.16777241230010986
Step 17343, Loss: 0.07518836110830307
Step 19270, Loss: 0.004490221850574017
Epoch 2/2, Loss: 0.015949279069900513


In [3]:
# Test model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    arch_stats = {}
    for batch_x, batch_y in test_loader:
        output = model(batch_x)
        batch_y_endian = [endianness_map[e] for e in batch_y["endianness"]]
        batch_y_arch = batch_y["architecture"]
        for i in range(len(output)):
            if output[i] >= 0.5:
                pred = 1
            else:
                pred = 0

            current_arch = batch_y_arch[i]
            if current_arch not in arch_stats:
                arch_stats[current_arch] = {"correct": 0, "total": 0}

            if pred == batch_y_endian[i]:
                correct += 1
                arch_stats[current_arch]["correct"] += 1

            arch_stats[current_arch]["total"] += 1
            total += 1

    overall_accuracy = correct / total
    print(f"\nOverall Accuracy: {overall_accuracy:.4f} ({correct}/{total})")

    # Print per-architecture accuracies
    print("\nPer-Architecture Accuracies:")
    for arch in sorted(arch_stats.keys()):
        arch_correct = arch_stats[arch]["correct"]
        arch_total = arch_stats[arch]["total"]
        arch_accuracy = arch_correct / arch_total
        print(f"{arch:10s}: {arch_accuracy:.4f} ({arch_correct}/{arch_total})")


Overall Accuracy: 0.9315 (17958/19279)

Per-Architecture Accuracies:
alpha     : 0.9988 (827/828)
amd64     : 1.0000 (877/877)
arm64     : 1.0000 (738/738)
armel     : 1.0000 (751/751)
armhf     : 0.9145 (706/772)
hppa      : 0.3549 (346/975)
i386      : 1.0000 (1004/1004)
ia64      : 1.0000 (992/992)
m68k      : 0.9890 (895/905)
mips      : 0.8931 (618/692)
mips64el  : 0.9775 (867/887)
mipsel    : 0.9747 (733/752)
powerpc   : 0.9201 (691/751)
powerpcspe: 0.9179 (660/719)
ppc64     : 0.9946 (555/558)
ppc64el   : 1.0000 (692/692)
riscv64   : 0.8979 (783/872)
s390      : 0.9469 (981/1036)
s390x     : 0.9986 (719/720)
sh4       : 0.9613 (1142/1188)
sparc     : 0.9005 (923/1025)
sparc64   : 0.8701 (583/670)
x32       : 1.0000 (875/875)
