# Журнал тренировки сети

Данный журнал необходим для тренировки сети в системе Yandex DataSphere. На вход в блоке входных данных указываются параметры обучения:
- Путь до директории с датасетом
- Число эпох
- Скорость обучения
- Batch size
- Описание эксперимента обучения
- Путь для сохранения результатов обучения (графики, метрики)
- Путь для сохранения визуализации работы сети на тренировочном и тестовом датасете (None, если не нужно производить визуализацю)

## Блок входных параметров

In [1]:
DATASET_PATH = "./../DataGeneration/inflated_new_vital"
#DATASET_PATH = "./../DataGeneration/green_data_3"
EPOCHS = 150
LR = 3e-4
BATCH_SIZE = 15
MODEL_NAME = "model.h5"

## Секция проведения обучения

### 0. Импорт необходимых классов и объектов

In [None]:
from dataset_reader import DatasetReader
from dataset_loader import DatasetSequence

### 1. Загрузка и проверка данных

In [None]:
dr = DatasetReader(DATASET_PATH)
X_train, Y_train, X_val, Y_val = dr.read_and_split_all_data(need_shuffle=True)

print(X_train.shape)
print(Y_train.shape)
print(X_val.shape)
print(Y_val.shape)

train_loader = DatasetSequence(X_train, Y_train, BATCH_SIZE)
valid_loader = DatasetSequence(X_val, Y_val, BATCH_SIZE)

print(len(train_loader))
print(len(valid_loader))

In [None]:
import os
from random import randint
from utils import *
from plotting_utils import *
import matplotlib.cm as cm

data_to_save_cnt = 10
for _ in range(data_to_save_cnt):
    idx = randint(0, X_train.shape[0])
    print(idx)
    x_data = X_train[idx, ..., 0]
    y_data = Y_train[idx, ..., 0]
    #x_inted_data = ((x_data * 255).astype("uint8")).astype("float32") / 255 
    save_tiff(x_data, f"./data/data_vis/x{idx}.tiff")
    save_tiff(y_data, f"./data/data_vis/y{idx}.tiff")
    save_image_slices(x_data, f"./data/data_vis/x{idx}.png", cm.jet, 0.022, 0.1, np.array(x_data.shape) // 2)
    #save_image_slices(x_inted_data, f"./data/data_vis/x_inted_{idx}.png", cm.jet, 0.022, 0.1, np.array(x_data.shape) // 2)
    save_image_slices(y_data, f"./data/data_vis/y{idx}.png", cm.jet, 0.022, 0.1, np.array(y_data.shape) // 2)

### 2. Инициализация и компиляция модели глубокого обучения 

In [None]:
import numpy as np
#from CNNModels.cnn_deconv_unet import CNNDeconvUNet
from CNNModels.cnn_deconv_unet_exp import CNNDeconvUNet
from CNNModels.cnn_deconv_rescoder import CNNDeconvRescoder

from tensorflow.keras.utils import plot_model

model = CNNDeconvUNet.build_model((*dr.shape, 1), LR)
plot_model(model, to_file='model.png')

### 3. Обучение модели

In [None]:
import os
from datetime import datetime
import tensorflow as tf
import keras

import wandb
from wandb.integration.keras import (
   WandbMetricsLogger,
   WandbModelCheckpoint,
)
import json

wanbai_keys_dict = json.loads(os.environ['wandbai_logins'])
os.environ['WANDB_API_KEY'] = '600b44e5399cb8873011d1bd392dcc603d66f5c1'#wanbai_keys_dict['sachuk']


# init run
run = wandb.init(
    # Set the project where this run will be logged
    project="RuDeconv3D",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": LR,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE
    },
    settings=wandb.Settings(init_timeout=180)
)



# Generate paths for logging and saving results
result_path = os.path.join("./train_logs/", datetime.now().strftime("%Y-%m-%d_%H-%M"))
os.mkdir(result_path)
model_saving_path = os.path.join(result_path, "best_" + MODEL_NAME)


# Generate checkpoints
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_saving_path,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    verbose=1,
    save_best_only=True
)

tb_path = os.path.join(result_path, "tensorboard_log")
tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir=tb_path,
    histogram_freq=1
)


# train model
hist = model.fit(x=train_loader, 
    validation_data=valid_loader, 
    epochs = EPOCHS, 
    batch_size=BATCH_SIZE, 
    #use_multiprocessing=False, 
    shuffle=True, 
    callbacks=[
        model_checkpoint_callback, 
        tensorboard_callback, 
        WandbMetricsLogger(log_freq=5)
              ] # here are callbacks
)


In [None]:
model_saving_path = os.path.join(result_path, "last_" + MODEL_NAME)
model.save_weights(model_saving_path)