  # Apply encoder to morphological profiles to get latent space representations :

# Setup

In [1]:
%load_ext autoreload  
%autoreload 2
from IPython.display import display, HTML, Image
from IPython.core.interactiveshell import InteractiveShell
display(HTML("<style>.container { width:98% !important; }</style>"))
InteractiveShell.ast_node_interactivity = "all"

In [2]:
import os
import sys
import random
from typing import List, Tuple
from types import SimpleNamespace
from functools import partial
import pprint
import logging
from datetime import datetime
for p in ['./src','../pt-snnl','../..']:
    if p not in sys.path:
        print(f"insert {p}")
        sys.path.insert(0, p)
print(sys.path)

import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd

import scipy
import scipy.stats as sps
import sklearn.metrics as skm
from scipy.spatial.distance import pdist, squareform, euclidean

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt  # for making figures
from torchinfo import summary

torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=180, profile=None, sci_mode=None)
torch.manual_seed(42);   # seed rng for reproducibility
pp = pprint.PrettyPrinter(indent=4)
pd.options.display.width = 132
np.set_printoptions(edgeitems=3, infstr='inf', linewidth=150, nanstr='nan')

os.environ["WANDB_NOTEBOOK_NAME"] = "AE-MAIN-SNNL.ipynb"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

torch.set_num_threads(4)  ## <--- limit to ~ 2 CPUs
torch.get_num_threads()

insert ./src
insert ../pt-snnl
insert ../..
['../..', '../pt-snnl', './src', '/home/kevin/WSL-shared/cellpainting/cj-datasets', '/home/kevin/miniforge3/envs/cp311/lib/python311.zip', '/home/kevin/miniforge3/envs/cp311/lib/python3.11', '/home/kevin/miniforge3/envs/cp311/lib/python3.11/lib-dynload', '', '/home/kevin/miniforge3/envs/cp311/lib/python3.11/site-packages', '/home/kevin/miniforge3/envs/cp311/lib/python3.11/site-packages/huggingface_hub-0.20.3-py3.8.egg']


<torch._C.Generator at 0x7f492ad07f10>

4

In [3]:
# from KevinsRoutines.utils.utils_general import list_namespace, save_to_pickle, load_from_pickle, get_device
import KevinsRoutines.utils as myutils
# import snnl.utils as utils
# from utils.utils_ptsnnl import display_cellpainting_batch, get_device
from utils.utils_cellpainting import label_counts, balance_datasets,save_checkpoint, load_checkpoint
from utils.dataloader import custom_collate_fn, dynamic_collate_fn, CellpaintingDataset, InfiniteDataLoader
from utils.utils_notebooks import plot_cls_metrics, compute_classification_metrics, run_model_on_test_data,\
                                train, validation, accuracy_fn, fit, build_model, define_datasets



In [4]:
myutils.get_device(verbose = True)

Dev Id   Device Name                    Total Memory                     InUse                            Free Memory 
   0     Quadro GV100                   34,069,872,640 B/ (31.73 GB)  	 645,922,816 B / (0.60 GB)  	 33,423,949,824 B / (31.13 GB)   *** CURRENT DEVICE *** 
   1     Quadro GV100                   34,069,872,640 B/ (31.73 GB)  	 645,922,816 B / (0.60 GB)  	 33,423,949,824 B / (31.13 GB)  
   2     NVIDIA TITAN Xp                12,774,539,264 B/ (11.90 GB)  	 392,298,496 B / (0.37 GB)  	 12,382,240,768 B / (11.53 GB)  

 Current CUDA Device is:  "cuda:0"  Device Name: Quadro GV100


'cuda:0'

In [5]:
timestamp = datetime.now().strftime('%Y_%m_%d_%H:%M:%S')
logger = logging.getLogger(__name__)
logLevel = os.environ.get('LOG_LEVEL', 'INFO').upper()
FORMAT = '%(asctime)s - %(name)s - %(levelname)s: - %(message)s'
logging.basicConfig(level="INFO", format= FORMAT)
logger.info(f" Excution started : {timestamp} ")
logger.info(f" Pytorch version  : {torch.__version__}")
logger.info(f" Scipy version    : {scipy.__version__}  \t\t Numpy version : {np.__version__}")
logger.info(f" Pandas version   : {pd.__version__}  ")

2024-10-04 00:06:45,177 - __main__ - INFO: -  Excution started : 2024_10_04_00:06:45 
2024-10-04 00:06:45,178 - __main__ - INFO: -  Pytorch version  : 2.2.0
2024-10-04 00:06:45,179 - __main__ - INFO: -  Scipy version    : 1.11.4  		 Numpy version : 1.26.2
2024-10-04 00:06:45,180 - __main__ - INFO: -  Pandas version   : 2.2.0  


In [6]:
# Set visible GPU device 
# ----------------------------------------------
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

try:
    del model
except Exception as e:
    pass

## Helper routines

# Args 

In [7]:
LATENT_DIM    = 250
COMPOUNDS_PER_BATCH = 600
TPSA_THRESHOLD = 70

# MODEL_TYPE = 'batch_norm'
MODEL_TYPE = 'single_layer'
# MODEL_TYPE = 'relu'
n_input    = LATENT_DIM  # the embedding dimensionality 

n_hidden_1 = 512  # the number of neurons in the hidden layer of the MLP
n_hidden_2 = 512  # the number of neurons in the hidden layer of the MLP
n_hidden_3 = 128

METADATA_COLS = ['Metadata_Source', 'Metadata_Batch', 'Metadata_Plate', 'Metadata_Well', 'Metadata_JCP2022', 'Metadata_Hash', 'Metadata_Bin', 'Metadata_TPSA', 'Metadata_lnTPSA', 'Metadata_log10TPSA', 'Metadata_Permiation']
# METADATA_COLS += [f'Feature_{x:03d}' for x in range(LATENT_DIM)]
input_cols = LATENT_DIM + len(METADATA_COLS)
print(len(METADATA_COLS))
print(input_cols)


INPUT_PATH = f"/home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_embeddings/"
CKPT_PATH = "./saved_models/embedding_models"

11
261


In [8]:
# RUN_DATETIME = datetime.now().strftime('%Y%m%d_%H%M')

# RUN_DATETIME = '20240926_1900'   ## Baseline CPB 600, Latent 150  - Single layer 256
# RUN_DATETIME = '20240927_2300'   ## Baseline CPB 600, Latent 150  - Single layer 512
# RUN_DATETIME = '20240929_2000'   ## Baseline CPB 600, Latent 150  - Batch Norm 256/256/128
# RUN_DATETIME = '20240929_1900'   ## Baseline CPB 600, Latent 150  - Batch Norm 512/512/128

# RUN_DATETIME = '20240930_2100'   ## Baseline CPB 600, Latent 250  - Balanced TPSA labels - Batch Norm 256/256/128
# RUN_DATETIME = '20241001_2100'   ## Baseline CPB 600, Latent 250  - Balanced TPSA labels - Batch Norm 512/512/256

# RUN_DATETIME = '20241002_1915'   ## Baseline CPB 600, Latent 150  - Single layer 256
# RUN_DATETIME = '20241002_1930'   ## Baseline CPB 600, Latent 150  - Single layer 512
# RUN_DATETIME = '20241002_1945'   ## Baseline CPB 600, Latent 150  - Batch Norm 256/256/128
# RUN_DATETIME = '20241002_2000'   ## Baseline CPB 600, Latent 150  - Batch Norm 512/512/128

RUN_DATETIME = '20241003_2330'   ## Baseline CPB 600, Latent 150, Single layer 512
print(RUN_DATETIME)

20241003_2330


In [9]:
# SNNL AUTOENCODERS 
# AE_RUNMODE = "snnl"
# AE_DATETIME = "20240718_1956"
# AE_DATETIME = "20240906_2201"     # Autoencoder training - SNNL, CPB = 600, Latent 150, WD = 0.001, SNN Factor 3
# AE_DATETIME = "20240917_2004"     # Autoencoder training - SNNL, CPB = 600, Latent 250, WD = 0.001, SNN Factor 3

## BASELINE AUTOENCODERS 
AE_RUNMODE = 'base'
# AE_DATETIME = "20240923_1943"     # Autoencoder training - Baseline, CPB = 600, Latent 150, WD = 0.001 (SNN Factor 0)
AE_DATETIME = "20240917_2017"     # Autoencoder training - Baseline, CPB = 600, Latent 250, WD = 0.001 (SNN Factor 0)

# AE_CKPTTYPE = "BEST"
AE_CKPTTYPE = "LAST"

In [10]:
CKPT_FILE = f"NN_{AE_RUNMODE.lower()}_embd600_{LATENT_DIM}Ltnt_512_{AE_DATETIME}_{AE_CKPTTYPE}_{RUN_DATETIME}_ep_{{ep}}"
print(CKPT_FILE)

NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_{ep}


In [11]:
## total rows = 346,542
## Trn file sz: 312,000 
## Train      : 277,200    (312_000 - (21,600 + 12,600 + 600) = 277,200
## Validation :  21,600
## Test       :  12,600
## Leftover   :     600
cellpainting_args = {'compounds_per_batch': COMPOUNDS_PER_BATCH,
                     'train_start'        : 0,
                     'train_end'          : 277_200,
                     'val_start'          : 0,
                     'val_end'            : 21_600,
                     'test_start'         : 0,
                     'test_end'           : 12_600,
                     'tpsa_threshold'     : 100
                    }

In [12]:
data_loader = define_datasets(cellpainting_args, AE_RUNMODE, AE_DATETIME, input_cols, AE_CKPTTYPE, INPUT_PATH, tpsa_threshold=TPSA_THRESHOLD)

2024-10-04 00:08:10,526 - utils.dataloader - INFO: -  Building CellPantingDataset for train
2024-10-04 00:08:10,527 - utils.dataloader - INFO: -  filename:  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_embeddings/3smpl_prfl_embedding_261_HashOrder_base_20240917_2017_LAST_train.csv
2024-10-04 00:08:10,528 - utils.dataloader - INFO: -  type    :  train
2024-10-04 00:08:10,529 - utils.dataloader - INFO: -  start   :  0
2024-10-04 00:08:10,530 - utils.dataloader - INFO: -  end     :  277200
2024-10-04 00:08:10,531 - utils.dataloader - INFO: -  numrows :  277200
2024-10-04 00:08:10,531 - utils.dataloader - INFO: -  names   :  None     usecols :  None
2024-10-04 00:08:10,532 - utils.dataloader - INFO: -  batch_size  :  1
2024-10-04 00:08:10,533 - utils.dataloader - INFO: -  sample_size :  3
2024-10-04 00:08:10,533 - utils.dataloader - INFO: -  compounds_per_batch :  600
2024-10-04 00:08:10,534 - utils.dataloader - INFO: -  rows per batch (chunksize) :  1800
2024-1

 TRAIN_INPUT:  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_embeddings/3smpl_prfl_embedding_261_HashOrder_base_20240917_2017_LAST_train.csv
 TEST_INPUT :  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_embeddings/3smpl_prfl_embedding_261_HashOrder_base_20240917_2017_LAST_train_sub_test.csv
 ALL_INPUT  :  /home/kevin/WSL-shared/cellpainting/cj-datasets/output_11102023/3_sample_embeddings/3smpl_prfl_embedding_261_HashOrder_base_20240917_2017_LAST_train_sub_val.csv
 load {}
 Dataset size: 277200   rows per batch: 1800  tpsa_threshold: 70
 Dataset size: 21600   rows per batch: 1800  tpsa_threshold: 70
 Dataset size: 12600   rows per batch: 1800  tpsa_threshold: 70


### Dataloader

In [16]:
# TRAIN_INPUT_FILE = f"3smpl_prfl_embedding_{input_cols}_HashOrder_{AE_RUNMODE}_{AE_DATETIME}_{AE_CKPTTYPE}_train.csv"
# TEST_INPUT_FILE  = f"3smpl_prfl_embedding_{input_cols}_HashOrder_{AE_RUNMODE}_{AE_DATETIME}_{AE_CKPTTYPE}_train_sub_test.csv"
# VAL_INPUT_FILE   = f"3smpl_prfl_embedding_{input_cols}_HashOrder_{AE_RUNMODE}_{AE_DATETIME}_{AE_CKPTTYPE}_train_sub_val.csv"
# # ALL_INPUT_FILE   = f"3smpl_prfl_embedding_{num_cols}_HashOrder_{AE_RUNMODE}_{AE_DATETIME}_{AE_CKPTTYPE}_sub_val.csv"

# print(TRAIN_INPUT_FILE)
# print(TEST_INPUT_FILE)
# print(VAL_INPUT_FILE)

# TRAIN_INPUT = os.path.join(INPUT_PATH, TRAIN_INPUT_FILE)
# TEST_INPUT  = os.path.join(INPUT_PATH, TEST_INPUT_FILE)
# VAL_INPUT   = os.path.join(INPUT_PATH, VAL_INPUT_FILE)

# print(f" TRAIN_INPUT:  {TRAIN_INPUT}")
# print(f" TEST_INPUT :  {TEST_INPUT }")
# print(f" ALL_INPUT  :  {VAL_INPUT }")

In [17]:
## total rows = 346,542
## Trn file sz: 312,000 
## Train      : 277,200    (312_000 - (21,600 + 12,600 + 600) = 277,200
## Validation :  21,600
## Test       :  12,600
## Leftover   :     600
# cellpainting_args = {'sample_size': 3,
#                      'batch_size': 1,
#                      'compounds_per_batch': 600,
#                      'training_path'  : TRAIN_INPUT,
#                      'validation_path': TRAIN_INPUT,
#                      'test_path'      : TRAIN_INPUT,
#                      'train_start'    : 0,
#                      'train_end'      : 277_200,  # 277,200 samples
#                      'val_start'      : 277_200,  # 
#                      'val_end'        : 298_800,  # 21_600 samples
#                      'test_start'     : 298_800,  # 
#                      'test_end'       : 311_400,  # 12_600 samples
#                     }

# cellpainting_args = {'compounds_per_batch': COMPOUNDS_PER_BATCH,
#                      'training_path'      : TRAIN_INPUT,
#                      'validation_path'    :  VAL_INPUT,
#                      'test_path'          : TEST_INPUT,
#                      'train_start'        : 0,
#                      'train_end'          : 277_200,
#                      'val_start'          : 0,
#                      'val_end'            : 21_600,
#                      'test_start'         : 0,
#                      'test_end'           : 12_600,
#                      'tpsa_threshold'     : 100
#                     }

In [18]:
# cellpainting_args

In [19]:
#### Load CellPainting Dataset
# data : keys to the dataset settings (and resulting keys in output dictionary)
# dataset = dict()
# data_loader = dict()

# for datatype in ['train', 'val', 'test']:
#     dataset[datatype] = CellpaintingDataset(type = datatype, **cellpainting_args)
#     data_loader[datatype] = InfiniteDataLoader(dataset = dataset[datatype], batch_size=1, shuffle = False, num_workers = 0, 
#                                                collate_fn = partial(dynamic_collate_fn, tpsa_threshold = dataset[datatype].tpsa_threshold) )

In [20]:
# def display_cellpainting_batch(batch_id, batch):
#     # data, labels, plates, compounds, cmphash, other, labels_2
#     features, label, well_ids, compound_id, cmphash, tpsa, label_2 = batch
#     # label_2 = np.zeros_like(label)
#     print("-"*135)
#     print(f"  Batch Id: {batch_id}   {type(batch)}  Rows returned {len(batch[0])} features: {features.shape}  ")
#     print(f"+-----+------------------------------------------+----------------+--------------------------+------------------------------+-----+-----+--------------------------------------------------------+")
#     print(f"| idx |   batch[2]                               |    batch[3]    |      batch[2]            |          batch[5]            | [1] | [1] |     batch[0]                                           | ") 
#     print(f"|     | SRCE      BATCH     PLATE     WELL       |   COMPOUND_ID  |       CMPHASH / BIN      |  TPSA / Ln(TPSA) / Log(TPSA) | LBL |LBL2 |     FEATURES                                           | ")
#     print(f"+-----+------------------------------------------+----------------+--------------------------+------------------------------+-----+-----+--------------------------------------------------------+")
#          ###    0 | source_11 Batch2    EC000046  K04      | JCP2022_009278 |  7406361908543180200 -  8  |   0   |   62.78000    4.13964   1.79782 | [-0.4377299 -0.4474466  1.1898487  0.2051901]
#          # "  1 | source_10    | JCP2022_006020 | -9223347314827979542 |   10 |  0 | tensor([-0.6346, -0.6232, -1.6046])"
    
#     for i in range(len(label)):
#         print(f"| {i:3d} | {well_ids[i,0][:9]:9s} {well_ids[i,1][:12]:12s}  {well_ids[i,2][:10]:10s}  {well_ids[i,3]:4s} |"\
#               f" {compound_id[i]:14s} | {cmphash[i,0]:20d}  {cmphash[i,1]:2d} |"\
#               f" {tpsa[i,0]:7.3f}  {tpsa[i,1]:8.5f}  {tpsa[i,2]:8.5f}  |"
#               f" {int(label[i]):2d}  | {int(label_2[i]):2d}  |"\
#               f" {features[i,:4].detach().cpu().numpy()}")
#         # print(f"| {i:3d} | {batch[2][i,0]:9s} {batch[2][i,1][:9]:9s} {str(batch[2][i,2])[:9]:9s} {batch[2][i,3]:>4s}       "\
#         #       f"|{batch[3][i]:12s} | {batch[4][i,0]:20d} - {batch[4][i,1]:2d}  "\
#         #       f"{batch[5][i,0]:11.5f}   {batch[5][i,1]:8.5f}  {batch[5][i,2]:8.5f} "\
#         #       f"|  {int(batch[1][i]):1d}  | {batch[0][i,:4].detach().cpu().numpy()}")


In [21]:
# # %%timeit
# # for dataset in ['train', 'val', 'test']:
# for dataset in ['test']:
#     for idx, batch in enumerate(data_loader[dataset]):
#         for b in batch :
#             print(b.shape)
#         display_cellpainting_batch(idx, batch)
#         if idx == 0:
#             break

In [22]:
# # -----------------------------------------
# #  Count pos/neg labels in each dataset
# # -----------------------------------------
# for datatype in ['train', 'val', 'test']:
#     MINIBATCH_SIZE = data_loader[datatype].dataset.sample_size * data_loader[datatype].dataset.compounds_per_batch
#     print(f" {datatype.capitalize()} Minibatch size : {MINIBATCH_SIZE}") 
# print()

# for datatype in ['train', 'val', 'test']:
#     minibatches = len(data_loader[datatype]) // MINIBATCH_SIZE
#     ttl_rows, ttl_rows_2 = 0, 0
#     ttl_pos_labels, ttl_pos_labels_2 = 0, 0
#     with tqdm.tqdm(enumerate(data_loader[datatype]), initial=0, total = minibatches, position=0, file=sys.stdout,
#                    leave= False, desc=f" Count labels ") as t_warmup:
#         for batch_count, (_, batch_labels, _, _, _, _, batch_labels_2) in t_warmup:
#             ttl_rows += batch_labels.shape[0]
#             ttl_rows_2 += batch_labels_2.shape[0]
#             ttl_pos_labels += batch_labels.sum()
#             ttl_pos_labels_2 += batch_labels_2.sum()
#     ttl_neg_labels = ttl_rows - ttl_pos_labels
#     ttl_neg_labels_2 = ttl_rows_2 - ttl_pos_labels_2
#     ttl = f"\n Dataset: {datatype} -  len of {datatype} data loader: {len(data_loader[datatype])}   number of batches: {minibatches}"
#     print(ttl)
#     print('-'*len(ttl))
#     print(f" total rows     : {ttl_rows:7d}")
#     print(f" total pos rows : {ttl_pos_labels:7.0f} - {ttl_pos_labels*100.0/ttl_rows:5.2f}%         alternative pos rows : {ttl_pos_labels_2:7.0f} - {ttl_pos_labels_2*100.0/ttl_rows:5.2f}%      ")
#     print(f" total neg rows : {ttl_neg_labels:7.0f} - {ttl_neg_labels*100.0/ttl_rows:5.2f}%         alternative neg rows : {ttl_neg_labels_2:7.0f} - {ttl_neg_labels_2*100.0/ttl_rows:5.2f}%")
#     print()

     Minibatch size : 1800 
                                                                                                 
     Dataset: train - len of train data loader: 277200   number of batches: 154  
    ------------------------------
     total rows     :  277200
     total pos rows :   33129 - 11.95%
     total neg rows :  244071 - 88.05%

     Dataset: val - len of val data loader: 21600   number of batches: 12
    ------------------------------
     total rows     :   21600
     total pos rows :    2532 - 11.72%
     total neg rows :   19068 - 88.28%
    
     Dataset: test - len of test data loader: 12600   number of batches: 7
    ------------------------------
     total rows     :   12600
     total pos rows :    1431 - 11.36%
     total neg rows :   11169 - 88.64%

# Define Neural Net Model 

- **4 layer model :**

    Input --> Hidden1 --> (BN/NL) ---> Hidden2 ---> (BN/NL) ---> Hidden3 --->  (BN/NL) ---> 1
   
    -  **20240909_1800** : Run on 4 FC layers model (includes final layer), model configuration UNKNOWN
    -  **20240909_1801** : Run on 4 FC layers model (includes final layer), Relu non linearities (NO Batch Norm)
    -  **20240909_2100** : Run on 4 FC layers model (includes final layer), with BATCH NORM and tanh non linearities

      
 - **Single Hidden Layer - 256**

   Input --> Hidden1 --> (Tanh) --->  1
    -  **20240916_1830** : Run on 1 FC layers model (includes final layer), Input --> 256 --> Tanh --> 1 ,  Read from 20240906_2201 (SNNL - CPB 600, LAT 150, SNN Factor 3)
    -  **20240926_1900** : Run on 1 FC layers model (includes final layer), Input --> 256 --> Tanh --> 1 ,  Read from 20240917_2017 (BASELINE - CPB 600, LAT 250, SNN Factor 0)
    -  **20240926_1930** : Run on 1 FC layers model (includes final layer), Input --> 256 --> Tanh --> 1 ,  Read from 20240917_2004 (SNNL - CPB 600, LAT 250, SNN Factor 3)
    -  **20240926_2000** : Run on 1 FC layers model (includes final layer), Input --> 256 --> Tanh --> 1 ,  Read from 20240924_0146 (SNNL - CPB 600, LAT 250, SNN Factor 30)
<br>

 - **Single Hidden Layer - 256**

    -  **20240921_0700** : Run on 1 FC layers model (includes final layer), Input --> 512 --> Tanh --> 1 ,  Read from 20240906_2201 (SNNL - CPB 600, LAT 150, SNN Factor 3)    


In [13]:

model = build_model(MODEL_TYPE, input = n_input, hidden_1 = n_hidden_1, hidden_2 = n_hidden_2, hidden_3=n_hidden_3, device = device)
 

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Param %                   Mult-Adds                 Trainable
Sequential                               [30, 250]                 [30, 1]                   --                             --                   --                        True
├─Linear: 1-1                            [30, 250]                 [30, 512]                 128,512                    99.60%                   3,855,360                 True
│    └─weight                                                                                ├─128,000
│    └─bias                                                                                  └─512
├─Tanh: 1-2                              [30, 512]                 [30, 512]                 --                             --                   --                        --
├─Linear: 1-3                            [30, 512]                 [30, 1]                 

In [14]:
metrics = { 'loss_trn' : [], 'acc_trn' : [], 'loss_val' : [], 'acc_val' : []}

start_epoch, end_epoch = 0,0
init_LR = 1.0e-3
# curr_LR = init_LR

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=init_LR)

# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.3 , patience=20, cooldown=10,)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = step_size, gamma=0.1, last_epoch =-1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.5, threshold=1.0e-06, patience=50, cooldown=10,)

### Read checkpoint

In [15]:
# loaded_epoch
# optimizer.state_dict()
# scheduler.state_dict()

In [16]:
# model, optimizer, scheudler, end_epoch = load_checkpoint(model, optimizer, scheduler, checkpoint_file.format(ep=100), ckpt_path = CKPT_PATH)
# model = model.to(device)

In [17]:
# end_epoch
# optimizer.state_dict()
# scheduler.state_dict()

# Run Training

In [18]:
# start_epoch = 0
# start_epoch = loaded_epoch
start_epoch = end_epoch
end_epoch += 1200
# start_epoch, end_epoch = 0,100
print(start_epoch, end_epoch)
_ = model.train()

0 1200


In [19]:

metrics = fit(model, optimizer, scheduler, data_loader, metrics, start_epoch, end_epoch, device, CKPT_FILE, CKPT_PATH )


 00:08:59 | Ep:   1/1200 | Trn loss:  0.693587 - Acc: 52.4481 | Val loss:  0.691070 - Acc: 52.5648 | last_lr: 1.00000e-03  bad_ep: 0  cdwn: 0                              
 00:09:21 | Ep:   2/1200 | Trn loss:  0.690575 - Acc: 53.1176 | Val loss:  0.691569 - Acc: 52.4213 | last_lr: 1.00000e-03  bad_ep: 1  cdwn: 0                              
 00:09:43 | Ep:   3/1200 | Trn loss:  0.689776 - Acc: 53.3770 | Val loss:  0.692922 - Acc: 51.8194 | last_lr: 1.00000e-03  bad_ep: 2  cdwn: 0                              
 00:10:05 | Ep:   4/1200 | Trn loss:  0.689032 - Acc: 53.5599 | Val loss:  0.693175 - Acc: 51.6111 | last_lr: 1.00000e-03  bad_ep: 3  cdwn: 0                              
 00:10:28 | Ep:   5/1200 | Trn loss:  0.688526 - Acc: 53.7529 | Val loss:  0.693327 - Acc: 51.6065 | last_lr: 1.00000e-03  bad_ep: 4  cdwn: 0                              
 00:10:50 | Ep:   6/1200 | Trn loss:  0.688111 - Acc: 53.8470 | Val loss:  0.693266 - Acc: 51.6898 | last_lr: 1.00000e-03  bad_ep: 5  cdwn: 

2024-10-04 00:45:13,787 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_100.pt - epoch: 100


 00:45:13 | Ep: 100/1200 | Trn loss:  0.640428 - Acc: 62.6465 | Val loss:  0.727999 - Acc: 51.7454 | last_lr: 5.00000e-04  bad_ep: 19  cdwn: 0 
 00:45:35 | Ep: 101/1200 | Trn loss:  0.640290 - Acc: 62.6613 | Val loss:  0.728224 - Acc: 51.7222 | last_lr: 5.00000e-04  bad_ep: 20  cdwn: 0                             
 00:45:57 | Ep: 102/1200 | Trn loss:  0.640151 - Acc: 62.6595 | Val loss:  0.728439 - Acc: 51.6667 | last_lr: 5.00000e-04  bad_ep: 21  cdwn: 0                             
 00:46:19 | Ep: 103/1200 | Trn loss:  0.640015 - Acc: 62.6670 | Val loss:  0.728664 - Acc: 51.6528 | last_lr: 5.00000e-04  bad_ep: 22  cdwn: 0                             
 00:46:41 | Ep: 104/1200 | Trn loss:  0.639881 - Acc: 62.6894 | Val loss:  0.728880 - Acc: 51.6528 | last_lr: 5.00000e-04  bad_ep: 23  cdwn: 0                             
 00:47:03 | Ep: 105/1200 | Trn loss:  0.639748 - Acc: 62.7100 | Val loss:  0.729096 - Acc: 51.6250 | last_lr: 5.00000e-04  bad_ep: 24  cdwn: 0                          

2024-10-04 01:22:07,543 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_200.pt - epoch: 200


 01:22:07 | Ep: 200/1200 | Trn loss:  0.629026 - Acc: 63.9372 | Val loss:  0.733161 - Acc: 52.4491 | last_lr: 1.25000e-04  bad_ep: 0  cdwn: 3 
 01:22:29 | Ep: 201/1200 | Trn loss:  0.628999 - Acc: 63.9358 | Val loss:  0.733231 - Acc: 52.4537 | last_lr: 1.25000e-04  bad_ep: 0  cdwn: 2                              
 01:22:52 | Ep: 202/1200 | Trn loss:  0.628973 - Acc: 63.9405 | Val loss:  0.733299 - Acc: 52.4583 | last_lr: 1.25000e-04  bad_ep: 0  cdwn: 1                              
 01:23:14 | Ep: 203/1200 | Trn loss:  0.628949 - Acc: 63.9441 | Val loss:  0.733365 - Acc: 52.4444 | last_lr: 1.25000e-04  bad_ep: 0  cdwn: 0                              
 01:23:36 | Ep: 204/1200 | Trn loss:  0.628924 - Acc: 63.9398 | Val loss:  0.733430 - Acc: 52.4352 | last_lr: 1.25000e-04  bad_ep: 1  cdwn: 0                              
 01:23:58 | Ep: 205/1200 | Trn loss:  0.628901 - Acc: 63.9462 | Val loss:  0.733492 - Acc: 52.4074 | last_lr: 1.25000e-04  bad_ep: 2  cdwn: 0                            

2024-10-04 01:59:03,095 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_300.pt - epoch: 300


 01:59:03 | Ep: 300/1200 | Trn loss:  0.626415 - Acc: 64.2605 | Val loss:  0.738077 - Acc: 52.2778 | last_lr: 6.25000e-05  bad_ep: 36  cdwn: 0 
 01:59:25 | Ep: 301/1200 | Trn loss:  0.626406 - Acc: 64.2644 | Val loss:  0.738105 - Acc: 52.2685 | last_lr: 6.25000e-05  bad_ep: 37  cdwn: 0                             
 01:59:47 | Ep: 302/1200 | Trn loss:  0.626397 - Acc: 64.2626 | Val loss:  0.738133 - Acc: 52.2546 | last_lr: 6.25000e-05  bad_ep: 38  cdwn: 0                             
 02:00:10 | Ep: 303/1200 | Trn loss:  0.626389 - Acc: 64.2655 | Val loss:  0.738161 - Acc: 52.2454 | last_lr: 6.25000e-05  bad_ep: 39  cdwn: 0                             
 02:00:32 | Ep: 304/1200 | Trn loss:  0.626381 - Acc: 64.2652 | Val loss:  0.738188 - Acc: 52.2407 | last_lr: 6.25000e-05  bad_ep: 40  cdwn: 0                             
 02:00:54 | Ep: 305/1200 | Trn loss:  0.626372 - Acc: 64.2655 | Val loss:  0.738216 - Acc: 52.2222 | last_lr: 6.25000e-05  bad_ep: 41  cdwn: 0                          

2024-10-04 02:36:08,947 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_400.pt - epoch: 400


 02:36:08 | Ep: 400/1200 | Trn loss:  0.625142 - Acc: 64.4037 | Val loss:  0.738624 - Acc: 52.3519 | last_lr: 1.56250e-05  bad_ep: 14  cdwn: 0 
 02:36:31 | Ep: 401/1200 | Trn loss:  0.625140 - Acc: 64.4048 | Val loss:  0.738631 - Acc: 52.3519 | last_lr: 1.56250e-05  bad_ep: 15  cdwn: 0                             
 02:36:53 | Ep: 402/1200 | Trn loss:  0.625138 - Acc: 64.4051 | Val loss:  0.738638 - Acc: 52.3519 | last_lr: 1.56250e-05  bad_ep: 16  cdwn: 0                             
 02:37:16 | Ep: 403/1200 | Trn loss:  0.625136 - Acc: 64.4044 | Val loss:  0.738645 - Acc: 52.3565 | last_lr: 1.56250e-05  bad_ep: 17  cdwn: 0                             
 02:37:38 | Ep: 404/1200 | Trn loss:  0.625134 - Acc: 64.4058 | Val loss:  0.738652 - Acc: 52.3565 | last_lr: 1.56250e-05  bad_ep: 18  cdwn: 0                             
 02:38:00 | Ep: 405/1200 | Trn loss:  0.625132 - Acc: 64.4040 | Val loss:  0.738659 - Acc: 52.3611 | last_lr: 1.56250e-05  bad_ep: 19  cdwn: 0                          

2024-10-04 03:13:30,721 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_500.pt - epoch: 500


 03:13:30 | Ep: 500/1200 | Trn loss:  0.624746 - Acc: 64.4380 | Val loss:  0.739106 - Acc: 52.3843 | last_lr: 3.90625e-06  bad_ep: 0  cdwn: 8 
 03:13:52 | Ep: 501/1200 | Trn loss:  0.624745 - Acc: 64.4372 | Val loss:  0.739108 - Acc: 52.3935 | last_lr: 3.90625e-06  bad_ep: 0  cdwn: 7                              
 03:14:14 | Ep: 502/1200 | Trn loss:  0.624745 - Acc: 64.4369 | Val loss:  0.739109 - Acc: 52.3843 | last_lr: 3.90625e-06  bad_ep: 0  cdwn: 6                              
 03:14:37 | Ep: 503/1200 | Trn loss:  0.624745 - Acc: 64.4394 | Val loss:  0.739111 - Acc: 52.3843 | last_lr: 3.90625e-06  bad_ep: 0  cdwn: 5                              
 03:14:59 | Ep: 504/1200 | Trn loss:  0.624744 - Acc: 64.4394 | Val loss:  0.739112 - Acc: 52.3843 | last_lr: 3.90625e-06  bad_ep: 0  cdwn: 4                              
 03:15:21 | Ep: 505/1200 | Trn loss:  0.624744 - Acc: 64.4401 | Val loss:  0.739113 - Acc: 52.3796 | last_lr: 3.90625e-06  bad_ep: 0  cdwn: 3                            

2024-10-04 03:50:51,701 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_600.pt - epoch: 600


 03:50:51 | Ep: 600/1200 | Trn loss:  0.624666 - Acc: 64.4408 | Val loss:  0.739221 - Acc: 52.3796 | last_lr: 1.95313e-06  bad_ep: 31  cdwn: 0 
 03:51:14 | Ep: 601/1200 | Trn loss:  0.624665 - Acc: 64.4405 | Val loss:  0.739223 - Acc: 52.3796 | last_lr: 1.95313e-06  bad_ep: 32  cdwn: 0                             
 03:51:37 | Ep: 602/1200 | Trn loss:  0.624665 - Acc: 64.4408 | Val loss:  0.739224 - Acc: 52.3796 | last_lr: 1.95313e-06  bad_ep: 33  cdwn: 0                             
 03:51:59 | Ep: 603/1200 | Trn loss:  0.624665 - Acc: 64.4394 | Val loss:  0.739225 - Acc: 52.3796 | last_lr: 1.95313e-06  bad_ep: 34  cdwn: 0                             
 03:52:21 | Ep: 604/1200 | Trn loss:  0.624664 - Acc: 64.4394 | Val loss:  0.739227 - Acc: 52.3796 | last_lr: 1.95313e-06  bad_ep: 35  cdwn: 0                             
 03:52:44 | Ep: 605/1200 | Trn loss:  0.624664 - Acc: 64.4394 | Val loss:  0.739228 - Acc: 52.3796 | last_lr: 1.95313e-06  bad_ep: 36  cdwn: 0                          

2024-10-04 04:28:29,899 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_700.pt - epoch: 700


 04:28:29 | Ep: 700/1200 | Trn loss:  0.624614 - Acc: 64.4372 | Val loss:  0.739293 - Acc: 52.3704 | last_lr: 4.88281e-07  bad_ep: 9  cdwn: 0 
 04:28:52 | Ep: 701/1200 | Trn loss:  0.624613 - Acc: 64.4376 | Val loss:  0.739294 - Acc: 52.3704 | last_lr: 4.88281e-07  bad_ep: 10  cdwn: 0                             
 04:29:15 | Ep: 702/1200 | Trn loss:  0.624613 - Acc: 64.4376 | Val loss:  0.739294 - Acc: 52.3704 | last_lr: 4.88281e-07  bad_ep: 11  cdwn: 0                             
 04:29:38 | Ep: 703/1200 | Trn loss:  0.624613 - Acc: 64.4376 | Val loss:  0.739294 - Acc: 52.3704 | last_lr: 4.88281e-07  bad_ep: 12  cdwn: 0                             
 04:30:00 | Ep: 704/1200 | Trn loss:  0.624613 - Acc: 64.4376 | Val loss:  0.739295 - Acc: 52.3704 | last_lr: 4.88281e-07  bad_ep: 13  cdwn: 0                             
 04:30:23 | Ep: 705/1200 | Trn loss:  0.624613 - Acc: 64.4376 | Val loss:  0.739295 - Acc: 52.3704 | last_lr: 4.88281e-07  bad_ep: 14  cdwn: 0                           

2024-10-04 06:22:22,536 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_1000.pt - epoch: 1000


 06:22:22 | Ep: 1000/1200 | Trn loss:  0.624595 - Acc: 64.4419 | Val loss:  0.739330 - Acc: 52.3611 | last_lr: 1.52588e-08  bad_ep: 4  cdwn: 0 
 06:22:45 | Ep: 1001/1200 | Trn loss:  0.624595 - Acc: 64.4419 | Val loss:  0.739330 - Acc: 52.3611 | last_lr: 1.52588e-08  bad_ep: 5  cdwn: 0                             
 06:23:08 | Ep: 1002/1200 | Trn loss:  0.624595 - Acc: 64.4419 | Val loss:  0.739330 - Acc: 52.3611 | last_lr: 1.52588e-08  bad_ep: 6  cdwn: 0                             
 06:23:31 | Ep: 1003/1200 | Trn loss:  0.624594 - Acc: 64.4416 | Val loss:  0.739330 - Acc: 52.3611 | last_lr: 1.52588e-08  bad_ep: 7  cdwn: 0                             
 06:23:54 | Ep: 1004/1200 | Trn loss:  0.624595 - Acc: 64.4419 | Val loss:  0.739331 - Acc: 52.3611 | last_lr: 1.52588e-08  bad_ep: 8  cdwn: 0                             
 06:26:35 | Ep: 1011/1200 | Trn loss:  0.624595 - Acc: 64.4426 | Val loss:  0.739331 - Acc: 52.3565 | last_lr: 1.52588e-08  bad_ep: 15  cdwn: 0                         

2024-10-04 07:00:49,762 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_1100.pt - epoch: 1100


 07:00:49 | Ep: 1100/1200 | Trn loss:  0.624594 - Acc: 64.4452 | Val loss:  0.739335 - Acc: 52.3519 | last_lr: 1.52588e-08  bad_ep: 43  cdwn: 0 
 07:01:12 | Ep: 1101/1200 | Trn loss:  0.624594 - Acc: 64.4455 | Val loss:  0.739335 - Acc: 52.3519 | last_lr: 1.52588e-08  bad_ep: 44  cdwn: 0                            
 07:01:35 | Ep: 1102/1200 | Trn loss:  0.624594 - Acc: 64.4455 | Val loss:  0.739335 - Acc: 52.3519 | last_lr: 1.52588e-08  bad_ep: 45  cdwn: 0                            
 07:04:39 | Ep: 1110/1200 | Trn loss:  0.624594 - Acc: 64.4462 | Val loss:  0.739335 - Acc: 52.3519 | last_lr: 1.52588e-08  bad_ep: 0  cdwn: 8                             
 07:05:03 | Ep: 1111/1200 | Trn loss:  0.624594 - Acc: 64.4462 | Val loss:  0.739335 - Acc: 52.3519 | last_lr: 1.52588e-08  bad_ep: 0  cdwn: 7                             
 07:05:26 | Ep: 1112/1200 | Trn loss:  0.624594 - Acc: 64.4462 | Val loss:  0.739335 - Acc: 52.3472 | last_lr: 1.52588e-08  bad_ep: 0  cdwn: 6                         

2024-10-04 07:39:11,959 - utils.utils_cellpainting - INFO: -  Model exported to NN_base_embd600_250Ltnt_512_20240917_2017_LAST_20241003_2330_ep_1200.pt - epoch: 1200


 07:39:11 | Ep: 1200/1200 | Trn loss:  0.624594 - Acc: 64.4477 | Val loss:  0.739336 - Acc: 52.3519 | last_lr: 1.52588e-08  bad_ep: 21  cdwn: 0 


In [56]:
# print(filename)
 
# save_checkpoint(end_epoch, model, optimizer, scheduler, metrics = metrics,
#                 filename = CKPT_FILE.format(ep=end_epoch),
#                 ckpt_path = CKPT_PATH, verbose = True)

start_epoch, end_epoch

# for mtrc in ['loss_trn', 'loss_val']:
#     for i in range(len(metrics[mtrc])):
#         # print(i)
#         metrics[mtrc][i] = metrics[mtrc][i].item()

(600, 1200)

In [60]:
#         metrics['loss_trn'].append(trn_loss.item())
#         metrics['acc_trn'].append(trn_acc)
#         metrics['loss_val'].append(val_loss.item())
#         metrics['acc_val'].append(val_acc)
for idx, (trn_loss, trn_acc, val_loss, val_acc) in enumerate(zip(metrics['loss_trn'],metrics['acc_trn'],metrics['loss_val'],metrics['acc_val'])):
    print(f" {datetime.now().strftime('%X')} | Ep: {idx:3d}/{end_epoch:4d} | Trn loss: {trn_loss:9.6f} - Acc: {trn_acc:.4f} |"
      f" Val loss: {val_loss:9.6f} - Acc: {val_acc:.4f} | ")

 22:18:45 | Ep:   0/1200 | Trn loss:  0.693725 - Acc: 52.5841 | Val loss:  0.690645 - Acc: 53.0417 | 
 22:18:45 | Ep:   1/1200 | Trn loss:  0.689889 - Acc: 53.3045 | Val loss:  0.690656 - Acc: 53.0139 | 
 22:18:45 | Ep:   2/1200 | Trn loss:  0.689043 - Acc: 53.6216 | Val loss:  0.690491 - Acc: 53.0833 | 
 22:18:45 | Ep:   3/1200 | Trn loss:  0.688419 - Acc: 53.7720 | Val loss:  0.690325 - Acc: 52.9306 | 
 22:18:45 | Ep:   4/1200 | Trn loss:  0.687879 - Acc: 53.9459 | Val loss:  0.690166 - Acc: 52.8472 | 
 22:18:45 | Ep:   5/1200 | Trn loss:  0.687373 - Acc: 54.0350 | Val loss:  0.690060 - Acc: 53.0231 | 
 22:18:45 | Ep:   6/1200 | Trn loss:  0.686883 - Acc: 54.2623 | Val loss:  0.690065 - Acc: 53.2083 | 
 22:18:45 | Ep:   7/1200 | Trn loss:  0.686383 - Acc: 54.3932 | Val loss:  0.690107 - Acc: 53.1852 | 
 22:18:45 | Ep:   8/1200 | Trn loss:  0.685847 - Acc: 54.5220 | Val loss:  0.690244 - Acc: 53.1343 | 
 22:18:45 | Ep:   9/1200 | Trn loss:  0.685288 - Acc: 54.7027 | Val loss:  0.69035