In [1]:
import os
import pickle
import numpy as np
import torch
import random
import tqdm

In [2]:
import matplotlib              as mpl
import matplotlib.pyplot       as plt
import matplotlib.colors       as mcolors
import matplotlib.patches      as mpatches
import matplotlib.transforms   as mtransforms
import matplotlib.font_manager as font_manager
%matplotlib inline

In [3]:
from aug import PadBottomRight, \
                Crop,           \
                RandomCrop,     \
                Resize,         \
                RandomShift,    \
                RandomRotate,   \
                RandomPatch

In [4]:
from behenate_net.dataset   import BeHenateDataset
from behenate_net.model     import ConfigModel, BeHenataNet
from behenate_net.trainer   import ConfigTrainer, Trainer
from behenate_net.validator import ConfigValidator, Validator
from behenate_net.utils     import EpochManager, split_dataset, set_seed, init_logger

In [5]:
seed = 0
set_seed(seed)

In [6]:
batch_size = 200
lr = 1e-3
frac_train = 0.8
frac_validate = 0.5

In [7]:
timestamp = init_logger(log_name = 'train', returns_timestamp = True)
print(timestamp)

2023_0313_1707_09


In [8]:
# timestamp_prev = "2023_0313_1255_23"
# epoch = 360
# fl_chkpt = f"{timestamp_prev}.epoch_{epoch}.chkpt"

In [9]:
path_pickle = "beam_center.pickle"
with open(path_pickle, 'rb') as handle:
    data_list = pickle.load(handle)

# Split data...
data_train   , data_val_and_test = split_dataset(data_list        , frac_train   , seed = None)
data_validate, data_test         = split_dataset(data_val_and_test, frac_validate, seed = None)

In [10]:
size_sample = 4000
size_img_y, size_img_x = (64, 64)
normalizes_data = True
trans_list = (
    PadBottomRight(size_y = 2000, size_x = 2000),
    # Crop((940, 960), (200, 200)),
    RandomCrop(center_shift_max = (20, 20), crop_window_size = (1000, 1000)),
    Resize(size_img_y, size_img_x),
    RandomShift(0.1, 0.1),
    RandomRotate(angle_max = 90),
    RandomPatch(num_patch = 10, size_patch_y = 10, size_patch_x = 10, var_patch_y = 0.2, var_patch_x = 0.2),
)
dataset_train = BeHenateDataset( data_list          = data_train,
                                 size_sample        = size_sample,
                                 trans_list         = trans_list,
                                 normalizes_data    = normalizes_data,
                                 prints_cache_state = False,
                               )
dataset_train.cache_dataset()

dataset_validate = BeHenateDataset( data_list       = data_validate,
                                    size_sample     = size_sample // 2,
                                    trans_list      = trans_list,
                                    normalizes_data = normalizes_data,
                                    prints_cache_state = False,
                                  )
dataset_validate.cache_dataset()

In [11]:
# img, center, _ = dataset_train[23]
# img = img[0]

# radius = 10
# fig = plt.figure(figsize = (8, 8))
# circle = mpatches.Circle(center[::-1], radius=radius, color='red', fill=False)    # Err..., matplotlib uses (x, y) not (y, x)
# vmin = np.nanmedian(img)
# vmax = np.nanmedian(img) + 8 * np.nanstd(img)
# plt.imshow(img, vmin = vmin, vmax = vmax)
# plt.gca().add_patch(circle)
# plt.title(f"{img.shape}, {center}")

In [12]:
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
config_model = ConfigModel( size_y = size_img_y, size_x = size_img_x, isbias = True )
model = BeHenataNet(config_model)
model.init_params(fl_chkpt = None)
# model.init_params(fl_chkpt = fl_chkpt)

In [13]:
# [[[ TRAINER ]]]
# Config the trainer...
config_train = ConfigTrainer( timestamp    = timestamp,
                              num_workers  = 1,
                              batch_size   = batch_size,
                              pin_memory   = True,
                              shuffle      = False,
                              lr           = lr,
                              tqdm_disable = True)
trainer = Trainer(model, dataset_train, config_train)

# [[[ VALIDATOR ]]]
# Config the validator...
config_validator = ConfigValidator( num_workers  = 1,
                                    batch_size   = batch_size,
                                    pin_memory   = True,
                                    shuffle      = False,
                                    lr           = lr,
                                    tqdm_disable = True)
validator = Validator(model, dataset_validate, config_validator)

In [14]:
# [[[ TRAIN EPOCHS ]]]
loss_train_hist    = []
loss_validate_hist = []
loss_min_hist      = []

# [[[ EPOCH MANAGER ]]]
epoch_manager = EpochManager( trainer   = trainer,
                              validator = validator, )
max_epochs = 1000
freq_save = 5
for epoch in tqdm.tqdm(range(max_epochs), disable=False):
    loss_train, loss_validate, loss_min = epoch_manager.run_one_epoch(epoch = epoch, returns_loss = True)

    loss_train_hist.append(loss_train)
    loss_validate_hist.append(loss_validate)
    loss_min_hist.append(loss_min)

    # if epoch % freq_save == 0: 
    #     epoch_manager.save_model_parameters()
    #     epoch_manager.save_model_gradients()
    #     epoch_manager.save_state_dict()

100%|██████████| 1000/1000 [16:36<00:00,  1.00it/s]
