In [None]:
from torch.utils.data import DataLoader
from statsmodels.nonparametric.smoothers_lowess import lowess
import torch 
import gc
import plotly.graph_objects as go
from torch import nn 
from torch.utils.data import Dataset
from torchvision.transforms.v2 import PILToTensor,Compose
from sklearn.model_selection import train_test_split
from datetime import datetime
import re
from torch.utils.tensorboard import SummaryWriter
import torchvision
import pandas as pd
import math
import os
from tqdm import tqdm
from einops import rearrange 
import yaml
import pickle
import matplotlib.pyplot as plt 
from functools import partial, reduce
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import numpy as np
import seaborn as sns
import matplotlib

import panel as pn

import holoviews as hv

import random
import sys
import math
import time
import random
from yellowbrick.text import TSNEVisualizer
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity

from holoviews import dim, opts
from dataclasses import dataclass

sys.path.append(os.path.abspath(os.path.join('..')))
from src.md import MDDataset, generate_phases_for_dense_validation, MDDenseSet, MDLoadable, MDDense, MDDensePhaseData
from src.models.denoiser import DenoiserModelPipeline, DenoiserModel, get_forward_diffusion_params, forward_add_noise
from src.preprocessing import ChromoDataContainer, ChromoData
from src.models.unified_classifier import evaluate_test_set, evaluate_unified_classifier_model, predictCLS, ClassifierModelPipeline


device='cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
# Training function for unified classifier model
def train_unified_classifier_model(
    dataset,
    model,
    run_name = 'unnamed_unified_classifier_model',
    n_epochs = 100,
    batch_size = 32,
    
    timestep = 1, # fixed timestep for the base denoiser model
    label_num = 5, # fixed label num for the base denoiser model
    label_count = 10, # fixed label count for the base denoiser model
    
    lr=10e-3,
    validation_portion=0.1,
    validation_per_epoch = 0,
    validation_samples = 1024,
    
    distilled_val_sets = None, # TODO: refactor this to have better generalization
    
    checkpoints_folder='../checkpoints/classifier'
):
    T, betas, alphas, alphas_cumprod, alphas_cumprod_prev, variance = get_forward_diffusion_params()
    
    alphas_cumprod = alphas_cumprod.to(device)
    
    model = model.to(device)
    
    model.train()

    print("Freezing denoiser_model")
    # # freeze all parameters except classification head
    for name, param in model.denoiser_model.named_parameters():
        param.requires_grad = False
        # print(f"Freezing: {name}")
    
    print("Unfreezing classifier_head")
    # Ensure the classification head parameters require gradients
    for name, param in model.classifier_head.named_parameters():
        param.requires_grad = True
        # print(f"Unfreezing: {name}")

    # loss for classification
    loss_fn=nn.CrossEntropyLoss()

    # optimizer for classification head only
    optimzer=torch.optim.AdamW(
        model.parameters(),
        lr=lr)
    
    # define scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimzer, step_size=10, gamma=0.9)
    train_ds, val_ds = dataset.split(test_size=validation_portion)
    
    # creating directory for checkpoints
    os.makedirs(checkpoints_folder, exist_ok=True)
    
    # setting checkpoint paths
    path_checkpoint = os.path.join(checkpoints_folder, f"{run_name}.pth")


    dataloader=DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        drop_last = True,
        num_workers=10,
        persistent_workers=True
    )
    
    dataloader_eval = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=True,
        drop_last = True,
        num_workers=10,
        persistent_workers=True
    )

    model.train()
    timestamp_int = int(datetime.now().timestamp())
    log_name = f"{timestamp_int}_{run_name}"
    print(f"Log name: {log_name}")
    writer = SummaryWriter(log_dir=f'../logs/{log_name}')

    iter_count=0
    for epoch in tqdm(range(n_epochs)):
        epoch_iter = 0
        
        if epoch % validation_per_epoch == 0:
            if epoch > 0:
                torch.save(model.state_dict(),path_checkpoint)
                print("Model saved")
            print("Validating")
            
            # VALIDATION
            
            model.eval()
            
            with torch.no_grad():
                
                # validation on random samples from validation set
                
                print("Validation on val samples")
                result_val = evaluate_test_set(model, predictCLS, dataloader_eval, alphas_cumprod, timestep, label_num, label_count, iterations=validation_samples // batch_size)
                writer.add_scalar('Error/val', result_val['error_percentage'], iter_count)
                
                
                # TODO: refactor this to have better generalization
                # validation on RD
                
                result_rel_mu, result_rel_lowess, result_gen_mu, result_gen_lowess = evaluate_unified_classifier_model(
                    model,
                    distilled_val_sets,
                    alphas_cumprod,
                    timestep,
                    label_num,
                    label_count,
                    batch_size,
                    predictCLS  
                )
                
                writer.add_scalar('Error/real_mu', result_rel_mu['total_abs_percentage_error'], iter_count)
                writer.add_scalar('Error/real_lowess', result_rel_lowess['total_abs_percentage_error'], iter_count)
                writer.add_scalar('Error/gen_mu', result_gen_mu['total_abs_percentage_error'], iter_count)
                writer.add_scalar('Error/gen_lowess', result_gen_lowess['total_abs_percentage_error'], iter_count)
                
                # switch back to train
            
            model.train()
            
            print("Validation done")
        
        for datachunk in dataloader:
            
            gxx = datachunk['gxx']
            labels_classifier = datachunk['labels_classifier']
            
            gxx=gxx.to(device)
            labels_classifier=labels_classifier.to(device)
            
            # fixed conditions
            # t=torch.randint(0,TIMESTEP,(BATCH_SIZE,)).to(device)
            t=torch.full((batch_size,),timestep,dtype=torch.long).to(device)
            y=torch.full((batch_size,),label_num,dtype=torch.long).to(device)
            
            x=gxx*2-1
            # print(x.shape)
            x,noise=forward_add_noise(x,t, alphas_cumprod)
            labels_pred, _=model(x,t,y)
            
            loss = loss_fn(labels_pred, labels_classifier)
            
            optimzer.zero_grad()
            loss.backward()
            # grad clip
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimzer.step()
            iter_count += 1
            epoch_iter += 1
            
            if iter_count % 100 == 0:
                writer.add_scalar('Loss/train', loss, iter_count)

        scheduler.step()
        
            
    print("Training done, saving model")
    torch.save(model.state_dict(), path_checkpoint)

In [None]:
# # TRANING OF UNIFIED MODEL

# # creating val datasets for real data
chrom_data = ChromoDataContainer.load_from_pkl()

# TODO: Split MD and DenseD to separatte notebooks
# 1) MD generates source for Dense MD
# 2) DenseD generates Both Dense MD and Dense RD
distiled_real_mu = generate_phases_for_dense_validation(chrom_data.gausians, do_real_mu=True)
distiled_real_lowess = generate_phases_for_dense_validation(chrom_data.gausians, do_real_mu=False)
distiled_mock_mu = MDLoadable.load(
    config_name="dataset_128_128_784_classifier_control_set_mu",
    config_folder="../configs/MDDense",
    dataset_folder="../data/MDDense",
    dataset_cls=MDDenseSet)
distiled_mock_lowess = MDLoadable.load(
    config_name="dataset_128_128_784_classifier_control_set_lowess",
    config_folder="../configs/MDDense",
    dataset_folder="../data/MDDense",
    dataset_cls=MDDenseSet)


# # Automated trainer

configs_dir = '../configs/classifier'
configs = os.listdir(configs_dir)

print(f"Found {len(configs)} configs")
for config in configs:
    print(config)
    
pipelines = []

for config in tqdm(configs):
    print("=====================================")
    classifierModelPipeline = ClassifierModelPipeline.load(config, skip_data_load=True)
    if classifierModelPipeline.config.training.skip:
        print(f"Skipping training for {config}")
        continue
    pipelines.append(classifierModelPipeline)
    print(f"Loaded model {classifierModelPipeline.config.model.description}")
    print(f"Loaded model {config}")
    
print("Done loading models")

# Loading training datasets

for pipeline in pipelines:
    pipeline.load_dataset()
    
print("Done loading datasets")

# Training models

for pipeline in pipelines:
    print("=====================================")
    print(f"Training model {pipeline.config.model.description}")
    print(f"Training params:")
    print(pipeline.config.training)
    train_unified_classifier_model(
        pipeline.dataset,
        pipeline.model,
        run_name = pipeline.file_name,
        n_epochs = pipeline.config.training.n_epochs,
        lr = pipeline.config.training.learning_rate,
        batch_size=pipeline.config.training.batch_size,
        validation_per_epoch = pipeline.config.training.validation_per_epoch,
        validation_portion = pipeline.config.training.validation_portion,
        
        label_num=pipeline.config.training.label_num,
        timestep=pipeline.config.training.timestep,
        label_count=pipeline.config.training.label_count,
        
        distilled_val_sets=(
            distiled_real_mu, 
            distiled_real_lowess,
            distiled_mock_mu.data, 
            distiled_mock_lowess.data
        ) # refactor this to have better generalization
    )
    print(f"Training model {pipeline.config.model.description} done")
    
