In [None]:
# Import libraries
import datetime as dt
import torch 
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from argparse import ArgumentParser
from sslcd import CloudDetectionDataModule
from sslcd import ResNet
from sslcd import seed_all
from sslcd import DelayedUnfreeze
torch.set_float32_matmul_precision('medium')

In [None]:
dataset_path = ""

dataset_name = {"WHUS2CD":"WHUS2-CD+",
                "CloudSen12":"CloudSEN12"}

In [None]:
dataset = "WHUS2CD"  #"CloudSen12"

In [None]:
config = dict(
    # DataModule Settings
    data_dir = dataset_path+f"/{dataset}",
    seed = 42,
    batch_size = 16,
    num_workers = 8,
    limit_dataset = 0.25, # Fraction of the dataset to use for finetuning (between 0 and 1)
    patch_size = 256,
    dataset = dataset_name[dataset],
    task = "cloud",
    
    # Trainer Settings
    devices = 1,
    accelerator="gpu",
    
    # Model Parameters
    use_mlp = False,
    num_classes = 1,
    input_ch = 13,
    backbone = "resnet18",
    segmentation = True,
    checkpoint= "", # pretraining checkpoint from MoCo or DeepCluster   
    
    # Optimizer Parameters
    optimizer = "Adam",
    scheduler = "CosineAnnealingLR",
    momentum = 0.9,
    max_epochs = 50,
    learning_rate = 0.04
)

parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = ResNet.add_model_specific_args(parser)
parser = CloudDetectionDataModule.add_model_specific_args(parser)

args, arg_strings = parser.parse_known_args([], None)
for key, value in config.items():
    setattr(args, key, value)

seed_all(config['seed'])

datamodule = CloudDetectionDataModule.from_argparse_args(args)

args.classification_head = torch.nn.Sequential(
    torch.nn.Conv2d(512, 64, kernel_size=3, padding=1),
    torch.nn.InstanceNorm2d(64),
    torch.nn.ReLU(),
    torch.nn.Conv2d(64, config['num_classes'], kernel_size=1)
)

model = ResNet(**args.__dict__)

if "moco" in config['checkpoint'].lower():
    filter_and_remap = "encoder_q"
    ssl_framework = "moco"
elif "deepcluster" in config['checkpoint'].lower():
    filter_and_remap = "backbone"
    ssl_framework = "deepcluster"

model.load_from_checkpoint(config['checkpoint'], filter_and_remap=filter_and_remap)
print ("Weights loaded from pretraining")

In [None]:
checkpointer = pl.callbacks.ModelCheckpoint(
dirpath=f'experiments/{dataset.lower()}/{ssl_framework}/{config['task']}/ft_{config['limit_dataset']*100:.0f}',
    filename="{epoch}-{val_acc:.2f}",
    monitor="val_acc",
    mode="max",
    save_last=True
)

early_stopping_callback = pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=15)

unfreezer = DelayedUnfreeze(backbone_id="model", unfreeze_at_epoch=5, train_frozen_bn=True, reset_lr=7e-3)

callbacks = [checkpointer, early_stopping_callback, unfreezer]

current_datetime = dt.datetime.now().strftime("%Y%m%d-%H%M%S")
logger = TensorBoardLogger(
                    save_dir="experiments/logs",
                    name = "tensorboard/",
                    version=f"{dataset.lower()}_{ssl_framework}_{config['task']}_ft_{config['limit_dataset']*100:.0f}_{current_datetime}"
                    )

In [None]:
# Find learning rate 
trainer_lr = pl.Trainer.from_argparse_args(args, enable_checkpointing=False, logger=None, auto_lr_find=True)
lr_finder = trainer_lr.tune(model, datamodule=datamodule, lr_find_kwargs={"min_lr": 1e-7, "max_lr": 1e-1})

args.learning_rate = lr_finder['lr_find'].suggestion()
print ("Suggested learning rate:",args.learning_rate)

In [None]:
# Training
trainer = pl.Trainer.from_argparse_args(args, enable_checkpointing=True, 
                                        logger=logger, 
                                        callbacks=callbacks, 
                                        auto_lr_find=False
                                        )
run = trainer.fit(model, datamodule=datamodule)