## Imports

In [None]:
import os
from datetime import datetime
from pathlib import Path

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

from torch.utils.data import DataLoader
import torch
from torch import nn
import torchvision
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from utils import get_BUSI_dataset, get_MRI_dataset, get_dataloaders_imagedata
from models import IntelligentMaskModelRL

## Setup

In [None]:
recon_model_structure = 'UNet' # options: 'UNet' or 'AE'
masker_model_structure = 'AE' # options: 'UNet' or 'AE'
image_size = 300
max_epochs = 300
random_state = 42
pl.seed_everything(random_state, workers=True)
start_time_str = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device=='cuda': print(torch.cuda.get_device_name(0))
gpus = 1 if device=='cuda' else 0
device = torch.device(device)
print(f'device: {device}')

## Dataset Config

In [None]:
data_path = 'Dataset/Dataset_BUSI_with_GT'
df = get_BUSI_dataset(data_path)
nclass = df.label_cat.values.max() - df.label_cat.values.min() + 1
single_channel = True

In [None]:
dataloader_params = {'batch_size': 16, 'num_workers': 4, 'pin_memory':True, 'shuffle_train':True, 'shuffle_test':False}
df_train, df_test = train_test_split(df, test_size=0.1, random_state=14)
dataloader_train, dataloader_test = get_dataloaders_imagedata(df_train=df_train, df_test=df_test, image_size=image_size, 
                                                              dataloader_params = dataloader_params, single_channel = single_channel)

## Model Config

In [None]:
recon_model_config = {
                'enc_chs' : (1, 16, 32, 64, 64, 128, 128),
                'dec_chs' : (128, 128, 64, 64, 32, 16),
                'retain_dim':True, 
                'image_size':image_size,
                'unet_flag':True if recon_model_structure == 'UNet' else False}

masker_model_config = {
                'enc_chs' : (1, 16, 32, 64, 64, 128, 128),
                'dec_chs' : (128,),
                'retain_dim':False, 
                'image_size':image_size,
                'unet_flag':True if masker_model_structure == 'UNet' else False}

training_config = {
    'recon_model_config': recon_model_config,
    'masker_model_config': masker_model_config,
    'loss_on_mask' : True,
    'lr_recon_model': 1e-3,
    'lr_masker_model': 1e-3,
    'use_scheduler' : True, 
    'milestones' : [max_epochs//3, 2 * max_epochs//3]
}
    
model = IntelligentMaskModelRL(**training_config)
model = model.to(device)
print('Number of params:', sum(p.numel() for p in model.parameters()))

## Trainer Config

In [None]:
lr_monitor = LearningRateMonitor(logging_interval='epoch')
save_model_path = f"model_checkpoint/UnsupervisedModel/{model.__class__.__name__}-{start_time_str}"
if not os.path.exists(save_model_path):
    Path(save_model_path).mkdir(parents=True, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    dirpath=save_model_path,
    filename="{epoch:02d}",
    # save_weights_only=True,
    every_n_epochs=50, 
    save_on_train_epoch_end=False,
    save_top_k = -1
)

trainer = pl.Trainer(gpus=gpus, max_epochs=max_epochs, callbacks=[checkpoint_callback, lr_monitor], log_every_n_steps=len(dataloader_train))

## Training

In [None]:
trainer.fit(model, dataloader_train, dataloader_test)