In [None]:

import pickle
import math
import panel as pn
import yaml
import holoviews as hv
import altair as alt
alt.data_transformers.disable_max_rows()
hv.extension("plotly")
pn.extension("plotly")
pn.config.theme = 'dark'
hv.renderer('plotly').theme = 'dark'
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

import os
import seaborn as sns
sns.set_style('darkgrid')
sns.set_palette('muted')
sns.set_context("notebook", font_scale=1.5,
                rc={"lines.linewidth": 2.5})
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
import torch
import yaml
import torch.nn as nn
import numpy as np
from tqdm import tqdm, trange
import os
from tqdm import tqdm
import sys
from dataclasses import dataclass

sys.path.append(os.path.abspath(os.path.join('..')))
from src.md import MDDataset
from src.models.denoiser import DenoiserModelPipeline, train_diffusion_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Automated trainer

configs_dir = '../configs/denoiser'
configs = os.listdir(configs_dir)

print(f"Found {len(configs)} configs")
for config in configs:
    print(config)
    
pipelines = []

for config in tqdm(configs):
    print("=====================================")
    denoiserModelPipeline = DenoiserModelPipeline.load(config, skip_data_load=True)
    if denoiserModelPipeline.config.training.skip:
        print(f"Skipping training for {config}")
        continue
    pipelines.append(denoiserModelPipeline)
    print(f"Loaded model {denoiserModelPipeline.config.model.description}")
    print(f"Loaded model {config}")
    
print("Done loading models")

# Loading training datasets

for pipeline in pipelines:
    pipeline.load_dataset()
    
print("Done loading datasets")

# Training models

for pipeline in pipelines:
    print(f"Training model {pipeline.config.model.description}")
    print(f"Training params:")
    print(pipeline.config.training)
    train_diffusion_model(
        pipeline.dataset,
        pipeline.model,
        run_name = pipeline.file_name,
        n_epochs = pipeline.config.training.n_epochs,
        lr = pipeline.config.training.learning_rate,
        batch_size=pipeline.config.training.batch_size,
        validation_per_epoch = pipeline.config.training.validation_per_epoch,
        validation_portion = pipeline.config.training.validation_portion,
        ablation_in_epoch_per_each_epochs = pipeline.config.training.ablation_in_epoch_per_each_epochs,
        ablation_batch_slice = pipeline.config.training.ablation_batch_slice,
    )
    print(f"Training model {pipeline.config.model.description} done")
    