In [None]:
from data import GaussianDistance, CIFData, data_loader

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import DataLoader

# Load dataset
train_list = torch.load('train_data.pth')
val_list = torch.load('val_data.pth')
test_list = torch.load('test_data.pth')
train_loader = data_loader(train_list, batch_size=256, shuffle=True)
val_loader = data_loader(val_list, batch_size=64, shuffle=False)
test_loader = data_loader(test_list, batch_size=64, shuffle=False)


from models.CGCNN import CGCNN
from models.CGAT import CGAT
from models.CGT import CGT

In [None]:

name = 'CGCNN' # or 'CGAT' or 'CGT'
model = name
variable = globals()[model]

import csv
import torch
import numpy as np
import random
import torch.optim as optim
from tqdm import tqdm
import torch.nn as nn
import os


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 300


model = variable(edge_dim=14, out_dim=201*4, seed=123).to(device)


for i, layer in enumerate(model.layers):
    print(f"Bias in layer {i+1}: : {layer.bias}")

criterion = nn.MSELoss()  


optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-6)


for param in model.parameters():
    param.grad = None
optimizer.state.clear()


if not os.path.exists(f'model_file_{name}'):
    os.makedirs(f'model_file_{name}', exist_ok=True)

    
with open(f'Loss{name}.csv', mode="w", newline="") as csv_file:
    writer = csv.writer(csv_file)
    writer.writerow(["epoch", "train_mse", "val_mse", "train_rmse", "val_rmse"])


for epoch in range(epochs):
    # ---------- train ----------
    model.train()
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch)
        loss = criterion(output, batch.y)
        loss.backward()
        optimizer.step()

    # ---------- calculate train set loss----------
    model.eval()
    total_train_loss = 0.0
    with torch.no_grad():
        for batch in train_loader:
            batch = batch.to(device)
            output = model(batch)
            batch_size = batch.batch.max().item() + 1  
            loss_mse = criterion(output, batch.y)
            total_train_loss += loss_mse.item() * batch_size
    train_mse_avg = total_train_loss / len(train_loader.dataset)
    train_rmse_avg = np.sqrt(train_mse_avg)  

    # ---------- calculate validation set loss ----------
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            output = model(batch)
            batch_size = batch.batch.max().item() + 1
            loss_mse = criterion(output, batch.y)
            total_val_loss += loss_mse.item() * batch_size
    val_mse_avg = total_val_loss / len(val_loader.dataset)
    val_rmse_avg = np.sqrt(val_mse_avg)  

    # ---------- record ----------
    print(f"Epoch {epoch+1}/{epochs}, Train RMSE: {train_rmse_avg:.4f}, Val RMSE: {val_rmse_avg:.4f}")
    

    with open(f'Loss{name}.csv', mode="a", newline="") as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow([epoch+1, train_mse_avg, val_mse_avg, train_rmse_avg, val_rmse_avg])
        
    if (epoch + 1) % 5 == 0:  # save model weights

        torch.save(model.state_dict(), f'model_file_{name}/epoch_{epoch+1}.pth')

    