# Training Regression - Reaction

# Import packages

In [None]:
import os
import sys

current_path=os.getcwd()
print(current_path)

parent_path=os.path.dirname(current_path)
print(parent_path)

if parent_path not in sys.path:
    sys.path.append(parent_path)

In [None]:
import torch

import pandas as pd
from lightning import pytorch as pl
from pathlib import Path

from chemprop import data, featurizers, models, nn

# Change data inputs here

## Load data

In [None]:
import numpy as np
chemprop_dir = Path.cwd().parent
num_workers = 20  # number of workers for dataloader. 0 means using main process for data loading
# smiles_column = 'AAM'
# target_columns = ['lograte']

In [None]:
train_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "barriers_rdb7" / "train.csv"
train_npz = np.load(f'../chemprop/data/normal/barriers_rdb7/barriers_rdb7_aam_train_processed_data.npz', allow_pickle=True)
train_v = train_npz['node_attrs']
train_e = train_npz['edge_attrs']
train_idx_g = train_npz['edge_indices']
train_y = train_npz['ys'] 

val_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "barriers_rdb7" / "val.csv"
val_npz = np.load(f'../chemprop/data/normal/barriers_rdb7/barriers_rdb7_aam_val_processed_data.npz', allow_pickle=True)
val_v = val_npz['node_attrs']
val_e = val_npz['edge_attrs']
val_idx_g = val_npz['edge_indices']
val_y = val_npz['ys'] 

test_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "barriers_rdb7" / "test.csv"
test_npz = np.load(f'../chemprop/data/normal/barriers_rdb7/barriers_rdb7_aam_test_processed_data.npz', allow_pickle=True)
test_v = test_npz['node_attrs']
test_e = test_npz['edge_attrs']
test_idx_g = test_npz['edge_indices']
test_y = test_npz['ys'] 

In [None]:
print(train_v.shape)

In [None]:
print(train_idx_g.shape, val_y.shape, test_y.shape)

## Perform data splitting for training, validation, and testing

## Get ReactionDatasets

In [None]:
train_dset = data.ReactionDataset(train_v, train_e, train_idx_g, train_y)
print(train_dset[0][3])
scaler = train_dset.normalize_targets()
# print(scaler)
print(train_dset[0][3])

val_dset = data.ReactionDataset(val_v, val_e, val_idx_g, val_y)
val_dset.normalize_targets(scaler)
test_dset = data.ReactionDataset(test_v, test_e, test_idx_g, test_y)

In [None]:
train_dset[0][3]

In [None]:
edge_index=train_dset[1][0][-2]
print(f'edge_index: {edge_index}')
reverse_index=train_dset[1][0][-1]
print(f'reverse_index: {reverse_index}')

In [None]:
import numpy as np

np.arange(6).reshape(-1,2)[:, ::-1].ravel()

## Get dataloaders

In [None]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message passing

Message passing blocks must be given the shape of the featurizer's outputs.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [None]:
train_v[0].shape[1]

In [None]:
fdims = (train_v[0].shape[1],train_e[0].shape[1]) # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = nn.BondMessagePassing(*fdims)

In [None]:
print(*fdims)

## Aggregation

In [None]:
print(nn.agg.AggregationRegistry)

In [None]:
agg = nn.MeanAggregation()  #try Mean or Sum
# agg = nn.SumAggregation()  #try Mean or Sum

## Feed-Forward Network (FFN)

In [None]:
print(nn.PredictorRegistry)

In [None]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)

In [None]:
ffn = nn.RegressionFFN(output_transform=output_transform)

## Batch norm

In [None]:
batch_norm = True

## Metrics

In [None]:
print(nn.metrics.MetricRegistry)

In [None]:
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()] 
# Only the first metric is used for training and early stopping

## Construct MPNN

In [None]:
k_jump = 3

In [None]:
# mpnn = models(mp, agg, ffn, batch_norm, metric_list, k_jump)
# mpnn

In [None]:
# mpnn = models.MPNN_Simple(
#     message_passing=mp,
#     predictor=ffn,        
#     # k_jump=k_jump,
#     metrics=metric_list
#     # init_lr=1e-4,
#     # max_lr=1e-3,
# )

# print("Khởi tạo mô hình thành công!")

# Training and testing

## Set up trainer

In [None]:
# trainer = pl.Trainer(
#     logger=False,
#     enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
#     enable_progress_bar=True,
#     accelerator="auto",
#     devices=1,
#     max_epochs=200,  # number of epochs to train for
# )

## Start training

In [None]:
# trainer.fit(mpnn, train_loader, val_loader)

## Test results

In [None]:
# results = trainer.test(mpnn, test_loader)

    # mpnn = models.MPNN_Simple(
    # message_passing=mp,
    # predictor=ffn,        
    # # k_jump=k_jump,
    # metrics=metric_list
    # ) 

    # mpnn = models.MPNN_1(mp, agg, ffn, batch_norm, metric_list, FA_layer) 
    
    #     mpnn = models.MPNN_MixHop_Pool(
    # message_passing=mp,
    # predictor=ffn,        
    # metrics=metric_list
    # ) 


In [None]:
# ===================================================================
# THAY THẾ TOÀN BỘ PHẦN "Training and testing" BẰNG ĐOẠN MÃ NÀY
# ===================================================================

# --- Thiết lập các tham số cho Ensemble ---
ENSEMBLE_SIZE = 15
all_test_results = []
output_dir = Path("./reaction_ensemble_results")
output_dir.mkdir(exist_ok=True)

print(f"Bắt đầu huấn luyện ensemble với kích thước = {ENSEMBLE_SIZE}")
print("-" * 30)

# --- Bắt đầu vòng lặp huấn luyện Ensemble ---
for i in range(ENSEMBLE_SIZE):
    print(f"\n--- Đang huấn luyện mô hình {i+1}/{ENSEMBLE_SIZE} ---")
    
    # 1. Tạo một mô hình MPNN mới cho mỗi lần lặp để đảm bảo trọng số được khởi tạo lại
    mpnn = models.MPNN_Simple(
    message_passing=mp,
    predictor=ffn,        
    metrics=metric_list
    ) 
    
    try:
        pretrained_weights = torch.load("/home/labhhc2/Documents/workspace/D20/Tam/repo/chemprop_1/examples/pretrained_dmpnn_nhap_encoder.pt")
        mpnn.message_passing.load_state_dict(pretrained_weights)
        print(f"Model {i+1}: Tải trọng số pre-train thành công!")
    except FileNotFoundError:
        print(f"Model {i+1}: Không tìm thấy file pre-trained. Huấn luyện từ đầu.")
    
    # 2. Tạo một Trainer mới, chỉ định nơi lưu checkpoint cho từng mô hình
    model_checkpoint_dir = output_dir / f"model_{i}"
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=model_checkpoint_dir,
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        filename='best_model'
    )
    
    trainer = pl.Trainer(
        logger=False,
        enable_checkpointing=True,
        callbacks=[checkpoint_callback],
        enable_progress_bar=True,
        accelerator="auto",
        devices=1,
        max_epochs=50,
    )
    
    # 3. Huấn luyện mô hình
    trainer.fit(mpnn, train_loader, val_loader)
    
    # 4. Chạy kiểm tra (test) trên mô hình tốt nhất vừa được lưu
    # và lưu kết quả của lần lặp này
    print(f"--- Đang kiểm tra mô hình {i+1} ---")
    best_model_path = checkpoint_callback.best_model_path
    results = trainer.test(mpnn, test_loader, ckpt_path=best_model_path)
    all_test_results.append(results[0]) # results là một list, lấy phần tử đầu tiên

# --- Tổng hợp kết quả ---
print("\n" + "="*30)
print("HUẤN LUYỆN ENSEMBLE HOÀN TẤT!")
print("="*30)

# Chuyển danh sách kết quả thành DataFrame để dễ tính toán
results_df = pd.DataFrame(all_test_results)

print("\n--- Kết quả kiểm tra của từng mô hình ---")
print(results_df)

print("\n--- Kết quả trung bình của Ensemble ---")
print(results_df.mean())