  # Autoencoder - SNNL Dev

# Setup

In [1]:
# %pwd
# %cd ~/WSL-shared/Cellpainting/pt-snnl/
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))
%load_ext autoreload  
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
import os
import sys
import csv
import time
import types
import copy
import pprint
import logging
from datetime import datetime
for p in ['./src','../..']:
    if p not in sys.path:
        print(f"insert {p}")
        sys.path.insert(0, p)
print(sys.path)
from typing import Dict, List, Tuple


from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import scipy
from scipy.sparse import csr_matrix
import torch
from tqdm import tqdm
import torch.nn.functional as F
from torchinfo import summary
import wandb

pp = pprint.PrettyPrinter(indent=4)
pd.options.display.width = 132
torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=150, profile=None, sci_mode=None)
np.set_printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')

os.environ["WANDB_NOTEBOOK_NAME"] = "AE-SNNL-MAIN.ipynb"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
torch.set_num_threads(4)  ## <--- limit to ~ 2 CPUs
torch.get_num_threads()

insert ./src
insert ../..
['../..', './src', '/home/kevin/miniforge3/envs/ptsnnl/lib/python311.zip', '/home/kevin/miniforge3/envs/ptsnnl/lib/python3.11', '/home/kevin/miniforge3/envs/ptsnnl/lib/python3.11/lib-dynload', '', '/home/kevin/miniforge3/envs/ptsnnl/lib/python3.11/site-packages', '/home/kevin/miniforge3/envs/ptsnnl/lib/python3.11/site-packages/cytominer_eval-0.1-py3.11.egg']


4

In [3]:
from KevinsRoutines.utils.utils_wandb import init_wandb, wandb_log_metrics,wandb_watch
from KevinsRoutines.utils.utils_general import list_namespace, print_heading, print_underline

import snnl.utils as utils
from snnl.utils import parse_args, load_configuration, set_global_seed, get_device, set_device
from snnl.utils import plot_model_parms, plot_train_history, plot_regression_metrics
from snnl.utils import display_model_summary, display_model_hyperparameters, display_model_gradients, display_model_parameters
from snnl.utils import display_epoch_metrics, display_cellpainting_batch, display_model_state_dict
from snnl.utils import define_autoencoder_model, init_resume_training

In [4]:
timestamp = datetime.now().strftime('%Y_%m_%d_%H:%M:%S')
logger = logging.getLogger(__name__) 
logLevel = os.environ.get('LOG_LEVEL', 'INFO').upper()
FORMAT = '%(asctime)s - %(name)s - %(levelname)s: - %(message)s'
logging.basicConfig(level="INFO", format= FORMAT)

logger.info(f" Excution started : {timestamp} ")
logger.info(f" Pytorch version  : {torch.__version__}  \t\t Number of threads: {torch.get_num_threads()}")
logger.info(f" Scipy version    : {scipy.__version__}  \t\t Numpy version : {np.__version__}")
logger.info(f" WandB version    : {wandb.__version__}  \t\t Pandas version: {pd.__version__}  ")
# logger.info(f" Search path      : {sys.path}")

2024-09-06 17:45:00,757 - __main__ - INFO: -  Excution started : 2024_09_06_17:45:00 
2024-09-06 17:45:00,758 - __main__ - INFO: -  Pytorch version  : 2.2.2  		 Number of threads: 4
2024-09-06 17:45:00,759 - __main__ - INFO: -  Scipy version    : 1.12.0  		 Numpy version : 1.26.4
2024-09-06 17:45:00,760 - __main__ - INFO: -  WandB version    : 0.17.4  		 Pandas version: 2.2.1  


### main(args)

In [5]:
GPU_ID = 0
LATENT_DIM    = 150
HIDDEN_1      = 512
EPOCHS        = 600
COMPOUNDS_PER_BATCH = 200
LEARNING_RATE = 1.00e-04
TEMP          = 1.00
TEMP_LR       = 0.00   ##1.0e-05
ADAM_WEIGHT_DECAY = 0.001
SNNL_FACTOR   = 3.0
# LOSS_FACTOR   = 134.0
# CHECKPT       = f"AE_snnl_dcpb{COMPOUNDS_PER_BATCH}_{LATENT_DIM}Ltnt_{HIDDEN_1}{HIDDEN_2}_{DATE}_LAST_ep_{LOAD_EPOCH:03d}.pt"
# CHECKPT       = "AE_snnl_dcpb200_150Ltnt_512_20240709_2235_LAST_ep_700.pt"
# RUN_ID        = "jzt6ecjz"
# print(CHECKPT)

In [6]:
# if __name__ == "__main__":
cli_args = f" --runmode             snnl "\
            f" --configuration      ./hyperparameters/ae_snglopt_150_512_cpb.yaml" \
            f" --epochs             {EPOCHS} " \
            f" --single_loss        "\
            f" --prim_opt           "\
            f" --temp_annealing     "\
            f" --anneal_patience    30"\
            f" --adam_wd            {ADAM_WEIGHT_DECAY}"\
            f" --lr                 {LEARNING_RATE} "\
            f" --temp               {TEMP} "\
            f" --snnl_factor        {SNNL_FACTOR}"\
            f" --temp_lr            {TEMP_LR}"\
            f" --cpb                {COMPOUNDS_PER_BATCH}"\
            f" --seed               4321"\
            f" --gpu_id             {GPU_ID} "  \
            f" --wandb              "\
            f" "

# f" --loss_factor        {LOSS_FACTOR}"\
# f" --run_id             {RUN_ID} "\
# f" --ckpt               {CHECKPT} " \
# f" --temp_opt " \
# f" --ckpt               AE_snnl_dcpb200_{LATENT_DIM}Ltnt_{HIDDEN_1}{HIDDEN_2}_{DATE}_LAST_ep_{LOAD_EPOCH:03d}.pt " \
# f" --ckpt               AE_baseline_{DATE}_snglOpt-{LATENT_DIM}Ltnt{HIDDEN_1}{HIDDEN_2}_ep_{LOAD_EPOCH}.pt " \
# f" --configuration      hyperparameters/ae_sn_{LATENT_DIM:03d}{HIDDEN_1}{HIDDEN_2}_cpb.yaml" \
# f" --exp_title           snglOpt-050Ltnt_512_sig "
# f" --runmode            snnl" \
cli_args = utils.parse_args(cli_args.split())
cli_args

args = utils.load_configuration(cli_args)
args.ckpt

Namespace(configuration='./hyperparameters/ae_snglopt_150_512_cpb.yaml', ckpt=None, cpb=200, exp_title=None, epochs=600, gpu_id=0, learning_rate=0.0001, exp_id=None, runmode='snnl', random_seed=4321, use_prim_optimizer=True, use_temp_optimizer=False, use_annealing=True, anneal_patience=30, use_single_loss=True, temperature=1.0, adam_weight_decay=0.001, loss_factor=None, snnl_factor=3.0, temperatureLR=0.0, WANDB_ACTIVE=True)

2024-09-06 17:45:07,438 - snnl.utils.utils_ptsnnl - INFO: -  command line param configuration             : [./hyperparameters/ae_snglopt_150_512_cpb.yaml]
2024-09-06 17:45:07,439 - snnl.utils.utils_ptsnnl - INFO: -  command line param ckpt                      : [None]
2024-09-06 17:45:07,440 - snnl.utils.utils_ptsnnl - INFO: -  command line param cpb                       : [200]
2024-09-06 17:45:07,440 - snnl.utils.utils_ptsnnl - INFO: -  command line param exp_title                 : [None]
2024-09-06 17:45:07,441 - snnl.utils.utils_ptsnnl - INFO: -  command line param epochs                    : [600]
2024-09-06 17:45:07,442 - snnl.utils.utils_ptsnnl - INFO: -  command line param gpu_id                    : [0]
2024-09-06 17:45:07,442 - snnl.utils.utils_ptsnnl - INFO: -  command line param learning_rate             : [0.0001]
2024-09-06 17:45:07,443 - snnl.utils.utils_ptsnnl - INFO: -  command line param exp_id                    : [None]
2024-09-06 17:45:07,444 - snnl.utils.utils

In [7]:
list_namespace(args)


command line parms : 
-----------------------
SGD_momentum.............  0
SGD_weight_decay.........  0
WANDB_ACTIVE.............  True
adam_weight_decay........  0.001
anneal_patience..........  30
batch_size...............  1

    cellpainting_args   (dict)
    ----------------------------
    batch_size...............  1
    chunksize................  None
    compounds_per_batch......  200
    conversions..............  None
    iterator.................  True
    sample_size..............  3
    test_end.................  33600
    test_path................  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_profiles/3sample_profiles_1482_HashOrder_test.csv
    test_start...............  30000
    train_end................  240000
    train_start..............  0
    training_path............  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_profiles/3sample_profiles_1482_HashOrder.csv
    val_end..................  24000
    val_star

In [8]:
### Set random seed and gpu device
set_global_seed(args.random_seed)
if args.ckpt is not None:
    if os.path.exists(os.path.join('ckpts', args.ckpt)):
        logger.info(f"Checkpoint {args.ckpt} found")
        logger.info(f"Resuming training using checkpoint: {args.ckpt}")
    else:
        logger.error(f"*** Checkpoint {args.ckpt} not found *** \n")
        raise ValueError(f"\n *** Checkpoint DOESNT EXIST *** \n")

if args.gpu_id is not None:
    _ = get_device(verbose=True)
    args.current_device = set_device(args.gpu_id)
    print(f" args.current_device is : {args.current_device}")

 device: 0   Quadro GV100                   :  free: 33,744,814,080 B   (31.43 GB)    total: 34,069,872,640 B   (31.73 GB)
 device: 1   Quadro GV100                   :  free: 33,744,814,080 B   (31.43 GB)    total: 34,069,872,640 B   (31.73 GB)


2024-09-06 17:45:10,791 - snnl.utils.utils_ptsnnl - INFO: -  Current CUDA Device is:  cuda:0 - Quadro GV100
2024-09-06 17:45:10,792 - snnl.utils.utils_ptsnnl - INFO: -  Switched to: Quadro GV100 - 0


 device: 2   NVIDIA TITAN Xp                :  free: 12,193,497,088 B   (11.36 GB)    total: 12,774,539,264 B   (11.90 GB)
 args.current_device is : cuda:0


  ### WandB Setup

In [9]:
wandb_run = utils.setup_wandb(args)

2024-09-06 17:45:12,057 - snnl.utils.utils_notebook - INFO: - ***** Initialize NEW  W&B Run *****


None, AE_20240906_1745, CellPainting_Profiles


[34m[1mwandb[0m: Currently logged in as: [33mkbardool[0m. Use [1m`wandb login --relogin`[0m to force relogin


2024-09-06 17:45:19,224 - snnl.utils.utils_notebook - INFO: -  Experiment Name  : AE_20240906_1745
2024-09-06 17:45:19,226 - snnl.utils.utils_notebook - INFO: -  Experiment Date  : 20240906_1745
2024-09-06 17:45:19,228 - snnl.utils.utils_notebook - INFO: - ***** Initialize NEW  W&B Run *****
2024-09-06 17:45:19,229 - snnl.utils.utils_notebook - INFO: - WANDB_ACTIVE     : True
2024-09-06 17:45:19,230 - snnl.utils.utils_notebook - INFO: - Project Name     : CellPainting_Profiles
2024-09-06 17:45:19,231 - snnl.utils.utils_notebook - INFO: - Experiment Id    : 28f3bwe5
2024-09-06 17:45:19,233 - snnl.utils.utils_notebook - INFO: - Experiment Name  : AE_20240906_1745
2024-09-06 17:45:19,233 - snnl.utils.utils_notebook - INFO: - Experiment Date  : 20240906_1745
2024-09-06 17:45:19,234 - snnl.utils.utils_notebook - INFO: - Experiment Title : dcpb200_150Ltnt_512
2024-09-06 17:45:19,235 - snnl.utils.utils_notebook - INFO: - Experiment Notes : AE snnl - DualOpt, 150 dim latent, 512 Midlayer, 200 

In [10]:
# if args.WANDB_ACTIVE:
#     wandb_run.finish()
#     WANDB_ACTIVE = False

  ### Define dataset and dataloaders

In [11]:
# args.cellpainting_args

#### Load CellPainting Dataset
data_loaders = utils.build_dataloaders(args, data = ['train', 'val'])
len(data_loaders['train']),len(data_loaders['val'])

2024-09-06 17:45:19,384 - root - INFO: -  load cellpainting
2024-09-06 17:45:19,387 - snnl.utils.dataloader - INFO: -  Building CellPantingDataset for train
2024-09-06 17:45:19,388 - snnl.utils.dataloader - INFO: -  filename:  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_profiles/3sample_profiles_1482_HashOrder.csv
2024-09-06 17:45:19,389 - snnl.utils.dataloader - INFO: -  type    :  train
2024-09-06 17:45:19,390 - snnl.utils.dataloader - INFO: -  start   :  0
2024-09-06 17:45:19,391 - snnl.utils.dataloader - INFO: -  end     :  240000
2024-09-06 17:45:19,392 - snnl.utils.dataloader - INFO: -  numrows :  240000
2024-09-06 17:45:19,393 - snnl.utils.dataloader - INFO: -  names   :  None     usecols :  None
2024-09-06 17:45:19,395 - snnl.utils.dataloader - INFO: -  batch_size  :  1
2024-09-06 17:45:19,396 - snnl.utils.dataloader - INFO: -  sample_size :  3
2024-09-06 17:45:19,396 - snnl.utils.dataloader - INFO: -  compounds_per_batch :  200
2024-09-06 17:45:19,

 load cellpainting
 Dataset size: 240000   rows per batch: 600
 Dataset size: 24000   rows per batch: 600


(240000, 24000)

# Define autoencoder model

### Override arguments

In [12]:
# args.temperature   = 1.0
# args.loss_factor   = 1.0       ## 1.0e+00
# args.learning_rate = 1.0e-03    ## 0.001
# args.temperatureLR = 0.0e-04    ## 1e-4
print(f"   runmode               : {args.runmode}")
print(f"   embedding_layer       : {args.embedding_layer}")
print(f"   Latent dim            : {args.code_units}")
print(f"   loss_factor           : {args.loss_factor}")
print(f"   snnl_factor           : {args.snnl_factor}")
print(f"   temperature           : {args.temperature}")
print(f"   learning_rate         : {args.learning_rate}")
print(f"   temperatureLR:        : {args.temperatureLR}")
print(f"   use_annealing:        : {args.use_annealing}")
print(f"   anneal_patience:      : {args.anneal_patience}")
print(f"   use Primary Optimizer : {args.use_prim_optimizer}")
print(f"   use Primary Scheduler : {args.use_prim_scheduler}")


   runmode               : snnl
   embedding_layer       : 4
   Latent dim            : 150
   loss_factor           : 1.0
   snnl_factor           : 3.0
   temperature           : 1.0
   learning_rate         : 0.0001
   temperatureLR:        : 0.0
   use_annealing:        : True
   anneal_patience:      : 30
   use Primary Optimizer : True
   use Primary Scheduler : True


In [13]:
list_namespace(args)


command line parms : 
-----------------------
SGD_momentum.............  0
SGD_weight_decay.........  0
WANDB_ACTIVE.............  True
adam_weight_decay........  0.001
anneal_patience..........  30
batch_size...............  1

    cellpainting_args   (dict)
    ----------------------------
    batch_size...............  1
    chunksize................  None
    compounds_per_batch......  200
    conversions..............  None
    iterator.................  True
    sample_size..............  3
    test_end.................  33600
    test_path................  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_profiles/3sample_profiles_1482_HashOrder_test.csv
    test_start...............  30000
    train_end................  240000
    train_start..............  0
    training_path............  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_profiles/3sample_profiles_1482_HashOrder.csv
    val_end..................  24000
    val_star

In [14]:
### Define Model
args.runmode
model = define_autoencoder_model(args, verbose = True)

'snnl'

2024-09-06 17:48:23,985 - snnl.utils.utils_notebook - INFO: - Defining model in SNNL mode 
2024-09-06 17:48:23,991 - snnl.models.autoencoder - INFO: -     layer pair:    0  type:linear           input:   1471  output:   1024    weights: [1024, 1471]   
2024-09-06 17:48:24,008 - snnl.models.autoencoder - INFO: -     layer pair:    1  type:relu             input:      0  output:      0    weights: [0, 0]   
2024-09-06 17:48:24,009 - snnl.models.autoencoder - INFO: -     layer pair:    2  type:linear           input:   1024  output:    512    weights: [512, 1024]   
2024-09-06 17:48:24,016 - snnl.models.autoencoder - INFO: -     layer pair:    3  type:relu             input:      0  output:      0    weights: [0, 0]   
2024-09-06 17:48:24,017 - snnl.models.autoencoder - INFO: -     layer pair:    4  type:linear           input:    512  output:    150    weights: [150, 512]   
2024-09-06 17:48:24,019 - snnl.models.autoencoder - INFO: -     layer pair:    5  type:linear           input:    

 EMBEDDING LAYER: 4
 Device cuda:0 will be used

------------------------------------------------------------
 Building Base Model from NOTEBOOK
------------------------------------------------------------
    Model_init()_    -- mode:              latent_code
    Model_init()_    -- Unsupervised :     True
    Model_init()_    -- Support for unsupervised training  in 'latent_code' mode is True
    Model_init()_    -- Criterion:         MSELoss()
    Model_init()_    -- use_snnl :         True
    Model_init()_    -- temperature :      Parameter containing:
tensor([1.])
    Model_init()_    -- temperature LR:    0.0

------------------------------------------------------------
 Building Autoencoder from NOTEBOOK
------------------------------------------------------------
setup_prim_optimizer()
    AE init() -- mode               : latent_code
    AE init() -- unsupervised       : True
    AE init() -- layer_types        : ['linear', 'relu', 'linear', 'relu', 'linear', 'linear', 'relu'

## Resume from model checkpoint (if args.ckpt is not None)

In [15]:
args.load_checkpoint = utils.load_checkpoint_v5
args.save_checkpoint = utils.save_checkpoint_v5
# print(f" Checkpoint File :  {args.ckpt}")

model = init_resume_training(model, args)

if args.WANDB_ACTIVE:
    wandb_watch(item = model, criterion=None, log = 'all', log_freq = 1000, log_graph = False)
    wandb.config.update(args)

2024-09-06 17:48:28,463 - root - INFO: -  INITIALIZE TRAINING - Run 600 epochs: epoch 1 to 600 


In [16]:
# display_model_state_dict(model, 'test')
# display_model_parameters(model, 'title')
# display_model_gradients(model, 'test')
# display_model_hyperparameters(model)

# for k,v in model.__dict__.items():
#     if k == 'training_history':
#         print(f" {k:30}  type: {str(type(v)):25s}   values: {v.keys()} ")
#     else:
#         print(f" {k:30}  type: {str(type(v)):25s}   values: {v} ")

# display_model_hyperparameters(model)

# model.optimizers['prim'].state_dict()
# model.schedulers['prim'].state_dict()

# print(f" Optimizer ParamGroup[0]   : {model.optimizers['prim'].param_groups[0]}")
# for k,v in model.optimizers.items():
#     print(k, v)

# for k,v in enumerate(model.optimizers['prim'].param_groups[0]['params']):
#     print(f" {k}, {v.shape}")

# model.display_values('test')
# model.display_gradients('test')
# model.optimizers['prim'].state_dict()

# Autoencoder Fit 

In [17]:
# model.starting_epoch = 0
model.ending_epoch = 400
# model.starting_epoch, model.ending_epoch = 700, 705
model.starting_epoch, model.ending_epoch

print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  epoch {model.starting_epoch+1:4d} of {model.ending_epoch:4d}")

(0, 400)

 2024-09-06 17:48:40  epoch    1 of  400


In [None]:
print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  Start: epoch {model.starting_epoch+1:4d} of {model.ending_epoch:4d}")
model.fit(args, data_loaders)
print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  End: epoch {model.epoch + 1:4d} of {model.ending_epoch:4d}")


 2024-09-06 17:49:18  Start: epoch    1 of  400
                                                                                                                                                                                                              

2024-09-06 17:52:10,023 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 1


  time   ep / eps |  Trn_loss   Primary      SNNL  |   temp*         grad    |   R2      BestEp         |  Vld_loss   Primary      SNNL  |   R2       BestEp          |   LR        temp LR    |
------------------+--------------------------------+-------------------------+--------------------------+--------------------------------+----------------------------|------------------------|
17:52:09  1 /400  |  17.2008     0.8441    16.3567 |   1.000000   0.0000e+00 | -12.5555     1           |  16.5235     0.5711    15.9524 |  -3.3140     1             |  1.000e-04   0.000e+00 |
17:55:02  2 /400  |  16.1772     0.5046    15.6726 |   1.000000   0.0000e+00 |  -1.8010     2           |  16.0057     0.4415    15.5642 |  -0.5388     2             |  1.000e-04   0.000e+00 |              


2024-09-06 17:55:02,700 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 2


17:57:53  3 /400  |  15.8851     0.4171    15.4680 |   1.000000   0.0000e+00 |  -0.2239     3           |  15.8440     0.3863    15.4576 |   0.1213     3             |  1.000e-04   0.000e+00 |              


2024-09-06 17:57:53,897 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 3


18:00:45  4 /400  |  15.7549     0.3758    15.3790 |   1.000000   0.0000e+00 |   0.2707     4           |  15.7637     0.3588    15.4049 |   0.3820     4             |  1.000e-04   0.000e+00 |              


2024-09-06 18:00:46,410 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 4


18:03:36  5 /400  |  15.6789     0.3542    15.3248 |   1.000000   0.0000e+00 |   0.4666     5           |  15.7135     0.3426    15.3708 |   0.5029     5             |  1.000e-04   0.000e+00 |              


2024-09-06 18:03:37,354 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 5


18:06:27  6 /400  |  15.6253     0.3410    15.2842 |   1.000000   0.0000e+00 |   0.5485     6           |  15.6779     0.3326    15.3453 |   0.5564     6             |  1.000e-04   0.000e+00 |              


2024-09-06 18:06:27,803 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 6


18:09:20  7 /400  |  15.5829     0.3326    15.2503 |   1.000000   0.0000e+00 |   0.5857     7           |  15.6507     0.3263    15.3244 |   0.5802     7             |  1.000e-04   0.000e+00 |              


2024-09-06 18:09:20,641 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 7


18:12:12  8 /400  |  15.5462     0.3270    15.2192 |   1.000000   0.0000e+00 |   0.6047     8           |  15.6255     0.3217    15.3038 |   0.5907     8             |  1.000e-04   0.000e+00 |              


2024-09-06 18:12:12,868 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 8


18:15:03  9 /400  |  15.5128     0.3233    15.1896 |   1.000000   0.0000e+00 |   0.6118     9           |  15.6030     0.3183    15.2847 |   0.5934     9             |  1.000e-04   0.000e+00 |              


2024-09-06 18:15:03,398 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 9


18:17:53 10 /400  |  15.4822     0.3207    15.1616 |   1.000000   0.0000e+00 |   0.6140    10           |  15.5858     0.3163    15.2695 |   0.5936    10             |  1.000e-04   0.000e+00 |              


2024-09-06 18:17:53,690 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 10


18:20:45 11 /400  |  15.4543     0.3190    15.1353 |   1.000000   0.0000e+00 |   0.6144    11           |  15.5718     0.3150    15.2569 |   0.5914    10             |  1.000e-04   0.000e+00 |              
18:23:35 12 /400  |  15.4294     0.3182    15.1112 |   1.000000   0.0000e+00 |   0.6146    12           |  15.5606     0.3145    15.2460 |   0.5932    10             |  1.000e-04   0.000e+00 |              
18:26:22 13 /400  |  15.4066     0.3177    15.0889 |   1.000000   0.0000e+00 |   0.6143    12           |  15.5520     0.3143    15.2377 |   0.5951    13             |  1.000e-04   0.000e+00 |              


2024-09-06 18:26:23,214 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 13


18:29:14 14 /400  |  15.3857     0.3176    15.0681 |   1.000000   0.0000e+00 |   0.6145    12           |  15.5447     0.3141    15.2306 |   0.5956    14             |  1.000e-04   0.000e+00 |              


2024-09-06 18:29:14,853 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 14


18:32:07 15 /400  |  15.3666     0.3178    15.0488 |   1.000000   0.0000e+00 |   0.6140    12           |  15.5400     0.3143    15.2257 |   0.5959    15             |  1.000e-04   0.000e+00 |              


2024-09-06 18:32:07,881 - snnl.utils.utils_ptsnnl - INFO: -  Model exported to ckpts/AE_snnl_dcpb200_150Ltnt_512_20240906_1745_BEST.pt - epoch: 15


 Trn 16/400:  10%|███████████                                                                                                         | 38/400 [00:16<02:25,  2.49it/s, Losses - Recon=0.28692, SNNL=15.04265]

In [None]:
print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  Start: epoch {model.starting_epoch+1:4d} of {model.ending_epoch:4d}")
model.fit(args, data_loaders)
print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}  End: epoch {model.epoch + 1:4d} of {model.ending_epoch:4d}")


In [None]:
# model.use_temp_optimizer
# take_checkpoint(model, args, epoch=306, update_best = True)
# model.use_temp_optimizer
# model.temp_params
# model.anneal_patience = 30
# model.optimizers['prim'].param_groups[0]
# torch.get_num_threads()
# torch.set_num_threads(3)

In [None]:
model.epoch, model.starting_epoch, model.ending_epoch

In [None]:
for epoch in range(model.ending_epoch):
    display_epoch_metrics(model, epoch)

In [None]:
# model.starting_epoch = 100
# model.ending_epoch = 200
# model.starting_epoch = 503
model.starting_epoch = model.ending_epoch
model.ending_epoch += 200

print(f" {datetime.now().strftime('%Y%m%d_%H%M%S')}  epoch {model.starting_epoch+1:4d} of {model.ending_epoch:4d}")

### Close WandB Logging

In [None]:
if args.WANDB_ACTIVE:
    wandb_run.finish()
    args.WANDB_ACTIVE = False

## Misc stuff

In [None]:
# for p in model.layers:
#     if hasattr(p, 'weight'):
#         p.weight.shape, p.bias.shape
# model.training_history['trn'].keys()
# # model.training_history['trn']['L04_W_grad']
# model.optimizers['prim']
# model.state_dict()['layers.0.weight']
# model.optimizers['prim'].zero_grad()
# layers.0.weight
# display_model_gradients(model, 'test')
# model.state_dict()['layers.0.weight'].requires_grad
# model.temperature.data
# model.snnl_criterion.temperature

In [None]:
# display_epoch_metrics(model, model.epoch, model.ending_epoch, header = True)

In [None]:
# model.starting_epoch, model.ending_epoch = 0, 2
# epoch = 1
# model.temperature = torch.abs(model.temperature)
# model.temperature = torch.nn.Parameter(torch.clip(model.temperature, 0.001, None))
# model.snnl_criterion.temperature= torch.nn.Parameter(torch.clip(model.snnl_criterion.temperature, 0.001, None))

In [None]:
# model.state_dict().keys()
# # for k in ['temperature', 'snnl_criterion.temperature'
# for k in ['temperature', 'snnl_criterion.temperature','layers.0.weight', 'layers.0.bias','layers.2.weight','layers.4.weight','layers.4.bias','layers.5.weight','layers.5.bias','layers.7.weight','layers.9.weight','layers.9.bias',]:
#     if model.state_dict()[k].ndim > 1:
#         print(f" {k+' - '+str(model.state_dict()[k].shape):45s} - {model.state_dict()[k][:3,:3].reshape((-1)).data}")
#     else:
#         print(f" {k+' - '+str(model.state_dict()[k].shape):45s} - {model.state_dict()[k][:9].data}")


# Plot losses, weights, biases and gradients

In [None]:
# if 0:
#     model_attributes = model.__dict__
#     model.training_history['train'].keys()
#     for key, value in model.training_history['train'].items():
#         if isinstance(value, List) or key in ["test_accuracy", "test_f1"]:
#             print(f"{key:25s} {type(value)}  {len(value):7d}  {value[-5:]}")
#     print()
#     for key, value in model.training_history['val'].items():
#         if isinstance(value, List) or key in ["test_accuracy", "test_f1"]:
#             print(f"{key:25s} {type(value)}  {len(value):7d}  {value[-5:]}")        

# tmp = np.array(model.training_history['train']['temp_grads'])
# for st in range(0,len(tmp), 1000):
#     end = st + 1000
#     print(f" {st:5d} - {end:5d}  min: {tmp[st:end].min():9e}   max: {tmp[st:end].max():9e}    avg: {tmp[st:end].mean():9e}   std: {tmp[st:end].std():9e}")

In [None]:
plot_train_history(model, start=0, n_bins = 25)

In [None]:
plot_regression_metrics(model,start = 4, n_bins = 25)

In [None]:
# plot_train_metrics(model, n_bins = 25)
# plot_regression_metrics(model,start = 4, n_bins = 25)

In [None]:
plot_train_history(model, start= 0, n_bins = 25)

In [None]:
plot_train_history(model, start= 0, n_bins = 25)

In [None]:
plot_train_history(model, start= 100, n_bins = 25)

In [None]:
plot_train_history(model, start= 10, n_bins = 25)

In [None]:
plot_train_history(model, start= 10, n_bins = 25)

In [None]:
plot_model_parms(model, epochs= model.epoch, n_bins = 15)

# Load model

In [None]:
ex_name = 'AE'
ex_epoch = 50
ex_runmode = 'snnl'
ex_date = '20240718'
ex_time = '1956'
ex_title = args.exp_title
compounds_per_batch = args.cpb
print(args.exp_title)
# runmode = 'snnl'
# ex_date = '20240516'
# ex_title = args.exp_title
# ex_epoch = 200

In [None]:
# filename = f"{model.name}_{args.runmode}_{exp_date}_{exp_title}_ep_{exp_epoch:03d}"
# filename = f"{model.name}_{ex_runmode}_{ex_date}_{ex_title}_{epochs:03d}_cpb_{ex_cpb}_factor_{ex_factor:d}.pt"
# file_pattern = f"{model.name}_{ex_runmode}_{ex_date}_{ex_title}_*_cpb_{ex_cpb}_factor_{ex_factor:d}.pt"
filename = f"{ex_name}_{ex_runmode}_{ex_title}_{ex_date}_{ex_time}_{ex_epoch:03d}.pt"
print(filename)
file_pattern = f"{ex_name}_{ex_runmode}_{ex_title}_{ex_date}_{ex_time}_BEST*.pt"
# file_pattern = f"{ex_name}_{ex_runmode}_{ex_title}_{ex_date}_{ex_time}_LAST*.pt"
print(file_pattern)

In [None]:
import glob
filelist = glob.glob(file_pattern,root_dir = './ckpts')
filelist = sorted(filelist)
filelist
# epochlist =sorted([int(x[-6:-3]) for x in filelist])
# epochlist

In [None]:
# runmode = 'snnl'
# ex_date = '20240516'
# ex_title = args.exp_title
ex_epoch = 200

filename = f"{model.name}_{args.runmode}_{exp_date}_{exp_title}_ep_{exp_epoch:03d}"
filename = filelist[0]
if filename[-3:] != '.pt':
    filename += '.pt'
print(filename)

if os.path.exists(os.path.join('ckpts', filename)):
    print(f"\n *** Checkpoint EXISTS *** \n")
else:
    print(f"\n *** Checkpoint DOESNT EXIST *** \n")

In [None]:
# mdl , last_epoch = load_checkpoint_v2(model, filename)
mdl, last_epoch, mdl_ckpt = utils.load_checkpoint_v5(model, filename)
print(f" last epoch : {last_epoch}")

In [None]:
print(args.current_device)
print(mdl.device)
# model.device = current_device
print(mdl.device)

In [None]:
print(current_device)
model.train()
model = model.cuda(device=current_device)

In [None]:
starting_epoch, epochs,

In [None]:
print(f" last epoch  {last_epoch}")

In [None]:
# starting_epoch = 20
starting_epoch = last_epoch
epochs = last_epoch + 100
# starting_epoch = epoch + 1
print(f" run epochs {starting_epoch+1} to {epochs} ")

In [None]:
print()
print(f" model device              : {model.device}")
print(f" model temperature         : {model.temperature}")
print(f" model use prim_optimizer  : {model.use_prim_optimizer}")
print(f" model use prim_scheduler  : {model.use_prim_scheduler}")
print() 
print(f" loss_factor               : {model.loss_factor}")
print(f" monitor_grads_layer       : {model.monitor_grads_layer}")
print(f" Learning rate             : {model.optimizers['prim'].param_groups[0]['lr']}") 
print(f" Optimizer ParamGroup[0]   : {model.optimizers['prim'].param_groups[0]}") 
print() 
print(f" Optimizer ParamGroup[0]   : {model.optimizers['prim'].state_dict()}") 

print(f" snnl_factor               : {model.snnl_factor}")
# if model.use_snnl:
#     print(f" temperature          : {model.temperature.item()}")
# if model.temp_optimizer is not None:
#     print(f" Temperature LR       : {model.temp_optimizer.param_groups[0]['lr']}") 
# print()

# for th_key in ['trn', 'val']:
#     for k,v in model.training_history[th_key].items():
#         if isinstance(v[-1],str):
#             print(f" {k:20s} : {v[-1]:s}  ")
#         else:
#             print(f" {k:20s} : {v[-1]:6f} ")
#     print()


In [None]:
model.optimizers['prim']['params']
# model.use_prim_scheduler

In [None]:
model.optimizers 
print()

In [None]:
model.scheduler.__dict__

In [None]:
model.temp_optimizer.__dict__
print()

In [None]:
model.temp_scheduler.__dict__