In [None]:
import os
import pandas as pd

from elevation_aware_ssl.classification.dataset import CustomDataset
from elevation_aware_ssl.classification.utils import visualize_augmented_images, generate_metadata_train_test_stratified_cv
from elevation_aware_ssl.trainer import Trainer
from sklearn.model_selection import train_test_split

# import wandB
import wandb

# wandb.login(key="wandb_key")
wandb.login(key="wandb_key")
os.environ["WANDB_SILENT"] = "true"

In [None]:
drive_path = '/content/drive/MyDrive' 
drive_path = "/media/omar/storage/gdrive"

In [None]:
query = [
    "((legal_exclusions == 0) and (non_agricultural == 1) and (agricultural_frontier == 0)) "
    "| ((legal_exclusions == 0) and (non_agricultural == 0) and (agricultural_frontier == 1))"
]
# "| ((legal_exclusions == 1) and (non_agricultural == 0) and (agricultural_frontier == 0))"]

In [None]:
select_classes = ["non_agricultural", "agricultural_frontier"]  # ["non_agricultural", "legal_exclusions", "agricultural_frontier"] #

In [None]:
# Load metadata
path_to_metadata = f"{drive_path}/Maestria/Datasets/GeoDataset/metadata_v2/metadata.csv"
metadata = pd.read_csv(path_to_metadata).sample(frac=0.7, random_state=42)
metadata = metadata.query(query[0])
metadata["Classes"] = metadata[select_classes].idxmax(axis=1)
metadata["Labels"] = metadata.Classes.map({k:i for i, k in enumerate(select_classes)})

print(metadata.shape)
metadata.head()

In [None]:
#path to image and masks
path_to_images = f'{drive_path}/Maestria/Datasets/GeoDataset/Dataset/'

In [None]:
augment = {
    "horizontal_flip_prob": 0.5,
    "vertical_flip_prob": 0.5,
    "resize_scale": (0.8, 1.0),
    "resize_prob": 1.0,
    "brightness": 0.4,
    "contrast": 0.4,
    "saturation": 0.4,
    "hue": 0.2,
    "color_jitter_prob": 0.5,
    "gray_scale_prob": 0.2,
}

ds = CustomDataset(path_to_images, metadata.sample(10), return_original=True, augment=augment)
visualize_augmented_images(ds, class_names=select_classes, brightness=0.0)

In [None]:
path_to_save_model = f"{drive_path}/Maestria/Theses/pruebas/models/resnet18"

In [None]:
metadata_train, metadata_test, metadata_valid = generate_metadata_train_test_stratified_cv(metadata, train_size=10, n_split=4)

In [None]:
path_to_save_model = f"{drive_path}Maestria/pruebas/models/resnet18"
path_to_load_backbone_simclr = f"{drive_path}Maestria/Theses/SSL/SimCLR/models/resnet18/model_SSL-SimCLR-v2.pth"
path_to_load_backbone_elevation_simclr = f"{drive_path}Maestria/Theses/SSL/ElevationSSL/models/resnet18/Elevation-SimCLR/model_SimCLR-Elevation.pth"
path_to_load_backbone_elevation = f"{drive_path}Maestria/Theses/SSL/ElevationSSL/models/resnet18/Elevation/model_Elevation.pth"

In [None]:
augment = {
    "horizontal_flip_prob": 0.5,
    "vertical_flip_prob": 0.5,
    "resize_scale": (0.8, 1.0),
    "resize_prob": 1.0,
    "brightness": 0.1,
    "contrast": 0.1,
    "saturation": 0.1,
    "hue": 0.1,
    "color_jitter_prob": 0.2,
    "gray_scale_prob": 0.2,
}

In [None]:
def run_experiment(
    train_size,
    epochs,
    pretrained=False,
    version="RandomInit",
    path_to_load_backbone=None,
    fine_tune=False,
    ft_epoch=30,
    project="CLF-2_classes",
    train_batch_size=8,
):

    metadata_train, metadata_test, metadata_valid = generate_metadata_train_test_stratified_cv(metadata, train_size=train_size, n_split=5)

    hypm_kwargs = {
        "version": version,
        "model_name": "Classifier",
        "amount_of_ft_data": metadata_train[1].shape[0],
        "backbone": "resnet18",
        "pretrained": pretrained,
        "fine_tune": fine_tune,
        "ft_epoch": ft_epoch,
        "ft_lr": 0.000005,
        "in_channels": 3,
        "num_classes": 2,
        "class_names": select_classes,
        "normalizing_factor": 6000,
        "weight_decay": 0.00005,
        "learning_rate": 1e-3,
        "train_batch_size": train_batch_size,
        "test_batch_size": 128,
        "epochs": epochs,
        "augment_train": augment,
        "augment_test": None,
        "patient": 10,
        "eval_epoch": 1,
    }

    wandb_kwargs = {
        "project": "CLF-2_classes",
        "entity": "omar-c",
        "id": None,
        "name": version,
        "resume": False,
    }

    for train, test, valid in zip(metadata_train, metadata_test, metadata_valid):
        metadata_kwargs = {
            "path_to_images": path_to_images,
            "path_to_save_model": None,  # path_to_save_model,  # Path to save the model that is being trained (do not include the extension .pt or .pth)
            "path_to_load_model": None,  # Path to load a model from a checkpoint (useful to handle notebook disconection)
            "path_to_load_backbone": path_to_load_backbone,
            "metadata_train": train,
            "metadata_test": test,
            "metadata_valid": valid,
            "num_workers": 6,
            "device": "cuda",
        }
        trainer = Trainer(CustomDataset, visualize_augmented_images, wandb_kwargs, hypm_kwargs, metadata_kwargs)
        trainer.configure_trainer()

        trainer.fit()

In [None]:
run_experiment(
    train_size=i,
    epochs=100,
    pretrained=True,
    version="Elevation+SimCLR",
    path_to_load_backbone=path_to_load_backbone_elevation_simclr,
    fine_tune=True,
    ft_epoch=10,
    train_batch_size=5)