In [None]:
!pwd

In [None]:
# Load paths for using psana
%env SIT_ROOT=/reg/g/psdm/
%env SIT_DATA=/cds/group/psdm/data/
%env SIT_PSDM_DATA=/cds/data/psdm/

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import random

from functools   import reduce
from collections import OrderedDict

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import socket
import pickle
import tqdm
import logging

In [None]:
from deepprojection.model import Shi2019Model
from deepprojection.encoders.convnet import Shi2019
from deepprojection.trainer          import SimpleTrainer      , ConfigTrainer
from deepprojection.validator        import SimpleValidator, ConfigValidator

In [None]:
from deepprojection.datasets.lite import SPIDataset, SPIOnlineDataset
from deepprojection.utils import MetaLog, init_logger, split_dataset, set_seed, NNSize, TorchModelAttributeParser, Config, EpochManager

In [None]:
from datetime import datetime
from image_preprocess_faulty import DatasetPreprocess

In [None]:
# [[[ SEED ]]]
seed = 0
set_seed(seed)

In [None]:
# [[[ CONFIG ]]]
timestamp_prev = None
## timestamp_prev = "2022_1129_2150_15"

frac_train     = 0.5
frac_validate  = 0.5

logs_triplets = True

lr = 1e-3
## lr = 5e-4

## alpha = 0.02
## alpha = 0.03336201
alpha = 0.05565119
## alpha = 0.09283178
## alpha = 0.15485274
## alpha = 0.25830993
## alpha = 0.43088694
## alpha = 0.71876273
## alpha = 1.1989685
## alpha = 2.0

size_sample_per_class_train    = 60
## size_sample_per_class_train    = 10
## size_sample_per_class_train    = 20
## size_sample_per_class_train    = 40
## size_sample_per_class_train    = 60
size_sample_train              = size_sample_per_class_train * 100
size_sample_validate           = size_sample_train // 2
size_sample_per_class_validate = size_sample_per_class_train // 2
size_batch                     = 20
trans                          = None

# [[[ LOGGING ]]]
timestamp = init_logger(log_name = 'train', returns_timestamp = True, saves_log = True)
print(timestamp)

# Clarify the purpose of this experiment...
hostname = socket.gethostname()
comments = f"""
            Hostname: {hostname}.

            Sample size (train)               : {size_sample_train}
            Sample size (validate)            : {size_sample_validate}
            Sample size (per class, train)    : {size_sample_per_class_train}
            Sample size (per class, validate) : {size_sample_per_class_validate}
            Batch  size                       : {size_batch}
            Alpha                             : {alpha}
            lr                                : {lr}
            seed                              : {seed}

            """

# Create a metalog to the log file, explaining the purpose of this run...
metalog = MetaLog( comments = comments )
metalog.report()


# [[[ DATASET ]]]
# Set up parameters for an experiment...
drc_dataset   = 'fastdata'
fl_dataset    = '0000.binary.fastdata'    # Raw, just give it a try
path_dataset  = os.path.join(drc_dataset, fl_dataset)

# Load raw data...
with open(path_dataset, 'rb') as fh:
    dataset_list = pickle.load(fh)

# Split data...
data_train   , data_val_and_test = split_dataset(dataset_list     , frac_train   , seed = None)
data_validate, data_test         = split_dataset(data_val_and_test, frac_validate, seed = None)

# Define the training set
dataset_train = SPIOnlineDataset( dataset_list          = data_train, 
                                  size_sample           = size_sample_train,
                                  size_sample_per_class = size_sample_per_class_train, 
                                  trans                 = trans, 
                                  seed                  = None, )
dataset_train.report()

# Define the training set
dataset_validate = SPIOnlineDataset( dataset_list          = data_validate, 
                                     size_sample           = size_sample_train,
                                     size_sample_per_class = size_sample_per_class_validate, 
                                     trans                 = trans, 
                                     seed                  = None, )
dataset_validate.report()

In [None]:
dataset_list[0]

### Preprocess

In [None]:
# Preprocess dataset...
# Data preprocessing can be lengthy and defined in dataset_preprocess.py
img_orig            = dataset_train[0][0][0]   # idx, fetch img
dataset_preproc     = DatasetPreprocess(img_orig)
trans               = dataset_preproc.config_trans()
dataset_train.trans = trans
dataset_validate.trans = trans
img_trans           = dataset_train[0][0][0]

In [None]:
dataset_train.cache_dataset()
dataset_validate.cache_dataset()

### Define model

In [None]:
class Shi2019(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        
        size_y, size_x = config.size_y, config.size_x
        isbias         = config.isbias

        # Define the feature extraction layer...
        in_channels = 1
        self.feature_extractor = nn.Sequential(
            # Motif network 1
            nn.Conv2d( in_channels  = in_channels,
                       out_channels = 5,
                       kernel_size  = 5,
                       stride       = 1,
                       padding      = 0,
                       bias         = isbias, ),
            nn.BatchNorm2d( num_features = 5 ),
            nn.PReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d( kernel_size = 2, 
                          stride = 2 ),
            
            # Motif network 2
            nn.Conv2d( in_channels  = 5,
                       out_channels = 3,
                       kernel_size  = 3,
                       stride       = 1,
                       padding      = 0,
                       bias         = isbias, ),
            nn.BatchNorm2d( num_features = 3 ),
            nn.PReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d( kernel_size = 2, 
                          stride = 2 ),
            
            # Motif network 3
            nn.Conv2d( in_channels  = 3,
                       out_channels = 2,
                       kernel_size  = 3,
                       stride       = 1,
                       padding      = 0,
                       bias         = isbias, ),
            nn.BatchNorm2d( num_features = 2 ),
            nn.PReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d( kernel_size = 2, 
                          stride = 2 ),
        )
        
        # Fetch all input arguments that define the layer...
        attr_parser = TorchModelAttributeParser()
        conv_dict = {}
        for layer_name, model in self.feature_extractor.named_children():
            conv_dict[layer_name] = attr_parser.parse(model)
        
        # Calculate the output size...
        self.feature_size = reduce(lambda x, y: x * y, NNSize(size_y, size_x, in_channels, conv_dict).shape())
        
        self.squash_to_prob = nn.Sequential(
            nn.Linear( in_features = self.feature_size,
                       out_features = 2,
                       bias = isbias ),
            nn.PReLU(),
            nn.Dropout(0.2),
            nn.Linear( in_features = 2,
                       out_features = 1,
                       bias = isbias ),
            nn.Sigmoid(),
        )    


    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(-1, self.feature_size)
        x = self.squash_to_prob(x)
        
        return x

In [None]:
class ConfigModel:

    def __init__(self, **kwargs):
        # logger.info(f"___/ Configure Model \___")

        # Set values of attributes that are not known when obj is created
        for k, v in kwargs.items():
            setattr(self, k, v)
            # logger.info(f"KV - {k:16s} : {v}")

In [None]:
size_y, size_x = img_trans.shape
config_model = ConfigModel(size_y = size_y, size_x = size_x, isbias = True)

In [None]:
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

In [None]:
model = Shi2019(config_model)
model.to(device)

In [None]:
model.forward(torch.tensor(img_trans[None, None]).to(device = device, dtype = torch.float))

### Load model

In [None]:
# [[[ IMAGE ENCODER ]]]
# Config the encoder...
size_y, size_x = img_trans.shape[-2:]
config_encoder = Config( name   = "Shi2019",
                         size_y = size_y,
                         size_x = size_x,
                         isbias = True )
encoder = Shi2019(config_encoder)


# [[[ MODEL ]]]
# Config the model...
config_model = Config( name = "Model", encoder = encoder, )
model = Shi2019Model(config_model)
model.init_params(from_timestamp = timestamp_prev)

### Config trainer and validator

In [None]:
# [[[ CHECKPOINT ]]]
drc_cwd          = os.getcwd()
DRCCHKPT         = "chkpts"
prefixpath_chkpt = os.path.join(drc_cwd, DRCCHKPT)
fl_chkpt         = f"{timestamp}.train.chkpt"
path_chkpt       = os.path.join(prefixpath_chkpt, fl_chkpt)


# [[[ TRAINER ]]]
# Config the trainer...
config_train = ConfigTrainer( path_chkpt        = path_chkpt,
                              num_workers       = 1,
                              batch_size        = size_batch,
                              pin_memory        = True,
                              shuffle           = False,
                              lr                = lr, 
                              tqdm_disable      = True)
trainer = SimpleTrainer(model, dataset_train, config_train)


# [[[ VALIDATOR ]]]
config_validator = ConfigValidator( path_chkpt        = None,
                                    num_workers       = 1,
                                    batch_size        = size_batch,
                                    pin_memory        = True,
                                    shuffle           = False,
                                    lr                = lr,
                                    tqdm_disable      = True)  # Conv2d input needs one more dim for batch
validator = SimpleValidator(model, dataset_validate, config_validator)

### Training epochs

In [None]:
loss_train_hist = []
loss_validate_hist = []
loss_min_hist = []

# [[[ EPOCH MANAGER ]]]
epoch_manager = EpochManager( trainer   = trainer,
                              validator = validator,
                              timestamp = timestamp, )

# epoch_manager.set_layer_to_capture(
#     module_name_capture_list  = ["final_conv"],
#     module_layer_capture_list = [torch.nn.ReLU],
# )

In [None]:
epoch_manager.loss_min = float('inf')

In [None]:
max_epochs = 1000
freq_save = 5
for epoch in tqdm.tqdm(range(max_epochs), disable=False):
    loss_train, loss_validate, loss_min = epoch_manager.run_one_epoch(epoch = epoch, returns_loss = True)
    
    loss_train_hist.append(loss_train)
    loss_validate_hist.append(loss_validate)
    loss_min_hist.append(loss_min)

    # if epoch % freq_save == 0: 
    #     epoch_manager.save_model_parameters()
    #     epoch_manager.save_model_gradients()
    #     epoch_manager.save_state_dict()

In [None]:
dataset_train[0]