# Data Augmentation with Hyperparameter Tuning

In [None]:
# basics
import os
import utilities.utils as utils
import numpy as np
from tqdm.notebook import tqdm


# torch
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import lightning as L

from data_preperation.dataset import CityDataset
import os

from config import PATH, CITIES, MIN_LABELS, PATCH_SIZE, LOGLEVEL, TEST_CITY

# custom modules
from data_acquisition.datahandler import DataHandler


# Configure logging for the pipeline
logger = utils.setup_logger(level=LOGLEVEL)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# from torch.utils.tensorboard import SummaryWriter
import lightning as L
from typing import Any
from torch.utils.data import DataLoader, Dataset
from lightning import seed_everything
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from models.lightning_utils import LitModule
from models.baseconvnet import ConvNetSimple

from lightning.pytorch.tuner.tuning import Tuner

# model
# convmodel = LitModule(ConvNetSimple())


# trainer
def get_trainer(dirname):
    trainer = L.Trainer(
        default_root_dir=f"model_experiments/augmentation/{dirname}",
        callbacks=[
            EarlyStopping(
                monitor="val_loss",
                mode="min",
                patience=5,
            ),
            ModelCheckpoint(
                dirpath=f"model_experiments/augmentation/{dirname}",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
                filename="best_model",
            ),
        ],
        # val_check_interval=1,
        fast_dev_run=False,
        num_sanity_val_steps=2,
        max_epochs=100,
        log_every_n_steps=20,
    )
    return trainer

#### Here we import the Data Augmentation Transformation

In [None]:
from data_augmentation.transformations import (
    ColorJitterCustom,
    random_rotation,
    random_horizontal_flip,
    random_vertical_flip,
)


from torchvision.transforms import (
    Compose,
    Lambda,
    RandomHorizontalFlip,
    RandomVerticalFlip,
    RandomChoice,
)

# transforms_train = Compose([
#     random_rotation,
#     random_horizontal_flip,
#     random_vertical_flip,
#     ColorJitterCustom(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
# ])


# Compose the transformations
custom_transforms1 = Compose(
    [
        random_rotation,
        random_horizontal_flip,
        random_vertical_flip,
        ColorJitterCustom(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    ]
)

custom_transforms2 = Compose(
    [random_rotation, random_horizontal_flip, random_vertical_flip]
)

custom_transforms3 = Compose(
    [
        ColorJitterCustom(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    ]
)


def no_op(x):
    return x


custom_transforms4 = Compose(
    [
        RandomChoice(
            [random_rotation, random_horizontal_flip, random_vertical_flip, no_op]
        ),
        RandomChoice(
            [random_rotation, random_horizontal_flip, random_vertical_flip, no_op]
        ),
        ColorJitterCustom(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    ]
)

transformations_list = [
    custom_transforms1,
    custom_transforms2,
    custom_transforms3,
    custom_transforms4,
]
# transformations_list = [custom_transforms3]
# transformations_list = [custom_transforms4]

In [None]:
torch.set_float32_matmul_precision("high")  # for tensor cores
# from data_acquisition import DataHandler

# datahandler = DataHandler(logger, path_to_data_directory="data")


# load images and mask for all specified cites

# import os
# images = []
# dense_masks=[]
# boundary_masks=[]

# for city in tqdm(cities):
#     buildings = None
#     if not os.path.exists(os.path.join(datahandler.path_to_data_directory,city,'building_mask_dense.tif')):
#         print("loading local buildings")
#         buildings = datahandler.get_buildings(city)
#     images.append(datahandler.get_satellite_image(city).transpose(2, 0, 1))
#     dense_masks.append(datahandler.get_building_mask(city, all_touched=True, loaded_buildings=buildings))
#     boundary_masks.append(datahandler.get_boundaries_mask(city))

# print(f"Data len: {len(images)}, {len(dense_masks)}, {len(boundary_masks)}")

dataset = CityDataset(
    PATH,
    patch_size=PATCH_SIZE,
    data_name="openEO.tif",
    labels_name="building_mask_dense.tif",
    image_bands=[1, 2, 3, 4, 5, 6],
    min_labels=MIN_LABELS,
    # devrun=True,
    cities=CITIES,
    train=True,
)

train_ds, val_ds = dataset.train_val_split(val_size=0.1, show_summary=False)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=20)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=20)

dataset_test = CityDataset(
    PATH,
    data_name="openEO.tif",
    labels_name="building_mask_dense.tif",
    image_bands=[1, 2, 3, 4, 5, 6],
    # devrun=True,
    cities=TEST_CITY,
    train=False,
)
test_dl = DataLoader(dataset_test, batch_size=32, shuffle=False, num_workers=20)

In [None]:
for idx, transforms in enumerate(transformations_list):
    idx = 4
    # dataset.transform = transforms
    dataset.transform = None
    train_ds, val_ds = dataset.train_val_split(val_size=0.1, show_summary=False)
    train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=20)
    val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=20)
    model = LitModule(ConvNetSimple(6), learning_rate=0.001, optimizer="adam")
    trainer = get_trainer(f"convSimple/_{idx}_transformation")
    seed_everything(49)
    trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
    best_model = LitModule.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path
    )
    trainer.test(model=model, dataloaders=test_dl)

In [None]:
prediction = trainer.predict(model=model, dataloaders=test_dl)

In [None]:
output = prediction[0].detach().numpy()
output = output.squeeze()
from utilities.plot_utils import (
    plot_prediction_with_thresholds,
    plot_random_patch,
    plot_output,
)

plot_output(output)
plot_random_patch(output, patch_len=6)
plot_prediction_with_thresholds(output)