In [None]:
# File: train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from thop import profile
import pickle
import numpy as np
from tqdm import tqdm
from net import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

experiments = {
    "FixedInit_NoPruning": {"type": "basernn", "trainable": False, "pruning": None, "optim": "adam"},
    "Learnable_RandomInit": {"type": "basernn", "trainable": True,  "pruning": None, "optim": "adam"},
    # "CWS_NoPruning": {"type": "cwsrnn",  "train_W": True, "pruning": None, "optim": "adam"},
    # "CWS_TrainC_Pruning": {"type": "cwsrnn", "train_W": True, "train_C": True, "pruning": "drosophila", "optim": "adam"},
    # "CWS_FixedC_Pruning": {"type": "cwsrnn", "train_W": True, "train_C": False, "pruning": "drosophila", "optim": "adam"},
    # "L1_FISTA": {"type": "basernn", "trainable": True,  "pruning": None, "optim": "fista"},
    # "FixedInit_L1": {"type": "basernn", "trainable": False, "pruning": None, "optim": "fista"},
    # "Learnable_Hungarian": {"type": "basernn", "trainable": True,  "pruning": "hungarian", "optim": "adam"},
    # "CNN_RNN": {"type": "cnnrnn",  "trainable": False, "pruning": None, "optim": "adam"}
}

def get_input_shape(model_type):
    return (1, 1, 28, 28) if model_type == "cnnrnn" else (1, 28, 28)

def prepare_input(data, model):
    return data if isinstance(model, CNNRNN) else data.squeeze(1)

def train_epoch(model, optimizer, criterion, train_loader, flops_per_sample, cumulative_flops):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    flops_acc_pairs = []
  
    with tqdm(train_loader, unit="batch", desc="Training") as pbar:
        for batch_idx, (data, target) in enumerate(pbar):
            data = prepare_input(data.to(device), model)
            target = target.to(device)
          
            batch_flops = flops_per_sample * data.size(0) * 3  
            cumulative_flops += batch_flops
          
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
          
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += data.size(0)
          
            if (batch_idx + 1) % 100 == 0:
                flops_acc_pairs.append((cumulative_flops, correct/total))
          
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{correct/total:.2%}",
                'FLOPs': f"{cumulative_flops/1e9:.2f}G"
            })
  
    return total_loss/total, correct/total, flops_acc_pairs, cumulative_flops

def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data = prepare_input(data.to(device), model)
            output = model(data)
            correct += output.argmax(dim=1).eq(target.to(device)).sum().item()
            total += target.size(0)
    return correct / total

def train_experiment(exp_id, config):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)
  
    non_zero = None
    if config.get('pruning') == 'drosophila':
        W_droso, non_zero = load_drosophila_matrix('./data/ad_connectivity_matrix.csv', apply_pruning=True)
    elif config.get('pruning') == 'hungarian':
        W_droso = load_drosophila_matrix('./data/ad_connectivity_matrix.csv')
    else:
        W_droso = load_drosophila_matrix('./data/ad_connectivity_matrix.csv')
  
    if config['type'] == 'basernn':
        model = BaseRNN(
            28, W_droso.shape[0], 10,
            W_init=W_droso if not config['trainable'] else None,
            trainable=config['trainable'],
            pruning_method=config.get('pruning')
        )
    elif config['type'] == 'cwsrnn':
         model = CWSRNN(
            28, W_droso.shape[0], 10, W_droso,
            train_W=config['train_W'],
            train_C=config.get('train_C', False),
            non_zero_count=non_zero if config.get('pruning') == 'drosophila' else None
        )
    elif config['type'] == 'cnnrnn':
        model = CNNRNN(torch.tensor(W_droso))
    model.to(device)
  
    if config['optim'] == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
    else:
        optimizer = FISTAOptimizer(model.parameters(), lr=1e-3, lambda_l1=1e-5)
  
    input_shape = get_input_shape(config['type'])
    dummy_input = torch.randn(input_shape).to(device)
    macs, _ = profile(model, inputs=(dummy_input,))
    flops_forward = macs * 2
  
    results = {
        "epoch_train_loss": [], "epoch_train_acc": [],
        "epoch_test_acc": [], "flops_acc": [],
        "total_flops": 0
    }
    cumulative_flops = 0
    criterion = nn.CrossEntropyLoss()
  
    for epoch in range(5):
        epoch_loss, epoch_acc, flops_pairs, cumulative_flops = train_epoch(
            model, optimizer, criterion, train_loader, flops_forward, cumulative_flops
        )
        test_acc = evaluate(model, test_loader)
      
        results["epoch_train_loss"].append(epoch_loss)
        results["epoch_train_acc"].append(epoch_acc)
        results["epoch_test_acc"].append(test_acc)
        results["flops_acc"].extend(flops_pairs)
        print(f"Epoch {epoch+1} | Test Acc: {test_acc:.2%}")
  
    results["total_flops"] = cumulative_flops
    with open(f"{exp_id}.pkl", "wb") as f:
        pickle.dump(results, f)

if __name__ == "__main__":
    for exp_id, config in experiments.items():
        print(f"\n{'='*30}\nTraining {exp_id}\n{'='*30}")
        train_experiment(exp_id, config)