## Imports

In [None]:
import sys
import os
import glob
from pathlib import Path
from datetime import datetime
import re
import argparse

import pandas as pd
import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
from sklearn.model_selection import train_test_split

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback
from pytorch_lightning.loggers import NeptuneLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load

from utils import get_BUSI_dataset, get_dataloaders_imagedata, get_dataloaders_clf
from models import IntelligentMaskModelRL, NNClassifier, CNNClassifier, ResNetClassifier, EncoderClassifier

## Setup

In [None]:
#Paths
load_checkpoint_path = Path('model_checkpoint/UnsupervisedModel/')
data_path = 'Dataset/Dataset_BUSI_with_GT/'
save_checkpoint_path = "model_checkpoint/Classifier/"

In [None]:
random_state = 42
data_random_state = 42
use_weight = True # for imbalanced data, use weights related to the inverse of the number of samples in each class
make_balance = True # make the training and test set balanced (equal number of samples for all classes)
image_size = 300 
# path to the pre-trained model (the whole IntelligentMaskModelRL)
pretrained_model_path = load_checkpoint_path/'IntelligentMaskModelRL-05-10-2022-15-42-34/epoch=01.ckpt'

pl.seed_everything(random_state, workers=True)
start_time_str = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device=='cuda': print(torch.cuda.get_device_name(0))
gpus = 1 if device=='cuda' else 0
device = torch.device(device)
print(f'device: {device}')

## Data Config

In [None]:
# Read data
df = get_BUSI_dataset(data_path)
nclass = df.label_cat.values.max() - df.label_cat.values.min() + 1
single_channel = True

In [None]:
# Split data based on the following dictionary
split_config = {
    'train_size': 100,
    'val_size': 100,
    'test_size': 100
}

df_train, df_test = train_test_split(df, test_size=split_config['test_size'], random_state=data_random_state, stratify=df.label_cat)
df_train, df_val = train_test_split(df_train, test_size=split_config['val_size'], random_state=data_random_state, stratify=df_train.label_cat)
df_train, _ = train_test_split(df_train, train_size=split_config['train_size'], random_state=data_random_state, stratify=df_train.label_cat)

In [None]:
# Create dataloaders
dataloader_params = {'batch_size': 16, 'num_workers': 2, 'pin_memory':True, 'shuffle_train':True, 'shuffle_test':False}
clf_dataloader_train, clf_dataloader_test = get_dataloaders_imagedata(df_train, df_test, image_size, dataloader_params, single_channel, random_transform_train=False)
_                   , clf_dataloader_val = get_dataloaders_imagedata(df_train, df_val, image_size, dataloader_params, single_channel, random_transform_train=False)

## Model Config

In [None]:
# Config for classifier 
config_clf = {
    'hidden' : [16], # list of neurons in each hidden layer of FC classifier
    'conv_channel' : 32, # the size of the convolutional channel before the classifier
    'nlayer_unfreeze': 0, # 'all' or Integer to unfreeze the nlayer_unfreeze layer of encoder in the classifier training
    'nclass' : nclass,
    'lr' : 1e-3,
    'weight_decay' : 0.1,
    'use_scheduler' : True,
    'dropout' : 0.3,
    'use_weight' : use_weight,
    'milestones' : [150,175] # milestones are used to change the learning rate in MultiStepLR scheduler with gamma=0.33
}

if config_clf['use_weight']:
    config_clf['weight'] = torch.tensor(df_train.label_cat.value_counts())

pretrained_config = {
    'checkpoint_name': pretrained_model_path
}

extra_config = {
    'random_state': random_state,
    'every_n_epochs' : 1,
    'max_epochs' : 200,
    'run_classifier_from_saved': False
}

In [None]:
pretrained_model = IntelligentMaskModelRL.load_from_checkpoint(pretrained_config['checkpoint_name']).to(device)
config_clf['encoder'] = pretrained_model.recon_model.encoder
config_clf['encoder_last_channel'] = pretrained_model.recon_model.encoder.last_channel

In [None]:
model_clf = EncoderClassifier(**config_clf).to(device)

## Trainer Config

In [None]:
model_name_str = model_clf.__class__.__name__ + '-' + start_time_str
save_model_path = save_checkpoint_path + pretrained_config['checkpoint_name'].split('/')[0]
file_name = re.split('/|.ckpt',pretrained_config['checkpoint_name'])[1]
if not os.path.exists(save_model_path):
    Path(save_model_path).mkdir(parents=True, exist_ok=True)
    
lr_monitor = LearningRateMonitor(logging_interval='epoch')

checkpoint_callback = ModelCheckpoint(
    dirpath=save_model_path,
    filename= model_name_str + '_' + file_name + "-{epoch:02d}",
    every_n_epochs=extra_config['max_epochs'], 
    save_on_train_epoch_end=False,
    save_top_k = -1
)

trainer = pl.Trainer(gpus=gpus, max_epochs=extra_config['max_epochs'], callbacks=[checkpoint_callback, lr_monitor], log_every_n_steps=5)

## Training

In [None]:
trainer.fit(model_clf, clf_dataloader_train, [clf_dataloader_val,clf_dataloader_test])

## Test

In [None]:
trainer.test(model_clf, clf_dataloader_test)