In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from utils.data_loader import load_data
from models.attention_ae import AttentionAE
from models.mil_model import SimpleMIL
from torch.utils.data import DataLoader, Dataset
from torch.nn import MSELoss, BCELoss
from torch.optim import Adam

In [None]:
DATA_DIR         = './data'
LEADS            = [0, 1, 2]
SEG_SEC          = 10
BATCH_AE         = 64
EPOCH_AE         = 10
BATCH_MIL        = 1
EPOCH_MIL        = 5
ATTN_TH_VALUES   = np.linspace(0.1, 0.9, 9)
device           = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_bags, train_labels, test_bags, test_labels = load_data(
    DATA_DIR, LEADS, SEG_SEC
)

In [None]:
ae = AttentionAE(input_dim=SEG_SEC * len(LEADS), latent_dim=8).to(device)
opt_ae = Adam(ae.parameters(), lr=1e-4)
crit_ae = MSELoss()

# Flatten all instances for AE training
all_instances = np.concatenate(train_bags, axis=0)
instances_tensor = torch.tensor(
    all_instances.reshape(len(all_instances), -1),
    dtype=torch.float32
).to(device)
ae_dataset = torch.utils.data.TensorDataset(instances_tensor)
ae_loader  = DataLoader(ae_dataset, batch_size=BATCH_AE, shuffle=True)

for epoch in range(1, EPOCH_AE + 1):
    ae.train()
    total_loss = 0
    for (x_batch,) in ae_loader:
        recon, _ = ae(x_batch)
        loss = crit_ae(recon, x_batch)
        opt_ae.zero_grad()
        loss.backward()
        opt_ae.step()
        total_loss += loss.item()
    print(f"[AE] Epoch {epoch}/{EPOCH_AE}, Loss: {total_loss/len(ae_loader):.4f}")

In [None]:
def select_bag(ae_model, bag, threshold):
    ae_model.eval()
    with torch.no_grad():
        X = torch.tensor(
            bag.reshape(len(bag), -1),
            dtype=torch.float32
        ).to(device)
        _, scores = ae_model(X)         # [N,1]
        mask = (scores.squeeze() >= threshold).cpu().numpy()
        selected = bag[mask]
        # Ensure at least one instance
        if len(selected) == 0:
            idx = scores.squeeze().argmax().item()
            selected = bag[idx:idx+1]
    return selected

In [None]:
results = []
for th in ATTN_TH_VALUES:
    # Select instances for each bag
    sel_train = [select_bag(ae, b, th) for b in train_bags]
    sel_test  = [select_bag(ae, b, th) for b in test_bags]

    # Train MIL Model
    mil = SimpleMIL(input_dim=SEG_SEC * len(LEADS)).to(device)
    opt_mil = Adam(mil.parameters(), lr=1e-4)
    crit_mil = BCELoss()

    for _ in range(EPOCH_MIL):
        mil.train()
        for bag, label in zip(sel_train, train_labels):
            bag_tensor = torch.tensor(
                bag.reshape(len(bag), -1),
                dtype=torch.float32
            ).to(device)
            pred = mil(bag_tensor).squeeze()
            loss = crit_mil(pred, torch.tensor(label, dtype=torch.float32).to(device))
            opt_mil.zero_grad()
            loss.backward()
            opt_mil.step()

    # Evaluate on Test Set
    mil.eval()
    correct = 0
    for bag, label in zip(sel_test, test_labels):
        bag_tensor = torch.tensor(
            bag.reshape(len(bag), -1),
            dtype=torch.float32
        ).to(device)
        p = (mil(bag_tensor).item() >= 0.5)
        correct += (p == bool(label))
    accuracy = correct / len(test_labels)
    results.append((th, accuracy))
    print(f"Threshold {th:.2f} → Accuracy: {accuracy:.4f}")

In [None]:
thresholds, accuracies = zip(*results)
plt.figure()
plt.plot(thresholds, accuracies)
plt.xlabel('Attention Threshold')
plt.ylabel('Test Accuracy')
plt.title('Threshold Tuning Results')
plt.show()