In [6]:
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import pandas as pd

from data_loader_darts import get_dataloaders_simple
from darts_search_bdp import train_darts_search_bdp
from model_build import FinalNetwork
from cell_plot import plot_cell

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [7]:

def train_final_model(model, train_loader, val_loader, device, epochs=25):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.025, momentum=0.9, weight_decay=3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_acc = 0
    train_loss_list, val_loss_list, train_acc_list, val_acc_list = [], [], [], []

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            x = x.squeeze(-1)
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (logits.argmax(dim=1) == y).sum().item()
            total += y.size(0)

        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device).squeeze(-1), y.to(device)
                logits = model(x)
                loss = criterion(logits, y)
                val_loss += loss.item()
                val_correct += (logits.argmax(dim=1) == y).sum().item()
                val_total += y.size(0)

        val_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_final_model.pt")

        scheduler.step()
        print(f"[Final Train Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    plt.figure()
    plt.plot(train_loss_list, label='Train Loss')
    plt.plot(val_loss_list, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curve')
    plt.savefig('final_loss.png')
    plt.close()

    plt.figure()
    plt.plot(train_acc_list, label='Train Acc')
    plt.plot(val_acc_list, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy Curve')
    plt.savefig('final_accuracy.png')
    plt.close()

def evaluate_model(model, val_loader, device):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device).squeeze(-1)
            logits = model(x)
            pred = logits.argmax(dim=1).cpu().numpy()
            y_true.extend(y.numpy())
            y_pred.extend(pred)

    cm = confusion_matrix(y_true, y_pred)
    print("\nConfusion Matrix on Validation:")
    print(cm)
    pd.DataFrame(cm).to_csv("confusion_matrix.csv", index=False)
    print("[\u2713] Saved confusion matrix to confusion_matrix.csv")

    pd.DataFrame({"y_true": y_true, "y_pred": y_pred}).to_csv("val_predictions.csv", index=False)
    print("[\u2713] Saved predictions to val_predictions.csv")




In [8]:

# 1. Set random seed
set_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 2. Load data
print("[INFO] Loading 50/50 split data...")
train_loader, val_loader, num_classes = get_dataloaders_simple(batch_size=32)
print(f"[INFO] DARTS will run on {len(train_loader.dataset.y)} train samples and {len(val_loader.dataset.y)} val samples")

[INFO] Loading 50/50 split data...
[DEBUG] Loaded ./PSG/SC4001E0.npz → 841 samples
[DEBUG] Loaded ./PSG/SC4002E0.npz → 1127 samples
[DEBUG] Loaded ./PSG/SC4011E0.npz → 1103 samples
[DEBUG] Loaded ./PSG/SC4012E0.npz → 1186 samples
[DEBUG] Loaded ./PSG/SC4021E0.npz → 1025 samples
[DEBUG] Loaded ./PSG/SC4022E0.npz → 1009 samples
[DEBUG] Loaded ./PSG/SC4031E0.npz → 952 samples
[DEBUG] Loaded ./PSG/SC4032E0.npz → 911 samples
[DEBUG] Loaded ./PSG/SC4041E0.npz → 1235 samples
[DEBUG] Loaded ./PSG/SC4042E0.npz → 1200 samples
[DEBUG] Loaded ./PSG/SC4051E0.npz → 672 samples
[DEBUG] Loaded ./PSG/SC4052E0.npz → 1246 samples
[DEBUG] Loaded ./PSG/SC4061E0.npz → 843 samples
[DEBUG] Loaded ./PSG/SC4062E0.npz → 1016 samples
[DEBUG] Loaded ./PSG/SC4071E0.npz → 976 samples
[DEBUG] Loaded ./PSG/SC4072E0.npz → 1273 samples
[DEBUG] Loaded ./PSG/SC4081E0.npz → 1134 samples
[DEBUG] Loaded ./PSG/SC4082E0.npz → 1054 samples
[DEBUG] Loaded ./PSG/SC4091E0.npz → 1132 samples
[DEBUG] Loaded ./PSG/SC4092E0.npz → 1105

In [9]:
# 3. Run DARTS search with pruning
print("[INFO] Running DARTS search with BDP...")
searched_genotype, pruned_train_loader, pruned_val_loader = train_darts_search_bdp(
    train_loader, val_loader, num_classes,
    epochs=40, prune_every=7, pt=0.05, pv=0.05,
    device=device
)

[INFO] Running DARTS search with BDP...

[Epoch 1/27] Starting...
  [Step 000] Loss: 1.5110 | Acc: 0.4688
  [Step 001] Loss: 1.6911 | Acc: 0.1250
  [Step 002] Loss: 1.6090 | Acc: 0.2812
  [Step 003] Loss: 1.4410 | Acc: 0.6250
  [Step 004] Loss: 1.5858 | Acc: 0.2500
  [Step 005] Loss: 1.4449 | Acc: 0.3750
  [Step 006] Loss: 1.3907 | Acc: 0.3750
  [Step 007] Loss: 1.3827 | Acc: 0.3125
  [Step 008] Loss: 1.2448 | Acc: 0.6250
  [Step 009] Loss: 1.3367 | Acc: 0.4688
  [Step 010] Loss: 1.2459 | Acc: 0.5938
  [Step 011] Loss: 1.1620 | Acc: 0.5625
  [Step 012] Loss: 1.1318 | Acc: 0.5625
  [Step 013] Loss: 1.5355 | Acc: 0.3750
  [Step 014] Loss: 1.2681 | Acc: 0.6562
  [Step 015] Loss: 1.2638 | Acc: 0.3750
  [Step 016] Loss: 1.4126 | Acc: 0.5000
  [Step 017] Loss: 1.3874 | Acc: 0.3125
  [Step 018] Loss: 1.0820 | Acc: 0.6250
  [Step 019] Loss: 1.4483 | Acc: 0.3438
  [Step 020] Loss: 1.2594 | Acc: 0.4062
  [Step 021] Loss: 1.3662 | Acc: 0.4375
  [Step 022] Loss: 1.1266 | Acc: 0.6562
  [Step 023] L

  [Step 204] Loss: 0.9853 | Acc: 0.5625
  [Step 205] Loss: 0.9651 | Acc: 0.6562
  [Step 206] Loss: 0.9857 | Acc: 0.6250
  [Step 207] Loss: 1.2731 | Acc: 0.5000
  [Step 208] Loss: 1.1585 | Acc: 0.5000
  [Step 209] Loss: 1.0533 | Acc: 0.5625
  [Step 210] Loss: 1.0168 | Acc: 0.5625
  [Step 211] Loss: 1.1847 | Acc: 0.3125
  [Step 212] Loss: 0.8887 | Acc: 0.7188
  [Step 213] Loss: 1.0430 | Acc: 0.6562
  [Step 214] Loss: 1.0698 | Acc: 0.5625
  [Step 215] Loss: 1.3146 | Acc: 0.4375
  [Step 216] Loss: 1.6129 | Acc: 0.3438
  [Step 217] Loss: 1.3905 | Acc: 0.4062
  [Step 218] Loss: 0.9440 | Acc: 0.6875
  [Step 219] Loss: 1.1626 | Acc: 0.5312
  [Step 220] Loss: 1.1555 | Acc: 0.5625
  [Step 221] Loss: 1.3042 | Acc: 0.4375
  [Step 222] Loss: 1.0982 | Acc: 0.6875
  [Step 223] Loss: 1.3437 | Acc: 0.4062
  [Step 224] Loss: 1.1960 | Acc: 0.4062
  [Step 225] Loss: 1.2634 | Acc: 0.3750
  [Step 226] Loss: 1.2906 | Acc: 0.3438
  [Step 227] Loss: 1.0891 | Acc: 0.5938
  [Step 228] Loss: 1.2838 | Acc: 0.2500


  [Step 000] Loss: 2.1371 | Acc: 0.6250
  [Step 001] Loss: 0.7069 | Acc: 0.6562
  [Step 002] Loss: 1.0855 | Acc: 0.7188
  [Step 003] Loss: 0.7747 | Acc: 0.6250
  [Step 004] Loss: 1.1593 | Acc: 0.6562
  [Step 005] Loss: 0.7341 | Acc: 0.6875
  [Step 006] Loss: 0.7525 | Acc: 0.6250
  [Step 007] Loss: 1.1172 | Acc: 0.5625
  [Step 008] Loss: 0.7614 | Acc: 0.7188
  [Step 009] Loss: 0.6957 | Acc: 0.6562
  [Step 010] Loss: 0.8032 | Acc: 0.6562
  [Step 011] Loss: 0.7873 | Acc: 0.7188
  [Step 012] Loss: 0.8711 | Acc: 0.7500
  [Step 013] Loss: 0.8290 | Acc: 0.6250
  [Step 014] Loss: 1.0345 | Acc: 0.4688
  [Step 015] Loss: 0.8826 | Acc: 0.5625
  [Step 016] Loss: 0.8321 | Acc: 0.6875
  [Step 017] Loss: 0.7020 | Acc: 0.7188
  [Step 018] Loss: 0.9260 | Acc: 0.6250
  [Step 019] Loss: 0.8277 | Acc: 0.7188
  [Step 020] Loss: 0.6818 | Acc: 0.6562
  [Step 021] Loss: 0.7442 | Acc: 0.7812
  [Step 022] Loss: 0.8156 | Acc: 0.6875
  [Step 023] Loss: 1.6198 | Acc: 0.5312
  [Step 024] Loss: 0.9145 | Acc: 0.5938


  [Step 205] Loss: 1.2275 | Acc: 0.5625
  [Step 206] Loss: 0.7924 | Acc: 0.5938
  [Step 207] Loss: 1.0251 | Acc: 0.6250
  [Step 208] Loss: 0.8469 | Acc: 0.6250
  [Step 209] Loss: 0.8677 | Acc: 0.6250
  [Step 210] Loss: 1.4360 | Acc: 0.4688
  [Step 211] Loss: 1.3553 | Acc: 0.5625
  [Step 212] Loss: 1.0047 | Acc: 0.5938
  [Step 213] Loss: 1.0088 | Acc: 0.6250
  [Step 214] Loss: 0.9857 | Acc: 0.5938
  [Step 215] Loss: 0.9435 | Acc: 0.5000
  [Step 216] Loss: 0.9178 | Acc: 0.6562
  [Step 217] Loss: 1.2286 | Acc: 0.4688
  [Step 218] Loss: 1.0072 | Acc: 0.5312
  [Step 219] Loss: 0.8894 | Acc: 0.7188
  [Step 220] Loss: 0.9891 | Acc: 0.5938
  [Step 221] Loss: 0.8701 | Acc: 0.7188
  [Step 222] Loss: 0.6740 | Acc: 0.7812
  [Step 223] Loss: 0.9275 | Acc: 0.6562
  [Step 224] Loss: 1.0161 | Acc: 0.6562
  [Step 225] Loss: 1.0837 | Acc: 0.6562
  [Step 226] Loss: 1.0858 | Acc: 0.5000
  [Step 227] Loss: 1.2749 | Acc: 0.4062
  [Step 228] Loss: 1.4103 | Acc: 0.4062
  [Step 229] Loss: 0.9567 | Acc: 0.6562


  [Step 000] Loss: 2.9531 | Acc: 0.5000
  [Step 001] Loss: 1.1821 | Acc: 0.5312
  [Step 002] Loss: 0.9799 | Acc: 0.6875
  [Step 003] Loss: 0.7580 | Acc: 0.7188
  [Step 004] Loss: 0.9939 | Acc: 0.5938
  [Step 005] Loss: 0.9512 | Acc: 0.5625
  [Step 006] Loss: 0.6557 | Acc: 0.7812
  [Step 007] Loss: 0.8379 | Acc: 0.6562
  [Step 008] Loss: 0.8892 | Acc: 0.6562
  [Step 009] Loss: 0.7287 | Acc: 0.6875
  [Step 010] Loss: 0.9066 | Acc: 0.6875
  [Step 011] Loss: 0.8788 | Acc: 0.7188
  [Step 012] Loss: 0.9314 | Acc: 0.7188
  [Step 013] Loss: 0.6285 | Acc: 0.6875
  [Step 014] Loss: 0.8765 | Acc: 0.5625
  [Step 015] Loss: 1.1068 | Acc: 0.4375
  [Step 016] Loss: 0.9805 | Acc: 0.5938
  [Step 017] Loss: 0.5520 | Acc: 0.7500
  [Step 018] Loss: 0.8393 | Acc: 0.6562
  [Step 019] Loss: 1.3302 | Acc: 0.4375
  [Step 020] Loss: 1.0900 | Acc: 0.4688
  [Step 021] Loss: 0.9740 | Acc: 0.5625
  [Step 022] Loss: 0.9492 | Acc: 0.5938
  [Step 023] Loss: 0.6272 | Acc: 0.7188
  [Step 024] Loss: 0.7616 | Acc: 0.6562


  [Step 205] Loss: 0.9425 | Acc: 0.6250
  [Step 206] Loss: 0.7783 | Acc: 0.7188
  [Step 207] Loss: 0.7761 | Acc: 0.7812
  [Step 208] Loss: 0.8005 | Acc: 0.6875
  [Step 209] Loss: 1.1095 | Acc: 0.5625
  [Step 210] Loss: 0.9112 | Acc: 0.6250
  [Step 211] Loss: 0.6963 | Acc: 0.6562
  [Step 212] Loss: 0.7610 | Acc: 0.7188
  [Step 213] Loss: 0.4580 | Acc: 0.9062
  [Step 214] Loss: 0.8250 | Acc: 0.6562
  [Step 215] Loss: 1.1210 | Acc: 0.6250
  [Step 216] Loss: 0.6523 | Acc: 0.7812
  [Step 217] Loss: 0.7810 | Acc: 0.7500
  [Step 218] Loss: 0.7861 | Acc: 0.6562
  [Step 219] Loss: 0.6678 | Acc: 0.6875
  [Step 220] Loss: 0.7593 | Acc: 0.6562
  [Step 221] Loss: 0.6044 | Acc: 0.8125
  [Step 222] Loss: 0.7125 | Acc: 0.7500
  [Step 223] Loss: 0.7188 | Acc: 0.6875
  [Step 224] Loss: 0.9521 | Acc: 0.5938
  [Step 225] Loss: 0.9192 | Acc: 0.6562
  [Step 226] Loss: 0.5971 | Acc: 0.7500
  [Step 227] Loss: 0.6158 | Acc: 0.8750
  [Step 228] Loss: 0.8814 | Acc: 0.6250
  [Step 229] Loss: 0.9164 | Acc: 0.6562


  [Step 000] Loss: 1.8237 | Acc: 0.5625
  [Step 001] Loss: 0.7165 | Acc: 0.7812
  [Step 002] Loss: 0.9831 | Acc: 0.6875
  [Step 003] Loss: 0.8710 | Acc: 0.6875
  [Step 004] Loss: 0.7885 | Acc: 0.6250
  [Step 005] Loss: 0.6350 | Acc: 0.8125
  [Step 006] Loss: 0.6992 | Acc: 0.8125
  [Step 007] Loss: 0.7671 | Acc: 0.6875
  [Step 008] Loss: 1.0306 | Acc: 0.4688
  [Step 009] Loss: 0.8068 | Acc: 0.6875
  [Step 010] Loss: 0.6454 | Acc: 0.7188
  [Step 011] Loss: 1.0335 | Acc: 0.5938
  [Step 012] Loss: 0.7480 | Acc: 0.6875
  [Step 013] Loss: 0.8449 | Acc: 0.5938
  [Step 014] Loss: 0.8027 | Acc: 0.6250
  [Step 015] Loss: 1.0112 | Acc: 0.6250
  [Step 016] Loss: 0.7496 | Acc: 0.7188
  [Step 017] Loss: 0.8197 | Acc: 0.6875
  [Step 018] Loss: 0.7852 | Acc: 0.7500
  [Step 019] Loss: 0.6264 | Acc: 0.7812
  [Step 020] Loss: 0.8527 | Acc: 0.5938
  [Step 021] Loss: 0.6035 | Acc: 0.7188
  [Step 022] Loss: 0.7227 | Acc: 0.6250
  [Step 023] Loss: 0.9818 | Acc: 0.5000
  [Step 024] Loss: 0.7478 | Acc: 0.6562


  [Step 205] Loss: 0.5262 | Acc: 0.8125
  [Step 206] Loss: 1.2007 | Acc: 0.6250
  [Step 207] Loss: 0.8466 | Acc: 0.7188
  [Step 208] Loss: 0.6894 | Acc: 0.6562
  [Step 209] Loss: 0.6747 | Acc: 0.6875
  [Step 210] Loss: 0.5263 | Acc: 0.8438
  [Step 211] Loss: 0.9774 | Acc: 0.6875
  [Step 212] Loss: 0.7696 | Acc: 0.6562
  [Step 213] Loss: 0.8347 | Acc: 0.7500
  [Step 214] Loss: 0.5795 | Acc: 0.7500
  [Step 215] Loss: 1.0167 | Acc: 0.6250
  [Step 216] Loss: 1.0573 | Acc: 0.5625
  [Step 217] Loss: 0.5088 | Acc: 0.8750
  [Step 218] Loss: 1.0218 | Acc: 0.4688
  [Step 219] Loss: 0.6709 | Acc: 0.8438
  [Step 220] Loss: 1.0388 | Acc: 0.5625
  [Step 221] Loss: 0.8067 | Acc: 0.6562
  [Step 222] Loss: 0.9528 | Acc: 0.7188
  [Step 223] Loss: 0.7449 | Acc: 0.7188
  [Step 224] Loss: 0.8049 | Acc: 0.6250
  [Step 225] Loss: 0.5919 | Acc: 0.7812
  [Step 226] Loss: 0.9400 | Acc: 0.7188
  [Step 227] Loss: 0.6899 | Acc: 0.7500
  [Step 228] Loss: 0.8331 | Acc: 0.6562
  [Step 229] Loss: 1.1425 | Acc: 0.5938


  [Step 000] Loss: 1.7545 | Acc: 0.3750
  [Step 001] Loss: 1.0028 | Acc: 0.5938
  [Step 002] Loss: 0.7988 | Acc: 0.6562
  [Step 003] Loss: 0.9595 | Acc: 0.5938
  [Step 004] Loss: 0.7440 | Acc: 0.6875
  [Step 005] Loss: 0.6413 | Acc: 0.6562
  [Step 006] Loss: 0.9647 | Acc: 0.5312
  [Step 007] Loss: 0.9811 | Acc: 0.5938
  [Step 008] Loss: 0.7432 | Acc: 0.6562
  [Step 009] Loss: 0.8550 | Acc: 0.6562
  [Step 010] Loss: 0.9420 | Acc: 0.5312
  [Step 011] Loss: 0.6981 | Acc: 0.6875
  [Step 012] Loss: 0.8445 | Acc: 0.6250
  [Step 013] Loss: 0.7145 | Acc: 0.7188
  [Step 014] Loss: 0.7594 | Acc: 0.6875
  [Step 015] Loss: 0.4842 | Acc: 0.8438
  [Step 016] Loss: 0.8283 | Acc: 0.7188
  [Step 017] Loss: 0.7042 | Acc: 0.7812
  [Step 018] Loss: 1.0358 | Acc: 0.6250
  [Step 019] Loss: 0.8721 | Acc: 0.5312
  [Step 020] Loss: 0.5858 | Acc: 0.8750
  [Step 021] Loss: 0.7050 | Acc: 0.7812
  [Step 022] Loss: 0.8689 | Acc: 0.6250
  [Step 023] Loss: 0.4954 | Acc: 0.8438
  [Step 024] Loss: 0.7403 | Acc: 0.7188


  [Step 205] Loss: 0.9722 | Acc: 0.7188
  [Step 206] Loss: 0.5365 | Acc: 0.8125
  [Step 207] Loss: 0.9048 | Acc: 0.6562
  [Step 208] Loss: 0.6458 | Acc: 0.7500
  [Step 209] Loss: 0.8780 | Acc: 0.6562
  [Step 210] Loss: 0.6533 | Acc: 0.7812
  [Step 211] Loss: 0.7340 | Acc: 0.6875
  [Step 212] Loss: 0.7282 | Acc: 0.7500
  [Step 213] Loss: 0.6247 | Acc: 0.6875
  [Step 214] Loss: 0.9919 | Acc: 0.6562
  [Step 215] Loss: 1.0398 | Acc: 0.7812
  [Step 216] Loss: 0.7760 | Acc: 0.6875
  [Step 217] Loss: 0.8458 | Acc: 0.6562
  [Step 218] Loss: 0.5450 | Acc: 0.8125
  [Step 219] Loss: 0.5244 | Acc: 0.8125
  [Step 220] Loss: 0.7513 | Acc: 0.6562
  [Step 221] Loss: 0.5625 | Acc: 0.7812
  [Step 222] Loss: 0.6162 | Acc: 0.8125
  [Step 223] Loss: 0.6785 | Acc: 0.6875
  [Step 224] Loss: 0.6904 | Acc: 0.7500
  [Step 225] Loss: 0.7872 | Acc: 0.8125
  [Step 226] Loss: 0.7226 | Acc: 0.7812
  [Step 227] Loss: 0.5483 | Acc: 0.8438
  [Step 228] Loss: 0.8019 | Acc: 0.7812
  [Step 229] Loss: 0.7996 | Acc: 0.6875


  [Step 000] Loss: 1.6495 | Acc: 0.5938
  [Step 001] Loss: 1.1327 | Acc: 0.5625
  [Step 002] Loss: 1.0268 | Acc: 0.5938
  [Step 003] Loss: 0.9358 | Acc: 0.6250
  [Step 004] Loss: 0.7954 | Acc: 0.6875
  [Step 005] Loss: 0.6523 | Acc: 0.7500
  [Step 006] Loss: 0.5697 | Acc: 0.6875
  [Step 007] Loss: 0.8231 | Acc: 0.6875
  [Step 008] Loss: 0.7369 | Acc: 0.7500
  [Step 009] Loss: 0.5635 | Acc: 0.8125
  [Step 010] Loss: 0.4695 | Acc: 0.8125
  [Step 011] Loss: 0.5153 | Acc: 0.7812
  [Step 012] Loss: 0.8198 | Acc: 0.6875
  [Step 013] Loss: 1.0590 | Acc: 0.5625
  [Step 014] Loss: 0.5632 | Acc: 0.8125
  [Step 015] Loss: 0.5675 | Acc: 0.7812
  [Step 016] Loss: 0.4042 | Acc: 0.9062
  [Step 017] Loss: 0.8071 | Acc: 0.6875
  [Step 018] Loss: 0.5989 | Acc: 0.8438
  [Step 019] Loss: 0.5787 | Acc: 0.7812
  [Step 020] Loss: 0.7178 | Acc: 0.6562
  [Step 021] Loss: 0.6671 | Acc: 0.7812
  [Step 022] Loss: 0.6481 | Acc: 0.7188
  [Step 023] Loss: 0.9265 | Acc: 0.7188
  [Step 024] Loss: 1.1517 | Acc: 0.5938


  [Step 205] Loss: 0.4778 | Acc: 0.9062
  [Step 206] Loss: 0.5676 | Acc: 0.8125
  [Step 207] Loss: 0.6565 | Acc: 0.7812
  [Step 208] Loss: 0.5339 | Acc: 0.8750
  [Step 209] Loss: 0.8188 | Acc: 0.8125
  [Step 210] Loss: 0.6961 | Acc: 0.7500
  [Step 211] Loss: 0.6671 | Acc: 0.7188
  [Step 212] Loss: 0.3646 | Acc: 0.8750
  [Step 213] Loss: 0.8949 | Acc: 0.6562
  [Step 214] Loss: 0.7112 | Acc: 0.7812
  [Step 215] Loss: 0.4975 | Acc: 0.8750
  [Step 216] Loss: 0.7553 | Acc: 0.6875
  [Step 217] Loss: 0.6970 | Acc: 0.7812
  [Step 218] Loss: 0.5649 | Acc: 0.8125
  [Step 219] Loss: 0.7789 | Acc: 0.7188
  [Step 220] Loss: 0.6876 | Acc: 0.8125
  [Step 221] Loss: 0.6533 | Acc: 0.7500
  [Step 222] Loss: 0.5183 | Acc: 0.8125
  [Step 223] Loss: 0.7364 | Acc: 0.7188
  [Step 224] Loss: 0.8164 | Acc: 0.6250
  [Step 225] Loss: 0.6850 | Acc: 0.6875
  [Step 226] Loss: 0.6764 | Acc: 0.7500
  [Step 227] Loss: 0.7646 | Acc: 0.6875
  [Step 228] Loss: 0.8401 | Acc: 0.6875
  [Step 229] Loss: 0.7643 | Acc: 0.7188


  [Step 000] Loss: 1.5848 | Acc: 0.3125
  [Step 001] Loss: 0.7830 | Acc: 0.6875
  [Step 002] Loss: 0.6972 | Acc: 0.6875
  [Step 003] Loss: 0.5621 | Acc: 0.8125
  [Step 004] Loss: 0.9042 | Acc: 0.5938
  [Step 005] Loss: 0.4954 | Acc: 0.8125
  [Step 006] Loss: 0.7186 | Acc: 0.7188
  [Step 007] Loss: 0.4404 | Acc: 0.8438
  [Step 008] Loss: 0.8759 | Acc: 0.6562
  [Step 009] Loss: 0.7109 | Acc: 0.6875
  [Step 010] Loss: 0.9137 | Acc: 0.6562
  [Step 011] Loss: 0.5257 | Acc: 0.8125
  [Step 012] Loss: 0.7300 | Acc: 0.7188
  [Step 013] Loss: 0.9626 | Acc: 0.7812
  [Step 014] Loss: 1.1104 | Acc: 0.7188
  [Step 015] Loss: 0.6509 | Acc: 0.6875
  [Step 016] Loss: 0.8641 | Acc: 0.5938
  [Step 017] Loss: 0.6902 | Acc: 0.6875
  [Step 018] Loss: 0.9769 | Acc: 0.6875
  [Step 019] Loss: 0.5850 | Acc: 0.8125
  [Step 020] Loss: 0.9863 | Acc: 0.5938
  [Step 021] Loss: 1.3117 | Acc: 0.6250
  [Step 022] Loss: 0.6524 | Acc: 0.7188
  [Step 023] Loss: 0.8854 | Acc: 0.6875
  [Step 024] Loss: 0.6177 | Acc: 0.7500


  [Step 205] Loss: 0.5354 | Acc: 0.8125
  [Step 206] Loss: 0.9184 | Acc: 0.6562
  [Step 207] Loss: 0.6436 | Acc: 0.7812
  [Step 208] Loss: 0.7199 | Acc: 0.7500
  [Step 209] Loss: 0.3105 | Acc: 0.9062
  [Step 210] Loss: 0.5700 | Acc: 0.7812
  [Step 211] Loss: 0.7570 | Acc: 0.6250
  [Step 212] Loss: 0.5753 | Acc: 0.7812
  [Step 213] Loss: 0.6155 | Acc: 0.7188
  [Step 214] Loss: 0.5919 | Acc: 0.8125
  [Step 215] Loss: 0.8935 | Acc: 0.7500
  [Step 216] Loss: 0.4647 | Acc: 0.8125
  [Step 217] Loss: 0.4957 | Acc: 0.8438
  [Step 218] Loss: 0.3521 | Acc: 0.8438
  [Step 219] Loss: 1.0441 | Acc: 0.6250
  [Step 220] Loss: 0.4724 | Acc: 0.8438
  [Step 221] Loss: 0.5890 | Acc: 0.8125
  [Step 222] Loss: 0.5881 | Acc: 0.7812
  [Step 223] Loss: 0.9946 | Acc: 0.6250
  [Step 224] Loss: 0.5744 | Acc: 0.8438
  [Step 225] Loss: 0.6345 | Acc: 0.7812
  [Step 226] Loss: 0.6401 | Acc: 0.7500
  [Step 227] Loss: 0.5138 | Acc: 0.8438
  [Step 228] Loss: 0.6673 | Acc: 0.7812
  [Step 229] Loss: 0.5869 | Acc: 0.7812


TypeError: unsupported operand type(s) for -: 'set' and 'int'

In [None]:

# 4. Visualize searched cells-----------------------------------------
print("[INFO] Visualizing searched cells...")
plot_cell(searched_genotype, 'normal')
plot_cell(searched_genotype, 'reduce')

In [None]:
from sklearn.model_selection import KFold
from torch.utils.data import TensorDataset, DataLoader
import seaborn as sns
# 5. Prepare data for cross-validation
print("[INFO] Running 5-Fold Cross Validation on pruned data...")
X_all = torch.cat([pruned_train_loader.dataset.X, pruned_val_loader.dataset.X], dim=0)
y_all = torch.cat([pruned_train_loader.dataset.y, pruned_val_loader.dataset.y], dim=0)
dataset = TensorDataset(X_all, y_all)

In [None]:

import torch
import matplotlib.pyplot as plt
import os

def train_final_model(model, train_loader, val_loader, device, epochs=30, patience=15):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, weight_decay=0.0005)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, verbose=True
    )

    best_acc = 0
    early_stop_counter = 0
    train_loss_list, val_loss_list, train_acc_list, val_acc_list = [], [], [], []

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            x = x.squeeze(-1)
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (logits.argmax(dim=1) == y).sum().item()
            total += y.size(0)

        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)

        # Validation/Test
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device).squeeze(-1), y.to(device)
                logits = model(x)
                loss = criterion(logits, y)
                val_loss += loss.item()
                val_correct += (logits.argmax(dim=1) == y).sum().item()
                val_total += y.size(0)

        val_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)

        scheduler.step(val_acc)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Test Loss: {val_loss:.4f} | Test Acc: {val_acc:.4f}")

        # Early stopping
        if val_acc > best_acc:
            best_acc = val_acc
            early_stop_counter = 0
            torch.save(model.state_dict(), "final_model.pt")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("🛑 Early stopping triggered.")
                break

    # Plot
    os.makedirs("logs", exist_ok=True)

    plt.figure()
    plt.plot(train_loss_list, label='Train Loss')
    plt.plot(val_loss_list, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Final Training - Loss Curve')
    plt.legend()
    plt.savefig('logs/final_loss_full.png')
    plt.close()

    plt.figure()
    plt.plot(train_acc_list, label='Train Acc')
    plt.plot(val_acc_list, label='Test Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Final Training - Accuracy Curve')
    plt.legend()
    plt.savefig('logs/final_accuracy_full.png')
    plt.close()


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, TensorDataset
import torch

# --- STEP 1: Tách train/test ---
X_train, X_test, y_train, y_test = train_test_split(
    X_all, y_all, test_size=0.2, stratify=y_all, random_state=47
)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader  = DataLoader(TensorDataset(X_test,  y_test),  batch_size=64)

# # --- STEP 2: Định nghĩa mô hình ---
model = FinalNetwork(C=8, num_classes=num_classes, layers=5, genotype=searched_genotype).to(device)


# ✅ In ra cấu trúc mô hình trước khi train
print("\n[INFO] === Model Architecture ===")
print(model)
print(f"[INFO] Total parameters: {sum(p.numel() for p in model.parameters())}")
print(f"[INFO] Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# --- STEP 3: Train trên toàn bộ training set ---
train_final_model(model, train_loader, test_loader, device=device, epochs=1)
# --- STEP 4: Đánh giá trên tập test ---
model.load_state_dict(torch.load("final_model.pt"))
model.eval()

y_true, y_pred = [], []
criterion = torch.nn.CrossEntropyLoss()
test_loss_total = 0.0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device).squeeze(-1), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        test_loss_total += loss.item()
        y_pred.extend(logits.argmax(dim=1).cpu().numpy())
        y_true.extend(y.cpu().numpy())

test_loss = test_loss_total / len(test_loader)
test_acc = accuracy_score(y_true, y_pred)

print(f"[✓] Final Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")



In [None]:
# === structured_pruning_pipeline.py ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score

# === Compute filter importance ===
def compute_importance(weight_tensor):
    return weight_tensor.abs().sum(dim=(1, 2))

# === Structured pruning with group-awareness ===
def structured_prune_CLR_kRNF(model, prune_ratio=0.5):
    keep_filters_dict = {}
    importance_scores = {}
    all_scores = []

    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d):
            score = compute_importance(module.weight.data)
            importance_scores[name] = score
            all_scores.append(score)

    global_scores = torch.cat(all_scores)
    threshold = torch.quantile(global_scores, prune_ratio)

    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d):
            score = importance_scores[name]
            sorted_idx = torch.argsort(score, descending=True)
            n_keep = (score > threshold).sum().item()
            groups = module.groups if hasattr(module, "groups") else 1
            n_keep = max((n_keep // groups) * groups, groups)
            keep_idx = sorted_idx[:n_keep]
            keep_filters_dict[name] = keep_idx

    return keep_filters_dict

# === Apply pruning (Conv + BatchNorm) ===
def apply_structured_filter_prune(model, keep_filters_dict):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d) and name in keep_filters_dict:
            keep_idx = keep_filters_dict[name]
            new_out_channels = len(keep_idx)

            new_conv = nn.Conv1d(
                in_channels=module.in_channels,
                out_channels=new_out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                dilation=module.dilation,
                groups=module.groups,
                bias=module.bias is not None
            )
            new_conv.weight.data = module.weight.data[keep_idx].clone()
            if module.bias is not None:
                new_conv.bias.data = module.bias.data[keep_idx].clone()

            parent = model
            parts = name.split('.')
            for p in parts[:-1]:
                parent = getattr(parent, p)
            setattr(parent, parts[-1], new_conv)

            # === Try matching BN ===
            try:
                bn_name = parts[:-1] + [str(int(parts[-1]) + 1)]
                bn_ref = model
                for p in bn_name:
                    bn_ref = getattr(bn_ref, p)
                if isinstance(bn_ref, nn.BatchNorm1d):
                    new_bn = nn.BatchNorm1d(new_out_channels)
                    new_bn.weight.data = bn_ref.weight.data[keep_idx].clone()
                    new_bn.bias.data = bn_ref.bias.data[keep_idx].clone()
                    new_bn.running_mean = bn_ref.running_mean[keep_idx].clone()
                    new_bn.running_var = bn_ref.running_var[keep_idx].clone()
                    setattr(parent, str(int(parts[-1]) + 1), new_bn)
            except Exception:
                pass

# === Save/load model ===
def save_pruned_model(model, path="pruned_model_structured_full.pt"):
    torch.save(model, path)
    print(f"[\u2713] Pruned model saved to: {path}")

def load_pruned_model(path, device):
    model = torch.load(path)
    return model.to(device)

# === Display filter stats ===
def print_filters_remaining(model, keep_filters_dict):
    print("\n[INFO] Remaining filters per Conv1d layer:")
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d) and name in keep_filters_dict:
            print(f" - {name}: {len(keep_filters_dict[name])} / {module.out_channels} kept")

# === Train pruned model ===
def train_model(model, train_loader, test_loader, device, epochs=1):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device).squeeze(-1), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
        train_loss = running_loss / len(train_loader)
        train_acc = correct / total

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device).squeeze(-1), y.to(device)
                output = model(x)
                loss = criterion(output, y)
                val_loss += loss.item()
                pred = output.argmax(dim=1)
                val_correct += (pred == y).sum().item()
                val_total += y.size(0)
        val_loss /= len(test_loader)
        val_acc = val_correct / val_total

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} || Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")


In [None]:
# === 1. Prune model
keep_filters = structured_prune_CLR_kRNF(model, prune_ratio=0.5)
apply_structured_filter_prune(model, keep_filters)
print_filters_remaining(model, keep_filters)

# === 2. Save pruned model
save_pruned_model(model, "pruned_finalnetwork_structured_full.pt")

# === 3. Load lại và train tiếp
model = load_pruned_model("pruned_finalnetwork_structured_full.pt", device)
# ... rồi train như bình thường


In [None]:
model = load_pruned_model("pruned_finalnetwork_structured_full.pt", device)
train_model(model, train_loader, test_loader, device, epochs=1)


In [None]:
import torch
import torch.nn as nn

# 1. Tính độ quan trọng (L1-norm)
def compute_importance(weight_tensor):
    return weight_tensor.abs().sum(dim=(1, 2))

# 2. CLR + k-RNF Structured Pruning
def structured_prune_CLR_kRNF(model, prune_ratio=0.5):
    keep_filters_dict = {}
    importance_scores = {}
    all_scores = []

    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d):
            score = compute_importance(module.weight.data)
            importance_scores[name] = score
            all_scores.append(score)

    global_scores = torch.cat(all_scores)
    threshold = torch.quantile(global_scores, prune_ratio)

    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d):
            score = importance_scores[name]
            sorted_idx = torch.argsort(score, descending=True)
            n_keep = (score > threshold).sum().item()
            groups = module.groups if hasattr(module, "groups") else 1
            n_keep = max((n_keep // groups) * groups, groups)  # đảm bảo chia hết group
            keep_idx = sorted_idx[:n_keep]
            keep_filters_dict[name] = keep_idx

    return keep_filters_dict

def apply_structured_filter_prune(model, keep_filters_dict):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d) and name in keep_filters_dict:
            keep_idx = keep_filters_dict[name]
            new_out_channels = len(keep_idx)

            # Replace Conv1d
            new_conv = nn.Conv1d(
                in_channels=module.in_channels,
                out_channels=new_out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                dilation=module.dilation,
                groups=module.groups,
                bias=module.bias is not None
            )
            new_conv.weight.data = module.weight.data[keep_idx].clone()
            if module.bias is not None:
                new_conv.bias.data = module.bias.data[keep_idx].clone()

            # === Replace Conv1d ===
            parent = model
            parts = name.split('.')
            for p in parts[:-1]:
                parent = getattr(parent, p)
            setattr(parent, parts[-1], new_conv)

            # === Try to find accompanying BatchNorm ===
            try:
                bn_name = parts[:-1] + [str(int(parts[-1]) + 1)]  # e.g., stem.0 → stem.1
                bn_ref = model
                for p in bn_name:
                    bn_ref = getattr(bn_ref, p)
                if isinstance(bn_ref, nn.BatchNorm1d):
                    new_bn = nn.BatchNorm1d(new_out_channels)
                    new_bn.weight.data = bn_ref.weight.data[keep_idx].clone()
                    new_bn.bias.data = bn_ref.bias.data[keep_idx].clone()
                    new_bn.running_mean = bn_ref.running_mean[keep_idx].clone()
                    new_bn.running_var = bn_ref.running_var[keep_idx].clone()
                    setattr(parent, str(int(parts[-1]) + 1), new_bn)
            except Exception:
                pass  # No matching BN found


# 4. In số lượng filter còn lại trên mỗi lớp Conv1d
def print_filters_remaining(model, keep_filters_dict):
    print("\n[INFO] Remaining filters per Conv1d layer:")
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d) and name in keep_filters_dict:
            total = module.out_channels
            kept = len(keep_filters_dict[name])
            print(f" - {name}: {kept} / {total} filters kept")

# 5. Lưu mô hình đã prune
def save_pruned_model(model, path="pruned_model_structured.pt"):
    torch.save(model, "pruned_finalnetwork_structured_full.pt")



In [None]:

keep_filters = structured_prune_CLR_kRNF(model, prune_ratio=0.5)
print_filters_remaining(model, keep_filters)
apply_structured_filter_prune(model, keep_filters)
save_pruned_model(model, "pruned_finalnetwork_structured.pt")


In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import torch.nn as nn

# === STEP 1: Tách train/test ===
X_train, X_test, y_train, y_test = train_test_split(
    X_all, y_all, test_size=0.2, stratify=y_all, random_state=47
)
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader  = DataLoader(TensorDataset(X_test,  y_test),  batch_size=64)

model = torch.load("pruned_finalnetwork_structured_full.pt")
model.to(device)


# === STEP 3: Train lại mô hình
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(1):  # huấn luyện 1 epoch (hoặc tăng nếu muốn)
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for x, y in train_loader:
        x, y = x.to(device).squeeze(-1), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    train_loss = running_loss / len(train_loader)
    train_acc = correct / total

    # === Validation
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device).squeeze(-1), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            val_loss += loss.item()
            pred = output.argmax(dim=1)
            val_correct += (pred == y).sum().item()
            val_total += y.size(0)
    val_loss /= len(test_loader)
    val_acc = val_correct / val_total

    print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} || "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

# === STEP 4: Đánh giá sau cùng
model.eval()
y_true, y_pred = [], []
test_loss_total = 0.0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device).squeeze(-1), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        test_loss_total += loss.item()
        y_pred.extend(logits.argmax(dim=1).cpu().numpy())
        y_true.extend(y.cpu().numpy())

test_loss = test_loss_total / len(test_loader)
test_acc = accuracy_score(y_true, y_pred)
print(f"[✓] Final Test Loss: {test_loss:.4f} | Final Test Accuracy: {test_acc:.4f}")


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, TensorDataset
import torch
import torch.nn as nn
from torch.nn.utils import prune

# === STEP 1: Tách train/test ===
X_train, X_test, y_train, y_test = train_test_split(
    X_all, y_all, test_size=0.2, stratify=y_all, random_state=47
)
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader  = DataLoader(TensorDataset(X_test,  y_test),  batch_size=64)

# === STEP 2: Load mô hình ===
model = FinalNetwork(C=8, num_classes=num_classes, layers=5, genotype=searched_genotype).to(device)
model.load_state_dict(torch.load("final_model.pt"))

# === STEP 3: Prune 50% trọng số theo magnitude
def prune_model(model, amount=0.5):
    for module in model.modules():
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            prune.l1_unstructured(module, name="weight", amount=amount)
    return model

model = prune_model(model, amount=0.5)

# === STEP 4: XÓA MASK để cắt thật sự
def remove_pruning_masks(model):
    for module in model.modules():
        if isinstance(module, (nn.Conv1d, nn.Linear)) and hasattr(module, "weight_mask"):
            prune.remove(module, "weight")

remove_pruning_masks(model)

# === STEP 5: Đếm lại số lượng weight thật sự
def count_pruned_weights(model):
    total, nonzero = 0, 0
    for module in model.modules():
        if isinstance(module, (nn.Conv1d, nn.Linear)):
            w = module.weight.data
            total += w.numel()
            nonzero += w.nonzero().size(0)
    zero = total - nonzero
    print(f"[INFO] Total weights: {total}")
    print(f"[INFO] Non-zero weights: {nonzero}")
    print(f"[INFO] Pruned weights (==0): {zero}")
    print(f"[INFO] Pruned ratio: {100 * zero / total:.2f}%")

count_pruned_weights(model)

# === STEP 6: Train lại model đã prune
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(50):
    # === Training ===
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for x, y in train_loader:
        x, y = x.to(device).squeeze(-1), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    train_loss = running_loss / len(train_loader)
    train_acc = correct / total

    # === Validation ===
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device).squeeze(-1), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            val_loss += loss.item()
            pred = output.argmax(dim=1)
            val_correct += (pred == y).sum().item()
            val_total += y.size(0)
    val_loss /= len(test_loader)
    val_acc = val_correct / val_total

    print(f"[Epoch {epoch+1}] "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} || "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")



In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from copy import deepcopy
from collections import defaultdict

# === Utility to count parameters ===
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"[INFO] Total parameters: {total:,}")
    print(f"[INFO] Trainable parameters: {trainable:,}")

# === Pruning ===
def apply_l1_pruning(model, amount=0.5):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d):
            prune.l1_unstructured(module, name='weight', amount=amount)
    return model

# === Remove reparam hooks after pruning ===
def finalize_pruning(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d):
            try:
                prune.remove(module, 'weight')
            except:
                pass
    return model

# === Training ===
def train_model(model, train_loader, val_loader, device, epochs=10, lr=1e-3):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device).squeeze(-1), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * x.size(0)
            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)

        acc = 100 * correct / total
        print(f"[Train] Epoch {epoch+1}/{epochs} | Loss: {running_loss/total:.4f} | Acc: {acc:.2f}%")

        # Validation
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device).squeeze(-1), y.to(device)
                out = model(x)
                val_correct += (out.argmax(1) == y).sum().item()
                val_total += y.size(0)
        val_acc = 100 * val_correct / val_total
        print(f"[Eval ] Epoch {epoch+1}/{epochs} | Val Acc: {val_acc:.2f}%")

# === Evaluation ===
def evaluate_model(model, test_loader, device):
    model.to(device)
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device).squeeze(-1), y.to(device)
            out = model(x)
            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)
    print(f"[✓] Final Test Accuracy: {100 * correct / total:.2f}%")

# === Instantiate model ===
model = FinalNetwork(C=8, num_classes=5, layers=5, genotype=searched_genotype)
print("=== [ORIGINAL] ===")
count_params(model)

# === Apply pruning ===
model = apply_l1_pruning(model, amount=0.5)
model = finalize_pruning(model)
print("=== [PRUNED] ===")
count_params(model)

# === Train / Eval ===
# train_model(model, train_loader, test_loader, device, epochs=10)
# evaluate_model(model, test_loader, device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from collections import defaultdict
from model_build import FinalNetwork
from genotypes import Genotype

# === GENOTYPE ===
genotype = Genotype(
    normal=[('sep_conv_1x5', 1), ('sep_conv_1x3', 0), ('dil_conv_1x3', 1), ('dil_conv_1x3', 0),
            ('dil_conv_1x5', 1), ('conv_1x1', 0), ('conv_1x1', 1), ('conv_3x3', 4),
            ('dil_conv_1x5', 5), ('sep_conv_1x5', 1), ('sep_conv_1x5', 6), ('sep_conv_1x5', 5),
            ('dil_conv_1x3', 7), ('sep_conv_1x3', 6)],
    normal_concat=[2, 3, 4, 5, 6, 7, 8],
    reduce=[('dil_conv_1x3', 1), ('max_pool_3x3', 0), ('dil_conv_1x3', 1), ('max_pool_3x3', 0),
            ('skip_connect', 1), ('max_pool_3x3', 0), ('dil_conv_1x5', 1), ('dil_conv_1x3', 3),
            ('dil_conv_1x3', 5), ('dil_conv_1x3', 1), ('skip_connect', 1), ('sep_conv_1x5', 3),
            ('conv_1x1', 1), ('sep_conv_1x5', 3)],
    reduce_concat=[2, 3, 4, 5, 6, 7, 8]
)

# === IMPORTS ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from collections import defaultdict

# === PARAMETER COUNT ===
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"[INFO] Total parameters: {total:,}")
    print(f"[INFO] Trainable parameters: {trainable:,}")

# === L1 SCORE ===
def compute_importance_scores(model):
    importance = {}
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d) and module.groups == 1:
            score = torch.sum(torch.abs(module.weight.data), dim=(1, 2))
            importance[name] = score.cpu()
    return importance

# === GLOBAL RANKING ===
def get_global_ranking(importance_scores):
    all_scores = []
    mapping = {}
    pointer = 0
    for name, score in importance_scores.items():
        for i, s in enumerate(score):
            all_scores.append((s.item(), pointer))
            mapping[pointer] = (name, i)
            pointer += 1
    all_scores.sort(reverse=True, key=lambda x: x[0])
    ranked_indices = [idx for _, idx in all_scores]
    return ranked_indices, mapping, all_scores

# === TOP FILTERS ===
def select_top_filters(ranked_indices, mapping, prune_rate=0.5):
    num_total = len(ranked_indices)
    num_to_keep = int(num_total * (1 - prune_rate))
    selected = ranked_indices[:num_to_keep]
    keep_dict = defaultdict(list)
    for idx in selected:
        layer_name, filt_idx = mapping[idx]
        keep_dict[layer_name].append(filt_idx)
    for k in keep_dict:
        keep_dict[k] = sorted(keep_dict[k])
    return dict(keep_dict)

# === PRUNE CONV1D ===
def build_pruned_conv1d(original, keep_out, keep_in=None):
    in_ch = len(keep_in) if keep_in is not None else original.in_channels
    out_ch = len(keep_out)
    conv = nn.Conv1d(
        in_channels=in_ch,
        out_channels=out_ch,
        kernel_size=original.kernel_size,
        stride=original.stride,
        padding=original.padding,
        dilation=original.dilation,
        groups=1,
        bias=original.bias is not None
    )
    with torch.no_grad():
        w = original.weight.data[keep_out]
        if keep_in is not None:
            w = w[:, keep_in, :]
        conv.weight.copy_(w)
        if original.bias is not None:
            conv.bias.copy_(original.bias.data[keep_out])
    return conv

# === PRUNE BN1D ===
def build_pruned_bn1d(original, keep_idxs):
    new_bn = nn.BatchNorm1d(len(keep_idxs))
    with torch.no_grad():
        if len(original.weight) > max(keep_idxs):
            new_bn.weight.copy_(original.weight.data[keep_idxs])
            new_bn.bias.copy_(original.bias.data[keep_idxs])
            new_bn.running_mean.copy_(original.running_mean[keep_idxs])
            new_bn.running_var.copy_(original.running_var[keep_idxs])
    return new_bn

# === APPLY PRUNING ===
def apply_k_rnf_pruning(model, importance_scores, keep_map, k=3):
    new_model = deepcopy(model)
    prev_kept = {}  # track previous layer's kept filters

    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d) and module.groups == 1 and not name.startswith("stem.0"):
            keep_out = keep_map.get(name, None)
            if keep_out is None:
                print(f"[WARN] No filters kept in {name}, skipping")
                continue

            keep_in = prev_kept.get(name, None)
            parent = new_model
            subnames = name.split(".")
            for s in subnames[:-1]:
                parent = getattr(parent, s)

            new_conv = build_pruned_conv1d(module, keep_out, keep_in)
            setattr(parent, subnames[-1], new_conv)

            # Adjust BatchNorm
            bn_name = subnames[-1].replace("conv", "bn")
            if hasattr(parent, bn_name):
                bn_module = getattr(parent, bn_name)
                if isinstance(bn_module, nn.BatchNorm1d):
                    try:
                        new_bn = build_pruned_bn1d(bn_module, keep_out)
                        setattr(parent, bn_name, new_bn)
                    except Exception as e:
                        print(f"[!] Could not adjust BN: {bn_name} | {e}")

            prev_kept[name] = keep_out
            print(f"[✓] Pruned {name} → out={len(keep_out)}" + (f", in={len(keep_in)}" if keep_in else ""))

    return new_model

# === TRAINING ===
def train_model(model, train_loader, val_loader, device, epochs=10, lr=1e-3):
    model.to(device)
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total, correct, loss_sum = 0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device).squeeze(-1), y.to(device)
            opt.zero_grad()
            out = model(x)
            loss = loss_fn(out, y)
            loss.backward()
            opt.step()

            loss_sum += loss.item() * x.size(0)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

        acc = 100. * correct / total
        print(f"[Epoch {epoch+1}] Train Loss: {loss_sum/total:.4f} | Train Acc: {acc:.2f}%")

        # Eval
        model.eval()
        total, correct = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device).squeeze(-1), y.to(device)
                out = model(x)
                pred = out.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        val_acc = 100. * correct / total
        print(f"[Eval] Epoch {epoch+1} | Val Acc: {val_acc:.2f}%")

# === EVALUATION ===
def evaluate_model(model, test_loader, device):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device).squeeze(-1), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    acc = 100. * correct / total
    print(f"[✓] Final Test Accuracy: {acc:.2f}%")

In [None]:
# === Khởi tạo mô hình ===
model = FinalNetwork(C=8, num_classes=5, layers=5, genotype=genotype).to(device)
count_params(model)

# === Tính độ quan trọng, prune, và huấn luyện lại ===
importance = compute_importance_scores(model)
ranked, mapping, _ = get_global_ranking(importance)
keep_map = select_top_filters(ranked, mapping, prune_rate=0.6)
pruned_model = apply_k_rnf_pruning(model, importance, keep_map, k=3)

print("=== [AFTER PRUNING] ===")
count_params(pruned_model)

train_model(pruned_model, train_loader, test_loader, device, epochs=10)
evaluate_model(pruned_model, test_loader, device)


In [None]:
....import torch
import matplotlib.pyplot as plt
import os

def train_final_model(model, train_loader, val_loader, device, fold=1, epochs=30):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_acc = 0
    train_loss_list, val_loss_list, train_acc_list, val_acc_list = [], [], [], []

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            x = x.squeeze(-1)
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (logits.argmax(dim=1) == y).sum().item()
            total += y.size(0)

        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)

        # --- Validation ---
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device).squeeze(-1), y.to(device)
                logits = model(x)
                loss = criterion(logits, y)
                val_loss += loss.item()
                val_correct += (logits.argmax(dim=1) == y).sum().item()
                val_total += y.size(0)

        val_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), f"best_final_model_fold{fold}.pt")

        scheduler.step()

        print(f"[Final Train Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # --- Save plots ---
    os.makedirs("logs", exist_ok=True)

    plt.figure()
    plt.plot(train_loss_list, label='Train Loss')
    plt.plot(val_loss_list, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {fold} - Loss Curve')
    plt.legend()
    plt.savefig(f'logs/final_loss_fold{fold}.png')
    plt.close()

    plt.figure()
    plt.plot(train_acc_list, label='Train Acc')
    plt.plot(val_acc_list, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'Fold {fold} - Accuracy Curve')
    plt.legend()
    plt.savefig(f'logs/final_accuracy_fold{fold}.png')
    plt.close()


In [None]:
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, TensorDataset
import torch

kf = KFold(n_splits=5, shuffle=True, random_state=42)
all_y_true, all_y_pred = [], []

for fold, (train_idx, val_idx) in enumerate(kf.split(X_all)):
    print(f"\n[INFO] Fold {fold+1}/5")

    fold_train = DataLoader(TensorDataset(X_all[train_idx], y_all[train_idx]), batch_size=64, shuffle=True)
    fold_val   = DataLoader(TensorDataset(X_all[val_idx],   y_all[val_idx]),   batch_size=64)

    # ✅ Define model
    model = FinalNetwork(C=16, num_classes=num_classes, layers=7, genotype=searched_genotype).to(device)

    # ✅ Train and save best checkpoint per fold
    train_final_model(model, fold_train, fold_val, device=device, fold=fold+1, epochs=30)

    # ✅ Load best checkpoint for this fold
    model.load_state_dict(torch.load(f"best_final_model_fold{fold+1}.pt"))

    # ✅ Evaluate and log confusion matrix, predictions
    evaluate_model(model, fold_val, device, fold=fold+1)

    # ✅ Save acc/loss manually for final printout
    model.eval()
    y_true_val, y_pred_val = [], []
    criterion = torch.nn.CrossEntropyLoss()
    val_loss_total = 0.0

    with torch.no_grad():
        for x, y in fold_val:
            x, y = x.to(device).squeeze(-1), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            val_loss_total += loss.item()
            y_pred_val.extend(logits.argmax(dim=1).cpu().numpy())
            y_true_val.extend(y.cpu().numpy())

    val_loss = val_loss_total / len(fold_val)
    val_acc = accuracy_score(y_true_val, y_pred_val)

    # For train acc check (optional, not strictly necessary)
    y_true_train, y_pred_train = [], []
    train_loss_total = 0.0
    with torch.no_grad():
        for x, y in fold_train:
            x, y = x.to(device).squeeze(-1), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            train_loss_total += loss.item()
            y_pred_train.extend(logits.argmax(dim=1).cpu().numpy())
            y_true_train.extend(y.cpu().numpy())

    train_loss = train_loss_total / len(fold_train)
    train_acc = accuracy_score(y_true_train, y_pred_train)

    print(f"[Fold {fold+1}] ✅ Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} || Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    all_y_pred.extend(y_pred_val)
    all_y_true.extend(y_true_val)


In [None]:
# 8. Confusion matrix tổng hợp
print("\n[INFO] Final Confusion Matrix (5-Fold CV on pruned data):")
cm = confusion_matrix(all_y_true, all_y_pred)
print(cm)

# 9. Heatmap + Save
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title("Confusion Matrix (Full Pruned Data - 5-Fold CV)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.savefig("final_confusion_matrix.png")
plt.show()

# 10. Save CSV
pd.DataFrame(cm).to_csv("confusion_matrix.csv", index=False)
print("[✓] Confusion matrix saved to final_confusion_matrix.png and confusion_matrix.csv")
print("\n[✓] Training complete.")

In [None]:
# from sklearn.model_selection import KFold
# from torch.utils.data import TensorDataset, DataLoader
# import seaborn as sns
# if __name__ == '__main__':

#     set_seed(42)
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'

#     print("[INFO] Loading 80/20 split data...")
#     train_loader, val_loader, num_classes = get_dataloaders_simple(batch_size=16)

#     print(f"[INFO] DARTS will run on {len(train_loader.dataset.y)} train samples and {len(val_loader.dataset.y)} val samples")

#     print("[INFO] Running DARTS search with BDP...")
#     searched_genotype, pruned_train_loader, pruned_val_loader = train_darts_search_bdp(
#         train_loader, val_loader, num_classes,
#         epochs=5, prune_every=5, pt=0.15, pv=0.05,
#         device=device
#     )

#     print("[INFO] Visualizing searched cells...")
#     plot_cell(searched_genotype, 'normal')
#     plot_cell(searched_genotype, 'reduce')

#     print("[INFO] Running 5-Fold Cross Validation on pruned data...")

#     # ✅ Gộp lại dữ liệu sau pruning
#     X_all = torch.cat([pruned_train_loader.dataset.X, pruned_val_loader.dataset.X], dim=0)
#     y_all = torch.cat([pruned_train_loader.dataset.y, pruned_val_loader.dataset.y], dim=0)
#     dataset = TensorDataset(X_all, y_all)

#     # ✅ K-fold cross-validation
#     kf = KFold(n_splits=5, shuffle=True, random_state=42)
#     all_y_true, all_y_pred = [], []

#     for fold, (train_idx, val_idx) in enumerate(kf.split(X_all)):
#         print(f"\n[INFO] Fold {fold+1}/5")
#         DataLoader(TensorDataset(X_all[train_idx], y_all[train_idx]), batch_size=64, shuffle=True)
#         fold_val   = DataLoader(TensorDataset(X_all[val_idx],   y_all[val_idx]),   batch_size=64)

#         fold_train = model = FinalNetwork(C=16, num_classes=num_classes, layers=9, genotype=searched_genotype).to(device)
#         train_final_model(model, fold_train, fold_val, device=device, epochs=100)

#         model.load_state_dict(torch.load("best_final_model1.pt"))

#         # ✅ Evaluation
#         model.eval()
#         with torch.no_grad():
#             for x, y in fold_val:
#                 x = x.to(device).squeeze(-1)
#                 logits = model(x)
#                 preds = logits.argmax(dim=1).cpu().numpy()
#                 all_y_pred.extend(preds)
#                 all_y_true.extend(y.numpy())

#     # ✅ Tổng hợp confusion matrix
#     print("\n[INFO] Final Confusion Matrix (5-Fold CV on pruned data):")
#     from sklearn.metrics import confusion_matrix
#     import matplotlib.pyplot as plt
#     import pandas as pd

#     cm = confusion_matrix(all_y_true, all_y_pred)
#     print(cm)

#     sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
#     plt.title("Confusion Matrix (Full Pruned Data - 5-Fold CV)")
#     plt.xlabel("Predicted")
#     plt.ylabel("True")
#     plt.savefig("final_confusion_matrix.png")
#     plt.show()

#     pd.DataFrame(cm).to_csv("confusion_matrix.csv", index=False)
#     print("[✓] Confusion matrix saved to final_confusion_matrix.png and confusion_matrix.csv")
#     print("\n[✓] Training complete.")
