# Corrosion in Industrial Complexes in Ostrava

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import re
from pathlib import Path
from pprint import pprint

import albumentations as albu
import dagshub
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import pytorch_lightning as L
import torch
from dagshub import get_repo_bucket_client

# Note: This does not recover the best weights as in Keras!
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# https://github.com/Lightning-AI/pytorch-lightning/discussions/10399,
# https://pytorch-lightning.readthedocs.io/en/1.5.10/extensions/generated/pytorch_lightning.callbacks.ModelCheckpoint.html
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

from corrosion import (
    CorrosionModel,
    SegmentationDatasetLoader,
    SegmentationDatasetSplit,
)
from corrosion.augmentation import (
    compose_transforms,
    hard_transforms,
    post_transforms,
    pre_transforms,
    resize_transforms,
)
from corrosion.git import get_commit_id, get_current_branch
from corrosion.plot import (
    plot_learning_curves,
    plot_predictions,
    plot_predictions_compact,
    show_examples,
    show_random,
)

## Configuration

In [None]:
plt.style.use('seaborn-v0_8-whitegrid')

# Experiment logging
REPO_NAME = 'corrosion'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)

# Reproducibility
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
SEED = 42
L.seed_everything(SEED, workers=True)

print(
    f'torch: {torch.__version__}, cuda: {torch.cuda.is_available()}, lightning: {L.__version__}'
)

In [None]:
# 'resnet18', 'resnet34', 'resnet50', 'resnet101', ...
ENCODER = 'resnet18'
DECODER = 'unet'  # 'unet', 'unetplusplus', ...
FREEZE_ENCODER = True
MAX_EPOCHS = 500
MONITOR = 'val_loss'
PATIENCE = 50

In [None]:
LOG_DIR = Path('logs')
EXPERIMENT_NAME = f'{DECODER}-{ENCODER}'
VERSION = 0
EXPERIMENT_DIR = LOG_DIR / EXPERIMENT_NAME / f'version_{VERSION}'
DATASET_URL = 'https://doi.org/10.5281/zenodo.10732179'

METRICS_CSV_NAME = 'metrics.csv'
LEARNING_CURVES_PNG_NAME = 'learning_curves.png'
PREDICTIONS_PNG_NAME = 'predictions.png'
TRAIN_TRANSFORMS_JSON_NAME = 'train_transforms.json'

In [None]:
IMAGE_SIZE = 256
BATCH_SIZE = 4

DATASET_DIR = Path('data')  # 'Corrosion_in_Industrial_Complexes_in_Ostrava'

TRAIN_SET_DIR = 'train'
VALID_SET_DIR = 'validation'
TEST_SET_DIR = 'test'

IMAGES_DIR = 'images'
MASKS_DIR = 'masks'

TRAIN_IMAGES = sorted((DATASET_DIR / TRAIN_SET_DIR / IMAGES_DIR).glob('*.jpg'))
TRAIN_MASKS = sorted((DATASET_DIR / TRAIN_SET_DIR / MASKS_DIR).glob('*.png'))

VALID_IMAGES = sorted((DATASET_DIR / VALID_SET_DIR / IMAGES_DIR).glob('*.jpg'))
VALID_MASKS = sorted((DATASET_DIR / VALID_SET_DIR / MASKS_DIR).glob('*.png'))

TEST_IMAGES = sorted((DATASET_DIR / TEST_SET_DIR / IMAGES_DIR).glob('*.jpg'))
TEST_MASKS = sorted((DATASET_DIR / TEST_SET_DIR / MASKS_DIR).glob('*.png'))

In [None]:
# show_random(TRAIN_IMAGES, TRAIN_MASKS)

In [None]:
# show_random(VALID_IMAGES, VALID_MASKS)

In [None]:
# show_random(TEST_IMAGES, TEST_MASKS)

## Augmentations and Transforms

In [None]:
train_transforms = compose_transforms(
    [
        resize_transforms(image_size=IMAGE_SIZE),
        hard_transforms(),
        post_transforms(),
    ]
)
valid_transforms = compose_transforms(
    [
        pre_transforms(image_size=IMAGE_SIZE),
        post_transforms(),
    ]
)
test_transforms = compose_transforms(
    [
        pre_transforms(image_size=IMAGE_SIZE),
        post_transforms(),
    ]
)

show_transforms = compose_transforms([resize_transforms(), hard_transforms()])

In [None]:
# show_random(TRAIN_IMAGES, TRAIN_MASKS, transforms=show_transforms)

In [None]:
# train_transforms.transforms

## Loaders

In [None]:
dataset_loader = SegmentationDatasetLoader(
    train=SegmentationDatasetSplit(images=TRAIN_IMAGES, masks=TRAIN_MASKS),
    valid=SegmentationDatasetSplit(images=VALID_IMAGES, masks=VALID_MASKS),
    test=SegmentationDatasetSplit(images=TEST_IMAGES, masks=TEST_MASKS),
)

In [None]:
loaders = dataset_loader.get_loaders(
    # set to zero if RuntimeError: Trying to resize storage that is not resizable
    num_workers=int(os.cpu_count()),
    batch_size={'train': BATCH_SIZE, 'valid': 1, 'test': 1},
    train_transforms=train_transforms,
    valid_transforms=valid_transforms,
    test_transforms=test_transforms,
)

train_dataloader = loaders['train']
valid_dataloader = loaders['valid']
test_dataloader = loaders['test']

## Training

In [None]:
model = CorrosionModel(
    DECODER, ENCODER, in_channels=3, out_classes=1, freeze_encoder=FREEZE_ENCODER
)

In [None]:
tb_logger = TensorBoardLogger(LOG_DIR, name=EXPERIMENT_NAME, version=VERSION)
csv_logger = CSVLogger(LOG_DIR, name=EXPERIMENT_NAME, version=VERSION)
early_stopping = EarlyStopping(
    monitor=MONITOR,
    mode='min',
    patience=PATIENCE,
)
model_checkpoint = ModelCheckpoint(
    dirpath=EXPERIMENT_DIR,
    filename='{epoch}-{val_loss:3f}',
    monitor=MONITOR,
    save_top_k=1,  # save only the best model
    mode='min',
)

In [None]:
trainer = L.Trainer(
    logger=[tb_logger, csv_logger],
    callbacks=[model_checkpoint, early_stopping],
    max_epochs=MAX_EPOCHS,
    log_every_n_steps=1,  # log every batch
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
    deterministic=True,
)

In [None]:
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
)

## Validation

In [None]:
model_checkpoint_path = list(EXPERIMENT_DIR.glob('*.ckpt'))[0]
model_ = CorrosionModel.load_from_checkpoint(model_checkpoint_path)

In [None]:
valid_metrics = trainer.validate(model_, dataloaders=valid_dataloader, verbose=False)[0]
pprint(valid_metrics)

In [None]:
test_metrics = trainer.test(model_, dataloaders=test_dataloader, verbose=False)[0]
pprint(test_metrics)

In [None]:
plot_predictions_compact(
    model_, test_dataloader, save_path=EXPERIMENT_DIR / PREDICTIONS_PNG_NAME
)

In [None]:
# plot_predictions(model_, test_dataloader)

In [None]:
plot_learning_curves(
    EXPERIMENT_DIR / METRICS_CSV_NAME,
    save_path=EXPERIMENT_DIR / LEARNING_CURVES_PNG_NAME,
)

## Logging

In [None]:
# Save the transforms for experiment logging
albu.save(train_transforms, EXPERIMENT_DIR / TRAIN_TRANSFORMS_JSON_NAME)

In [None]:
def get_early_stopping_epoch() -> int | None:
    checkpoint = list(EXPERIMENT_DIR.glob('*.ckpt'))[0].stem
    pattern = r'epoch=(\d+)'
    match = re.search(pattern, checkpoint)
    if match:
        return int(match.group(1))
    else:
        return None

In [None]:
def log_dict_to_mlflow(dictionary: dict[str, float]) -> None:
    for k, v in dictionary.items():
        mlflow.log_metric(k, v)

In [None]:
with mlflow.start_run(run_name=f'{EXPERIMENT_NAME}-v{VERSION}') as run:
    try:
        mlflow.set_tag('Branch', get_current_branch())
        mlflow.set_tag('Commit ID', get_commit_id())
        mlflow.set_tag('Dataset', DATASET_URL)
    except Exception as e:
        print(e)

    log_dict_to_mlflow(dict(valid_metrics))
    log_dict_to_mlflow(dict(test_metrics))

    mlflow.log_param('encoder', ENCODER)
    mlflow.log_param('decoder', DECODER)
    mlflow.log_param('batch_size', BATCH_SIZE)
    mlflow.log_param('max_epochs', trainer.max_epochs)
    mlflow.log_param('early_stopping', get_early_stopping_epoch())
    mlflow.log_param('monitor', MONITOR)
    mlflow.log_param('patience', PATIENCE)
    mlflow.log_param('image_size', IMAGE_SIZE)
    mlflow.log_param('frozen_encoder', FREEZE_ENCODER)

    # Models are versioned by default
    mlflow.pytorch.log_model(
        pytorch_model=model_,
        artifact_path='model',
        registered_model_name=f'pytorch-{DECODER}-{ENCODER}',
    )

In [None]:
# Get a boto3.client object
s3 = get_repo_bucket_client(f'{USER_NAME}/{REPO_NAME}')

# Upload files to the bucket
s3.upload_file(
    Bucket=REPO_NAME,
    Filename=str(EXPERIMENT_DIR / METRICS_CSV_NAME),
    Key=str(EXPERIMENT_DIR / METRICS_CSV_NAME),
)
s3.upload_file(
    Bucket=REPO_NAME,
    Filename=str(EXPERIMENT_DIR / PREDICTIONS_PNG_NAME),
    Key=str(EXPERIMENT_DIR / PREDICTIONS_PNG_NAME),
)
s3.upload_file(
    Bucket=REPO_NAME,
    Filename=str(EXPERIMENT_DIR / LEARNING_CURVES_PNG_NAME),
    Key=str(EXPERIMENT_DIR / LEARNING_CURVES_PNG_NAME),
)
s3.upload_file(
    Bucket=REPO_NAME,
    Filename=str(EXPERIMENT_DIR / TRAIN_TRANSFORMS_JSON_NAME),
    Key=str(EXPERIMENT_DIR / TRAIN_TRANSFORMS_JSON_NAME),
)