In [None]:
import torch
import pandas as pd
from torch.utils.data import Dataset
import os
from torch_geometric.data import Data
import pickle
import numpy as np
import pytorch_lightning as pl
from typing import Optional, Sequence
from pathlib import Path
import random
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import csv
from torch_geometric.data import DataLoader
from torch_geometric.data import Batch

from dBandDiff.data_utils import preprocess, preprocess_tensors, add_scaled_lattice_prop, get_scaler_from_data_list
from dBandDiff.data_load import main
from dBandDiff.diffusion import model


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  

set_seed(666)  


In [None]:
datamodule = main()
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()

In [None]:
print(len(train_loader.dataset))
print(len(val_loader.dataset))

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=10, gamma=0.95)

os.makedirs(f'model_file', exist_ok=True)


with open(f'Loss.csv', mode="w", newline="") as csv_file:
    writer = csv.writer(csv_file)
    writer.writerow([
        "epoch", "train_loss", "train_loss_lattice", "train_loss_coord", "train_loss_type",
        "val_loss", "val_loss_lattice", "val_loss_coord", "val_loss_type"
    ])


num_epochs = 1000

for epoch in range(num_epochs):
    model.train()
    running_train_loss = 0.0
    running_train_loss_lattice = 0.0
    running_train_loss_coord = 0.0
    running_train_loss_type = 0.0

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()

        output = model(batch)

        loss = output['loss']
        loss_lattice = output['loss_lattice']
        loss_coord = output['loss_coord']
        loss_type = output['loss_type']

        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()
        running_train_loss_lattice += loss_lattice.item()
        running_train_loss_coord += loss_coord.item()
        running_train_loss_type += loss_type.item()

    epoch_train_loss = running_train_loss / len(train_loader)
    epoch_train_loss_lattice = running_train_loss_lattice / len(train_loader)
    epoch_train_loss_coord = running_train_loss_coord / len(train_loader)
    epoch_train_loss_type = running_train_loss_type / len(train_loader)

    scheduler.step()

    # validation
    model.eval()
    running_val_loss = 0.0
    running_val_loss_lattice = 0.0
    running_val_loss_coord = 0.0
    running_val_loss_type = 0.0

    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            batch = batch.to(device)
            output = model(batch)

            val_loss = output['loss']
            val_loss_lattice = output['loss_lattice']
            val_loss_coord = output['loss_coord']
            val_loss_type = output['loss_type']

            running_val_loss += val_loss.item()
            running_val_loss_lattice += val_loss_lattice.item()
            running_val_loss_coord += val_loss_coord.item()
            running_val_loss_type += val_loss_type.item()

    epoch_val_loss = running_val_loss / len(val_loader)
    epoch_val_loss_lattice = running_val_loss_lattice / len(val_loader)
    epoch_val_loss_coord = running_val_loss_coord / len(val_loader)
    epoch_val_loss_type = running_val_loss_type / len(val_loader)

    # Write
    with open(f'Loss.csv', mode="a", newline="") as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow([
            epoch + 1,
            epoch_train_loss, epoch_train_loss_lattice, epoch_train_loss_coord, epoch_train_loss_type,
            epoch_val_loss, epoch_val_loss_lattice, epoch_val_loss_coord, epoch_val_loss_type
        ])


    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f'model_file/epoch_{epoch+1}.pth')
