# Library import

In [1]:
import torch
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
import torch.nn as nn
from model.mpnn import SolvationModel
import numpy as np
from model.custoum_dataset import SolPropDataset
from torch.utils.tensorboard import SummaryWriter
from types import SimpleNamespace
import pickle
from tqdm import tqdm
import random

# Fix the random state for reproduction

In [2]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)


# Hyperparameters (need to be optimized)

In [3]:
class Args:
    def __init__(self):
        self.hidden_size = 200             # D-MPNN hidden dim # hyperparameter
        self.ffn_hidden_size = 100         # Feedforward hidden dim # hyperparameter
        self.output_size = 2               # # of properties
        self.dropout = 0.1                 # dropout # hyperparameter
        self.bias = True
        self.depth = 2                     
        self.activation = "ReLU"           # activation
        self.cuda = True                   # GPU
        self.property = "solvation"
        self.aggregation = "mean"
        self.atomMessage = False           # False: only atom


# Basic functions

In [4]:
def collate_fn(batch):
    batched_data = {
        'solute': [item['solute'] for item in batch],                 # list of dict
        'solvent_list': [item['solvent_list'] for item in batch],     # list of list of dict
        'mol_frac': torch.stack([item['mol_frac'] for item in batch]),
        'target': torch.stack([item['target'] for item in batch])     # (B, 2)
    }
    return batched_data

def to_namespace(obj):
    return obj if isinstance(obj, SimpleNamespace) else SimpleNamespace(**obj)

# move dict tensors to device
def move_batch_to_device(batch, device):
    batch['mol_frac'] = batch['mol_frac'].to(device)
    batch['target']   = batch['target'].to(device)

    # Move solutes
    solutes = []
    for solute in batch['solute']:
        solute_ns = to_namespace(solute)
        for k, v in vars(solute_ns).items():
            if isinstance(v, torch.Tensor):
                setattr(solute_ns, k, v.to(device))
        solutes.append(solute_ns)
    batch['solute'] = solutes

    # Move solvents
    solvents_out = []
    for solvent_list in batch['solvent_list']:
        tmp = []
        for solvent in solvent_list:
            solvent_ns = to_namespace(solvent)
            for k, v in vars(solvent_ns).items():
                if isinstance(v, torch.Tensor):
                    setattr(solvent_ns, k, v.to(device))
            tmp.append(solvent_ns)
        solvents_out.append(tmp)
    batch['solvent_list'] = solvents_out
    return batch

# metrics
def rmse(y_hat, y):
    return torch.sqrt(torch.mean((y_hat - y) ** 2)).item()
def mae(y_hat, y):
    return torch.mean(torch.abs(y_hat - y)).item()

# Load processed dataset 

In [None]:
# if you want to use another data, go to 'preprocessing' folder and save it based on 'preprocessing_binary.py' file
data_path = "train_binary.pkl"   
with open(data_path, 'rb') as f:
    data_list = pickle.load(f)

# datasplit

In [None]:
dataset = SolPropDataset(data_list)
train_size = int(0.8 * len(dataset)) # 0.8 -> train:test = 8:2
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# data loading

In [13]:
batch_size = 128 # hyperparameter
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=False)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)


# scaling

In [None]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import joblib

targets = []
for item in dataset:
    targets.append(item['target'].cpu().numpy())
targets = np.array(targets)

scaler = StandardScaler() # MinMaxScaler # hyperparameter
scaler.fit(targets)
# scaler save
joblib.dump(scaler, "scaler.pkl")

['scaler.pkl']

# model/optimizer

In [None]:
args = Args()
device = torch.device('cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu')
model = SolvationModel(args).to(device)
optimizer = Adam(model.parameters(), lr=1e-3) # hyperparameter
criterion = nn.MSELoss(reduction='none')

# model train

In [None]:
# log
writer = SummaryWriter(log_dir='runs/solprop_test')
# train loop 
best_val_loss = float('inf')
patience = 10 # hyperparameter
epochs_without_improve = 0
max_epochs = 100 # hyperparameter

for epoch in range(max_epochs):
    # === Training ===
    model.train()
    train_loss = 0.0
    train_rmse = 0.0
    train_mae  = 0.0
    n_train    = 0

    train_gsolv_loss = 0.0
    train_gsolv_rmse = 0.0
    train_gsolv_mae = 0.0

    train_hsolv_loss = 0.0
    train_hsolv_rmse = 0.0
    train_hsolv_mae = 0.0
    

    tepoch = tqdm(train_loader, unit='batch', desc=f"[Train] Epoch {epoch+1}/{max_epochs}")
    for batch in tepoch:
        batch = move_batch_to_device(batch, device)

        optimizer.zero_grad()
        output = model(batch)   # (B, 2)
        target = batch['target'].cpu().numpy()  # (B, 2)    
        target_sc = torch.tensor(scaler.transform(target), dtype=torch.float, device=device)

        loss_per_target = criterion(output, target_sc)  # (B, 2)
        loss_Gsolv = loss_per_target[:, 0]
        loss_Hsolv = loss_per_target[:, 1]
        
        loss = loss_Gsolv.mean() + loss_Hsolv.mean()
        loss.backward()
        optimizer.step()

        B = output.size(0)
        n_train += B
        train_loss += loss.item() * B
        train_rmse += rmse(output.detach(), target_sc)
        train_mae  += mae(output.detach(), target_sc)

        train_gsolv_loss += loss_Gsolv.mean().item() * B
        train_gsolv_rmse += rmse(output[:, 0].detach(), target_sc[:, 0])
        train_gsolv_mae  += mae(output[:, 0].detach(), target_sc[:, 0])

        train_hsolv_loss += loss_Hsolv.mean().item() * B
        train_hsolv_rmse += rmse(output[:, 1].detach(), target_sc[:, 1])
        train_hsolv_mae  += mae(output[:, 1].detach(), target_sc[:, 1])

        tepoch.set_postfix(loss=loss.item(), refresh=False)

    train_loss /= n_train
    train_rmse /= len(train_loader)
    train_mae  /= len(train_loader)

    train_gsolv_loss /= n_train
    train_gsolv_rmse /= len(train_loader)
    train_gsolv_mae /= len(train_loader)

    train_hsolv_loss /= n_train
    train_hsolv_rmse /= len(train_loader)
    train_hsolv_mae /= len(train_loader)

    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar("RMSE/train", train_rmse, epoch)
    writer.add_scalar("MAE/train",  train_mae,  epoch)

    writer.add_scalar("Loss_Gsolv/train", train_gsolv_loss, epoch)
    writer.add_scalar("RMSE_Gsolv/train", train_gsolv_rmse, epoch)
    writer.add_scalar("MAE_Gsolv/train", train_gsolv_mae, epoch)

    writer.add_scalar("Loss_Hsolv/train", train_hsolv_loss, epoch)
    writer.add_scalar("Loss_Hsolv/train", train_hsolv_rmse, epoch)
    writer.add_scalar("Loss_Hsolv/train", train_hsolv_mae, epoch)

    # === Validation ===
    model.eval()
    val_loss = 0.0
    val_rmse = 0.0
    val_mae  = 0.0
    n_val    = 0

    val_gsolv_loss = 0.0
    val_gsolv_rmse = 0.0
    val_gsolv_mae = 0.0

    val_hsolv_loss = 0.0
    val_hsolv_rmse = 0.0
    val_hsolv_mae = 0.0

    vepoch = tqdm(val_loader, unit='batch', desc=f"[Val]   Epoch {epoch+1}/{max_epochs}")
    with torch.no_grad():
        for batch in vepoch:
            batch = move_batch_to_device(batch, device)

            output = model(batch)
            target = batch['target'].cpu().numpy()
            target_sc = torch.tensor(scaler.transform(target), dtype=torch.float, device=device)

            loss_Gsolv = loss_per_target[:, 0]
            loss_Hsolv = loss_per_target[:, 1]
            loss = loss_Gsolv.mean() + loss_Hsolv.mean()

            B = output.size(0)
            n_val += B
            val_loss += loss.item() * B
            val_rmse += rmse(output.detach(), target_sc)
            val_mae  += mae(output.detach(), target_sc)

            val_gsolv_loss += loss_Gsolv.mean().item() * B
            val_gsolv_rmse += rmse(output[:, 0].detach(), target_sc[:, 0])
            val_gsolv_mae  += mae(output[:, 0].detach(), target_sc[:, 0])

            val_hsolv_loss += loss_Hsolv.mean().item() * B
            val_hsolv_rmse += rmse(output[:, 1].detach(), target_sc[:, 1])
            val_hsolv_mae  += mae(output[:, 1].detach(), target_sc[:, 1])

            vepoch.set_postfix(loss=loss.item(), refresh=False)

    val_loss /= n_val
    val_rmse /= len(val_loader)
    val_mae  /= len(val_loader)

    val_gsolv_loss /= n_val
    val_gsolv_rmse /= len(val_loader)
    val_gsolv_mae /= len(val_loader)

    val_hsolv_loss /= n_val
    val_hsolv_rmse /= len(val_loader)
    val_hsolv_mae /= len(val_loader)


    writer.add_scalar("Loss/val",  val_loss, epoch)
    writer.add_scalar("RMSE/val",  val_rmse, epoch)
    writer.add_scalar("MAE/val",   val_mae,  epoch)

    writer.add_scalar("Loss_Gsolv/val", val_gsolv_loss, epoch)
    writer.add_scalar("RMSE_Gsolv/val", val_gsolv_rmse, epoch)
    writer.add_scalar("MAE_Gsolv/val", val_gsolv_mae, epoch)

    writer.add_scalar("Loss_Hsolv/val", val_hsolv_loss, epoch)
    writer.add_scalar("Loss_Hsolv/val", val_hsolv_rmse, epoch)
    writer.add_scalar("Loss_Hsolv/val", val_hsolv_mae, epoch)

    print(f"Epoch {epoch:03d}: "
          f"Train Loss={train_loss:.4f} RMSE={train_rmse:.4f} MAE={train_mae:.4f} | "
          f"Val Loss={val_loss:.4f} RMSE={val_rmse:.4f} MAE={val_mae:.4f}")

    # Early Stopping
    if val_loss < best_val_loss - 1e-6:     
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pt")
        epochs_without_improve = 0
    else:
        epochs_without_improve += 1
        if epochs_without_improve >= patience:
            print("Early stopping triggered.")
            break

writer.close()

[Train] Epoch 1/100:   0%|          | 0/3 [00:00<?, ?batch/s]

[Train] Epoch 1/100: 100%|██████████| 3/3 [00:04<00:00,  1.54s/batch, loss=1.51]
[Val]   Epoch 1/100: 100%|██████████| 1/1 [00:00<00:00,  1.69batch/s, loss=1.51]


Epoch 000: Train Loss=2.4446 RMSE=1.0871 MAE=0.7122 | Val Loss=1.5125 RMSE=1.0354 MAE=0.7470


[Train] Epoch 2/100: 100%|██████████| 3/3 [00:04<00:00,  1.59s/batch, loss=1.24]
[Val]   Epoch 2/100: 100%|██████████| 1/1 [00:00<00:00,  1.76batch/s, loss=1.24]


Epoch 001: Train Loss=2.0605 RMSE=0.9911 MAE=0.7062 | Val Loss=1.2384 RMSE=1.0513 MAE=0.7882


[Train] Epoch 3/100: 100%|██████████| 3/3 [00:04<00:00,  1.63s/batch, loss=2.1] 
[Val]   Epoch 3/100: 100%|██████████| 1/1 [00:00<00:00,  1.72batch/s, loss=2.1]


Epoch 002: Train Loss=2.1681 RMSE=1.0373 MAE=0.7579 | Val Loss=2.1004 RMSE=1.0126 MAE=0.7434


[Train] Epoch 4/100: 100%|██████████| 3/3 [00:04<00:00,  1.59s/batch, loss=1.19]
[Val]   Epoch 4/100: 100%|██████████| 1/1 [00:00<00:00,  1.82batch/s, loss=1.19]


Epoch 003: Train Loss=1.9202 RMSE=0.9644 MAE=0.6698 | Val Loss=1.1936 RMSE=1.0199 MAE=0.7339


[Train] Epoch 5/100: 100%|██████████| 3/3 [00:05<00:00,  1.74s/batch, loss=1.15]
[Val]   Epoch 5/100: 100%|██████████| 1/1 [00:00<00:00,  1.73batch/s, loss=1.15]


Epoch 004: Train Loss=1.9308 RMSE=0.9571 MAE=0.6300 | Val Loss=1.1523 RMSE=1.0249 MAE=0.7316


[Train] Epoch 6/100: 100%|██████████| 3/3 [00:04<00:00,  1.60s/batch, loss=1.95]
[Val]   Epoch 6/100: 100%|██████████| 1/1 [00:00<00:00,  2.03batch/s, loss=1.95]


Epoch 005: Train Loss=1.9085 RMSE=0.9765 MAE=0.6253 | Val Loss=1.9544 RMSE=0.9886 MAE=0.7052


[Train] Epoch 7/100: 100%|██████████| 3/3 [00:04<00:00,  1.63s/batch, loss=2.83]
[Val]   Epoch 7/100: 100%|██████████| 1/1 [00:00<00:00,  1.84batch/s, loss=2.83]


Epoch 006: Train Loss=1.7842 RMSE=0.9347 MAE=0.6291 | Val Loss=2.8290 RMSE=0.9761 MAE=0.6987


[Train] Epoch 8/100: 100%|██████████| 3/3 [00:05<00:00,  1.71s/batch, loss=1.85]
[Val]   Epoch 8/100: 100%|██████████| 1/1 [00:00<00:00,  1.72batch/s, loss=1.85]


Epoch 007: Train Loss=1.7770 RMSE=0.9354 MAE=0.6473 | Val Loss=1.8464 RMSE=0.9696 MAE=0.6890


[Train] Epoch 9/100: 100%|██████████| 3/3 [00:04<00:00,  1.59s/batch, loss=2.18]
[Val]   Epoch 9/100: 100%|██████████| 1/1 [00:00<00:00,  1.62batch/s, loss=2.18]


Epoch 008: Train Loss=1.7243 RMSE=0.9282 MAE=0.6209 | Val Loss=2.1775 RMSE=0.9626 MAE=0.6660


[Train] Epoch 10/100: 100%|██████████| 3/3 [00:05<00:00,  1.79s/batch, loss=2.69] 
[Val]   Epoch 10/100: 100%|██████████| 1/1 [00:00<00:00,  1.67batch/s, loss=2.69]


Epoch 009: Train Loss=1.6722 RMSE=0.8989 MAE=0.5935 | Val Loss=2.6891 RMSE=0.9586 MAE=0.6522


[Train] Epoch 11/100: 100%|██████████| 3/3 [00:04<00:00,  1.61s/batch, loss=1.09]
[Val]   Epoch 11/100: 100%|██████████| 1/1 [00:00<00:00,  1.77batch/s, loss=1.09]


Epoch 010: Train Loss=1.6487 RMSE=0.8865 MAE=0.5885 | Val Loss=1.0913 RMSE=0.9539 MAE=0.6498


[Train] Epoch 12/100: 100%|██████████| 3/3 [00:04<00:00,  1.50s/batch, loss=1.09]
[Val]   Epoch 12/100: 100%|██████████| 1/1 [00:00<00:00,  1.78batch/s, loss=1.09]


Epoch 011: Train Loss=1.6223 RMSE=0.8845 MAE=0.5863 | Val Loss=1.0855 RMSE=0.9567 MAE=0.6456


[Train] Epoch 13/100: 100%|██████████| 3/3 [00:04<00:00,  1.64s/batch, loss=1.4] 
[Val]   Epoch 13/100: 100%|██████████| 1/1 [00:00<00:00,  1.85batch/s, loss=1.4]


Epoch 012: Train Loss=1.5954 RMSE=0.8846 MAE=0.5792 | Val Loss=1.4032 RMSE=0.9590 MAE=0.6503


[Train] Epoch 14/100: 100%|██████████| 3/3 [00:04<00:00,  1.55s/batch, loss=0.876]
[Val]   Epoch 14/100: 100%|██████████| 1/1 [00:00<00:00,  1.81batch/s, loss=0.876]


Epoch 013: Train Loss=1.5857 RMSE=0.8450 MAE=0.5876 | Val Loss=0.8755 RMSE=0.9622 MAE=0.6510


[Train] Epoch 15/100: 100%|██████████| 3/3 [00:05<00:00,  1.68s/batch, loss=1.25]
[Val]   Epoch 15/100: 100%|██████████| 1/1 [00:00<00:00,  1.79batch/s, loss=1.25]


Epoch 014: Train Loss=1.5169 RMSE=0.8594 MAE=0.5552 | Val Loss=1.2459 RMSE=0.9691 MAE=0.6430


[Train] Epoch 16/100: 100%|██████████| 3/3 [00:04<00:00,  1.54s/batch, loss=1.13]
[Val]   Epoch 16/100: 100%|██████████| 1/1 [00:00<00:00,  1.88batch/s, loss=1.13]


Epoch 015: Train Loss=1.5366 RMSE=0.8583 MAE=0.5527 | Val Loss=1.1264 RMSE=0.9682 MAE=0.6482


[Train] Epoch 17/100: 100%|██████████| 3/3 [00:04<00:00,  1.57s/batch, loss=1.8]  
[Val]   Epoch 17/100: 100%|██████████| 1/1 [00:00<00:00,  1.86batch/s, loss=1.8]


Epoch 016: Train Loss=1.5345 RMSE=0.8654 MAE=0.5580 | Val Loss=1.8002 RMSE=0.9713 MAE=0.6577


[Train] Epoch 18/100: 100%|██████████| 3/3 [00:04<00:00,  1.63s/batch, loss=2.29] 
[Val]   Epoch 18/100: 100%|██████████| 1/1 [00:00<00:00,  1.76batch/s, loss=2.29]


Epoch 017: Train Loss=1.5507 RMSE=0.8738 MAE=0.5848 | Val Loss=2.2938 RMSE=0.9765 MAE=0.6661


[Train] Epoch 19/100: 100%|██████████| 3/3 [00:04<00:00,  1.63s/batch, loss=0.736]
[Val]   Epoch 19/100: 100%|██████████| 1/1 [00:00<00:00,  1.65batch/s, loss=0.736]


Epoch 018: Train Loss=1.5434 RMSE=0.8271 MAE=0.5695 | Val Loss=0.7358 RMSE=0.9807 MAE=0.6575


[Train] Epoch 20/100: 100%|██████████| 3/3 [00:05<00:00,  1.73s/batch, loss=2.26]
[Val]   Epoch 20/100: 100%|██████████| 1/1 [00:00<00:00,  1.67batch/s, loss=2.26]


Epoch 019: Train Loss=1.5261 RMSE=0.8694 MAE=0.5421 | Val Loss=2.2594 RMSE=0.9881 MAE=0.6523


[Train] Epoch 21/100: 100%|██████████| 3/3 [00:04<00:00,  1.66s/batch, loss=1.36]
[Val]   Epoch 21/100: 100%|██████████| 1/1 [00:00<00:00,  1.65batch/s, loss=1.36]


Epoch 020: Train Loss=1.5007 RMSE=0.8616 MAE=0.5476 | Val Loss=1.3585 RMSE=0.9877 MAE=0.6697


[Train] Epoch 22/100: 100%|██████████| 3/3 [00:04<00:00,  1.59s/batch, loss=1.41] 
[Val]   Epoch 22/100: 100%|██████████| 1/1 [00:00<00:00,  1.66batch/s, loss=1.41]


Epoch 021: Train Loss=1.5194 RMSE=0.8563 MAE=0.5632 | Val Loss=1.4075 RMSE=0.9889 MAE=0.6678


[Train] Epoch 23/100: 100%|██████████| 3/3 [00:04<00:00,  1.56s/batch, loss=1.46] 
[Val]   Epoch 23/100: 100%|██████████| 1/1 [00:00<00:00,  1.66batch/s, loss=1.46]


Epoch 022: Train Loss=1.4921 RMSE=0.8480 MAE=0.5485 | Val Loss=1.4640 RMSE=0.9925 MAE=0.6624


[Train] Epoch 24/100: 100%|██████████| 3/3 [00:04<00:00,  1.61s/batch, loss=0.924]
[Val]   Epoch 24/100: 100%|██████████| 1/1 [00:00<00:00,  1.79batch/s, loss=0.924]


Epoch 023: Train Loss=1.5199 RMSE=0.8537 MAE=0.5421 | Val Loss=0.9240 RMSE=0.9926 MAE=0.6713


[Train] Epoch 25/100: 100%|██████████| 3/3 [00:04<00:00,  1.59s/batch, loss=1.29]
[Val]   Epoch 25/100: 100%|██████████| 1/1 [00:00<00:00,  1.71batch/s, loss=1.29]


Epoch 024: Train Loss=1.4605 RMSE=0.8468 MAE=0.5453 | Val Loss=1.2920 RMSE=0.9904 MAE=0.6669


[Train] Epoch 26/100: 100%|██████████| 3/3 [00:05<00:00,  1.75s/batch, loss=1.87] 
[Val]   Epoch 26/100: 100%|██████████| 1/1 [00:00<00:00,  1.74batch/s, loss=1.87]


Epoch 025: Train Loss=1.4998 RMSE=0.8586 MAE=0.5457 | Val Loss=1.8718 RMSE=0.9923 MAE=0.6674


[Train] Epoch 27/100: 100%|██████████| 3/3 [00:04<00:00,  1.61s/batch, loss=0.92]
[Val]   Epoch 27/100: 100%|██████████| 1/1 [00:00<00:00,  1.86batch/s, loss=0.92]


Epoch 026: Train Loss=1.4564 RMSE=0.8378 MAE=0.5396 | Val Loss=0.9200 RMSE=0.9916 MAE=0.6787


[Train] Epoch 28/100: 100%|██████████| 3/3 [00:05<00:00,  1.67s/batch, loss=1.15]
[Val]   Epoch 28/100: 100%|██████████| 1/1 [00:00<00:00,  1.77batch/s, loss=1.15]


Epoch 027: Train Loss=1.4671 RMSE=0.8442 MAE=0.5498 | Val Loss=1.1489 RMSE=0.9900 MAE=0.6661


[Train] Epoch 29/100: 100%|██████████| 3/3 [00:04<00:00,  1.64s/batch, loss=2.58] 
[Val]   Epoch 29/100: 100%|██████████| 1/1 [00:00<00:00,  1.81batch/s, loss=2.58]

Epoch 028: Train Loss=1.4875 RMSE=0.8437 MAE=0.5400 | Val Loss=2.5826 RMSE=0.9906 MAE=0.6621
Early stopping triggered.





In [None]:
print()