In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader as GeoDataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import math
import warnings
warnings.filterwarnings("ignore")

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Fixed hyperparameters
LEARNING_RATE = 1e-4
EMB_DIM = 512
PATIENCE_ES = 20  # early stopping patience
PATIENCE_LR = 10   # scheduler patience
LR_FACTOR = 0.5   # scheduler reduce factor

# User modules
from model.featurisation import smiles2graph
from model.CL_model_vas_info import GNNModelWithNewLoss
from model.fusion import TransformerFusionModel, WeightedFusion, MLP, FusionFineTuneModel

# Data loading
def load_data_for_visualization(name, batch_size=32, val_split=0.1, test_split=0.2, seed=42):
    df = pd.read_csv(f'data/{name}.csv')
    smiles_list = df['smiles'].tolist()
    labels = df[name].tolist()
    data_list = smiles2graph(smiles_list, labels)
    
    # Train-test split
    train_val, test_data = train_test_split(data_list, test_size=test_split, random_state=seed)
    train_data, val_data = train_test_split(
        train_val, test_size=val_split/(1 - test_split), random_state=seed
    )
    
    train_loader = GeoDataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = GeoDataLoader(val_data, batch_size=batch_size, shuffle=False, drop_last=True)
    test_loader = GeoDataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)

    # Get test dataset smiles
    test_smiles_list = [data.smiles for data in test_data]
    
    return train_loader, val_loader, test_loader, test_smiles_list

# Load pre-trained encoders
def load_pretrained_encoders(sample, encoders_to_use=[0, 1, 2]):
    encoders = []
    for i in encoders_to_use:
        enc = GNNModelWithNewLoss(
            num_node_features=sample.x.shape[1],
            num_edge_features=sample.edge_attr.shape[1],
            num_global_features=sample.global_features.shape[0],
            hidden_dim=EMB_DIM
        ).to(device)
        ckpt = torch.load(f'premodels/{i}/best_model.pth', map_location=device)
        enc.load_state_dict(ckpt['encoder_state_dict'])
        encoders.append(enc)
    return encoders

# Build fusion model based on encoder selection
def get_finetune_model(fusion_method, sample, dropout, encoders_to_use=[0, 1, 2]):
    encoders = load_pretrained_encoders(sample, encoders_to_use)  # Select encoders based on input
    if fusion_method == 'attention':
        fusion = TransformerFusionModel(emb_dim=EMB_DIM).to(device)
    elif fusion_method == 'weighted':
        fusion = WeightedFusion(num_inputs=len(encoders), emb_dim=EMB_DIM, dropout=dropout).to(device)
    elif fusion_method == 'concat':
        fusion = MLP(emb_dim=EMB_DIM * len(encoders)).to(device)
    else:
        raise ValueError(f'Unknown fusion method {fusion_method}')
    return FusionFineTuneModel(encoders, fusion, fusion_method).to(device)

# Training routine
def train_and_validate(model, train_loader, val_loader, epochs):
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.MSELoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=LR_FACTOR, patience=PATIENCE_LR)

    train_losses, val_rmses = [], []
    best_val_rmse = float('inf')
    best_state = None
    patience_cnt = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            embs = [encoder(batch) for encoder in model.encoders]
            embs = torch.stack(embs, dim=1)  # [B, N_encoders, D]

            # Handle fusion method
            if model.fusion_method == 'concat':
                embs = embs.view(embs.size(0), -1)  # Flatten for MLP: [B, N_encoders * D]
            
            out = model.fusion(embs)  # Fusion output
            pred = out[0] if isinstance(out, tuple) else out  # Ensure proper unpacking
            label = batch.y.view(-1).float().to(device)

            # Ensure the shape of pred and label match
            pred = pred.view(-1)  # Flatten predictions to match label shape
            label = label.view(-1)  # Flatten label to match prediction shape
            
            loss = criterion(pred, label).sqrt()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)

        model.eval()
        with torch.no_grad():
            preds, labs = [], []
            for batch in val_loader:
                batch = batch.to(device)
                embs = [encoder(batch) for encoder in model.encoders]
                embs = torch.stack(embs, dim=1)

                if model.fusion_method == 'concat':
                    embs = embs.view(embs.size(0), -1)

                out = model.fusion(embs)
                pred = out[0] if isinstance(out, tuple) else out
                pred = pred.view(-1)  # Flatten predictions to match label shape
                label = batch.y.view(-1).cpu()

                preds.append(pred.cpu())
                labs.append(label)
            
            preds = torch.cat(preds)
            labs = torch.cat(labs)
            rmse = criterion(preds, labs).sqrt().item()
        val_rmses.append(rmse)
        scheduler.step(rmse)

        # Early stopping
        if rmse < best_val_rmse - 1e-6:
            best_val_rmse = rmse
            best_state = model.state_dict().copy()
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= PATIENCE_ES:
                break

        # print(f"[Epoch {epoch:03d}] Train Loss={avg_loss:.4f}, Val RMSE={rmse:.4f}")

    model.load_state_dict(best_state)
    return train_losses, val_rmses

# Final test evaluation
def test_model(model, test_loader):
    criterion = nn.MSELoss()
    model.eval()
    with torch.no_grad():
        preds, labs = [], []
        for batch in test_loader:
            batch = batch.to(device)
            embs = [encoder(batch) for encoder in model.encoders]
            embs = torch.stack(embs, dim=1)

            if model.fusion_method == 'concat':
                embs = embs.view(embs.size(0), -1)

            out = model.fusion(embs)
            pred = out if not isinstance(out, tuple) else out[0]
            preds.append(pred.cpu())
            labs.append(batch.y.view(-1).cpu())
        preds = torch.cat(preds)
        labs = torch.cat(labs)
        rmse = criterion(preds, labs).sqrt().item()
    
    return rmse, preds.numpy(), labs.numpy()

# Direct usage of fixed parameters (no hyperparameter search)
def run_multiple(fusion_method, name, params, encoders_to_use, runs=10):
    rmses, histories, results = [], [], []
    for i in range(runs):
        # Get data loaders
        tr, vl, te, _ = load_data_for_visualization(name, batch_size=params['batch_size'], val_split=params['val_split'],
            test_split=params['test_split'], seed=42+i)
        
        sample = tr.dataset[0]  # Get a sample batch from the training loader
        model = get_finetune_model(fusion_method, sample, dropout=params['dropout'], encoders_to_use=encoders_to_use)
        
        # Train the model
        tr_losses, val_rmses = train_and_validate(model, tr, vl, epochs=params['epochs'])
        rmses.append(val_rmses[-1])  # Save the last RMSE of the validation set
        histories.append((tr_losses, val_rmses))  # Save history of training and validation
        
        # Test the model
        rmse_test, preds, labs = test_model(model, te)
        results.append((preds, labs))  # Save predictions and labels for testing
    
    # Average the RMSE over all runs
    mean_rmse = np.mean(rmses)
    var_rmse = np.var(rmses)
    
    # Return all results
    return {
        'mean_rmse': mean_rmse,
        'var_rmse': var_rmse,
        'all_histories': histories,
        'all_preds': [res[0] for res in results],
        'all_labels': [res[1] for res in results]
    }

def run_ablations(best_params, runs=3):
    # Define datasets and fusion methods inside this function
    datasets = ['freesolv']
    fusion_methods = ['concat', 'weighted', 'attention']
    
    # Define all combinations of encoders to test
    encoder_combinations = [
        [0, 1, 2],  # All encoders
        [0],  # Only encoder 0
        [1],  # Only encoder 1
        [2],  # Only encoder 2
        [0, 1],  # Encoder 0 and 1
        [0, 2],  # Encoder 0 and 2
        [1, 2]   # Encoder 1 and 2
    ]
    
    # Initialize the result table to store the mean ± var for each encoder combination
    result_table = {f"encoders_{'-'.join(map(str, encoders))}": [] for encoders in encoder_combinations}
    
    # Loop through each dataset and fusion method combination
    for ds in datasets:
        for fusion_method in fusion_methods:
            print(f"Running ablation for dataset: {ds}, fusion method: {fusion_method}")
            
            # Loop over each encoder combination
            for encoders_to_use in encoder_combinations:
                print(f"Running with encoders: {encoders_to_use}")
                
                # Run multiple trials for each configuration
                res = run_multiple(fusion_method, ds, best_params, encoders_to_use=encoders_to_use, runs=runs)
                
                # Compute the mean and variance of RMSE
                mean_rmse = res['mean_rmse']
                var_rmse = res['var_rmse']
                
                # Format mean ± var
                mean_var = f"{mean_rmse:.4f} ± {var_rmse:.4f}"
                
                # Append the result for this encoder configuration
                result_table[f"encoders_{'-'.join(map(str, encoders_to_use))}"].append(mean_var)
    
    # Construct the final result table (summary)
    summary_data = []
    
    # For each fusion method (only 1 method at a time), add a row for each dataset
    for fusion_method in fusion_methods:
        for ds in datasets:
            row = [ds, fusion_method]  # First two columns: dataset and fusion method
            
            # Append results for each encoder combination (mean ± var)
            for encoders in encoder_combinations:
                encoder_key = f"encoders_{'-'.join(map(str, encoders))}"
                row.append(result_table[encoder_key][datasets.index(ds)])  # Match dataset index

            summary_data.append(row)
    
    # Save the final result to a single CSV file
    summary_df = pd.DataFrame(
        summary_data, 
        columns=["Dataset", "Fusion Method"] + [f"Encoders_{'-'.join(map(str, encoders))}" for encoders in encoder_combinations]
    )
    summary_df.to_csv(f"abl_result/{fusion_method}_ablation_summary.csv", index=False)
    print(f"Results saved to abl_result/{fusion_method}_ablation_summary.csv")


# Main execution
if __name__ == '__main__':
    # Directly using fixed parameters (no hyperparameter search)
    best_params = {
        'batch_size': 16,  # Fixed batch size
        'dropout': 0.1,    # Fixed dropout
        'epochs': 100,     # Fixed number of epochs
        'val_split': 0.1,  # Validation split
        'test_split': 0.2, # Test split
        'seed': 42,        # Random seed
        'weight_decay': 0  # Weight decay (regularization)
    }

    # Run the ablation study with the given parameters
    run_ablations(best_params, runs=3)
