In [None]:
# torch related imports
import torch
from lightning.pytorch.loggers import TensorBoardLogger

# imports for hyperparam tuning with Ray
from ray import tune
from ray.train.torch import TorchTrainer
from ray.train import RunConfig, ScalingConfig, CheckpointConfig

# Lightning
import lightning as L

from utils import loadData, plotExamples, set_reproducibility
from models import train_func, tuning, ConvNet, Classificator

## Hyperparameters

In [None]:
#Dataloader params
NUM_WORKERS = 7 # 7 because that what it suggested in a warning message
PERSISTENT_WORKERS = True # Suggested to do this in a warning message for faster init
USE_AUGMENT = False  # mutual exclusive with CUSTOM_TRAIN_VAL_SPLIT
CUSTOM_TRAIN_VAL_SPLIT = True  # mutual exclusive with USE_AUGMENT
PROJECT_DATA_DIR = "/Project" # Change this to chest_xray folder
DATA_SET_DIR = PROJECT_DATA_DIR + "/chest_xray" # Change this to chest_xray folder
LIGHTNING_LOGS_DIR = PROJECT_DATA_DIR + "/lightning_logs"
USE_SAMPLER = False
SHOW_ANALYTICS = False

# Lightning moduls params
EPOCHS = 50
CLASS_LABELS = ["Normal", "Pneumonia"]
NUM_CLASSES = 2

# Tuning params
NUM_SAMPLES = 20 # Number of sampls from parameter space

search_space = {
    "reproducibility_active": True,
    "epochs": EPOCHS,
    #"seed": tune.randint(0, 10000),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([16,32, 64]),
    "loss" : tune.choice(["BCEwLogits", "CrossEntropyLoss"]),
    "dropout": tune.choice([0.2, 0.5, 0.8]), # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10518240/pdf/cureus-0015-00000044130.pdf
    "project_data_dir": PROJECT_DATA_DIR,
    "data_set_dir": DATA_SET_DIR,
    "lightning_logs": LIGHTNING_LOGS_DIR,
}

## Ensuring reporducibility

In [None]:
# Setting the seed
set_reproducibility(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Load Data and show Analytics

In [None]:
# show analytics
train_loader, val_loader, test_loader = loadData(numWorkers=NUM_WORKERS, showAnalytics = True, batchSize = 32)

In [None]:
plotExamples(train_loader)

## Hyperparameter Tuning with Ray Tune

In [None]:
# setting scaling and run config
scaling_config = ScalingConfig(
    num_workers=1, use_gpu=True
)

run_config = RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="val_BinaryAccuracy",
        checkpoint_score_order="max",
    ),
)

In [None]:
# Define a TorchTrainer without hyper-parameters for Tuner
ray_trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    run_config=run_config,
)

In [None]:
results = tuning(ray_trainer, search_space, num_samples=NUM_SAMPLES, num_epochs=EPOCHS)

In [None]:
best_result = results.get_best_result(metric="val_BinaryAccuracy", mode="max")
best_config = best_result.config  # Get best trial's hyperparameters
print(best_config)
#best_logdir = best_result.path  # Get best trial's result directory
#best_checkpoint = best_result.checkpoint  # Get best trial's best checkpoint
best_metrics = best_result.metrics  # Get best trial's last results
print(best_metrics)
best_result_df = best_result.metrics_dataframe  # Get best result as pandas dataframe

In [None]:
print(best_config['train_loop_config'])

## Running a CNN with the best found hyperparams

In [None]:
set_reproducibility(42)

In [None]:
#best_config = {'train_loop_config': {'reproducibility_active': True, 'epochs': 50, 'lr': 0.00035213424594870914, 'batch_size': 32, 'loss': 'CrossEntropyLoss', 'dropout': 0.5}}

In [None]:
train_loader, val_loader, test_loader = loadData(dataDir=DATA_SET_DIR, numWorkers=7, batchSize=best_config['train_loop_config']["batch_size"])

In [None]:
early_stopping = L.pytorch.callbacks.EarlyStopping(monitor='Validation loss', patience=10, min_delta=1e-6)
checkpoint = L.pytorch.callbacks.ModelCheckpoint(dirpath=PROJECT_DATA_DIR + '/pneumonia_model/', monitor="val_BinaryAccuracy", mode='max')
callbacks = [early_stopping, checkpoint]
logger = TensorBoardLogger(LIGHTNING_LOGS_DIR,
                           name=f"simpleCNN/{'augment' if USE_AUGMENT else 'original'}",
                           )

In [None]:
trainer = L.Trainer(
    accelerator = 'auto',
    devices=1,
    logger=logger,
    max_epochs=EPOCHS, 
    callbacks=callbacks)

In [None]:
cnn = ConvNet(num_classes=NUM_CLASSES, dropout=best_config['train_loop_config']['dropout'])
classifier = Classificator(cnn, CLASS_LABELS, best_config['train_loop_config'], NUM_CLASSES)
torch.set_float32_matmul_precision('medium')
trainer.fit(classifier,train_dataloaders=train_loader,val_dataloaders=val_loader)

## Testing final Model

In [None]:
trainer.test(model = classifier, dataloaders=test_loader)