In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import math
import warnings
from torch.utils.data import TensorDataset, DataLoader
from model.getdata import smiles2graph
from model.CL_model_vas_info import GNNModelWithNewLoss
from model.fusion import TransformerFusionModel

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== Step 1: Data Preparation ==========
def prepare_dataset(smiles2graph_fn, csv_file, model_dir):
    df = pd.read_csv(csv_file)
    smiles_list = df["smiles"].tolist()
    y = torch.tensor(df[df.columns[-1]].values, dtype=torch.float32)
    data_list = smiles2graph_fn(smiles_list, y=y)
    embeddings = load_model_embeddings(data_list, model_dir)
    global_features = torch.stack([data.global_features for data in data_list])
    return embeddings, y, global_features                                                                                                                                       
0
# ========== Step 2: Load model embeddings ==========
def load_model_embeddings(data_list, model_dir):
    model_embeddings = []
    for i in range(3):
        model_path = os.path.join(model_dir, str(i), "best_model.pth")
        model = GNNModelWithNewLoss(
            num_node_features=data_list[0].x.shape[1],
            num_edge_features=data_list[0].edge_attr.shape[1],
            num_global_features=0,
            hidden_dim=512,
            cov_num=3
        )
        state_dict = torch.load(model_path, map_location=torch.device("cpu"))
        model.load_state_dict(state_dict['encoder_state_dict'])
        model.eval()
        embeddings = []
        with torch.no_grad():
            for data in data_list:
                emb = model.forward(data.to(torch.device("cpu")))
                embeddings.append(emb)
        model_embeddings.append(torch.stack(embeddings))
    return torch.stack(model_embeddings, dim=0).permute(1, 0, 2)

# ========== Step 3: Training Function ==========
def train_model_batched(model, train_loader, val_loader, train_ds_len, val_ds_len, epochs=600, lr=1e-5):
    criterion = nn.MSELoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_losses, val_losses = [], []
    for epoch in range(1, epochs + 1):
        model.train()
        total_train = 0.0
        for xb, yb, global_features in train_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            global_features = global_features.to(device, non_blocking=True)
            optimizer.zero_grad()
            preds, _ = model(xb, global_features)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
            total_train += loss.item() * xb.size(0)
        train_rmse = (total_train / train_ds_len) ** 0.5
        train_losses.append(train_rmse)

        model.eval()
        total_val = 0.0
        with torch.no_grad():
            for xb, yb, global_features in val_loader:
                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)
                global_features = global_features.to(device, non_blocking=True)
                preds, _ = model(xb, global_features)
                total_val += criterion(preds, yb).item() * xb.size(0)
        val_rmse = (total_val / val_ds_len) ** 0.5
        val_losses.append(val_rmse)
        print(f"Epoch {epoch}/{epochs} — Train Loss: {train_rmse:.4f}, Val Loss: {val_rmse:.4f}")
    return train_losses, val_losses

# ========== Step 4: Run Training ==========
if __name__ == '__main__':
    csv_path = 'data/esol.csv'
    model_dir = 'premodels_new_og/3'
    batch_size = 64
    embeddings, y_tensor, global_features = prepare_dataset(smiles2graph, csv_path, model_dir)

    train_idx, val_idx = train_test_split(np.arange(len(y_tensor)), test_size=0.2, random_state=42)
    train_ds = TensorDataset(embeddings[train_idx], y_tensor[train_idx], global_features[train_idx])
    val_ds = TensorDataset(embeddings[val_idx], y_tensor[val_idx], global_features[val_idx])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)

    model = TransformerFusionModel(emb_dim=embeddings.shape[-1]).to(device)
    train_losses, val_losses = train_model_batched(
        model, train_loader, val_loader,
        train_ds_len=len(train_ds), val_ds_len=len(val_ds),
        epochs=200, lr=1e-4
    )

    # ========== Step 5: Plot ==========
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train RMSE')
    plt.plot(val_losses, label='Val RMSE')
    plt.xlabel('Epoch')
    plt.ylabel('RMSE')
    plt.legend()
    plt.title('Training Curve')
    plt.grid()
    plt.tight_layout()
    plt.savefig('training_curve.png')
    plt.show()


RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 3