In [None]:
# Import libraries
import datetime as dt
import torch 
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from argparse import ArgumentParser
from sslcd import SEN12MSDataModule, CloudDetectionDataModule
from sslcd import DeepCluster, MoCo, ResNet
from sslcd import seed_all
from sslcd import DelayedUnfreeze
torch.set_float32_matmul_precision('medium')

In [2]:
dataset_path = "/projects/sampeo/RepreSentCCN/represent_uc3_cloud_detection/Datasets"

## MoCo: Pretraining on SEN12MS

In [None]:
config = dict(
    # DataModule Settings
    data_dir = dataset_path+"/SEN12MS",
    seed = 42,
    batch_size = 64,
    num_workers = 8,
    patch_size = 256,

    # Trainer Settings
    gpus = 1,
    accelerator="gpu",
    
    # Model Parameters
    input_ch = 13,
    band_set = 's2-all',
    max_epochs = 100
)


parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = MoCo.add_model_specific_args(parser)
parser = SEN12MSDataModule.add_model_specific_args(parser)

args, arg_strings = parser.parse_known_args([], None)
for key, value in config.items():
    setattr(args, key, value)

seed_all(config['seed'])

datamodule = SEN12MSDataModule.from_argparse_args(args)

model = MoCo(**args.__dict__)

checkpointer = pl.callbacks.ModelCheckpoint(
    dirpath=f'./experiments/sen12ms/moco/pretraining',
    filename="{epoch}-{val_loss:.2f}",
    monitor='val_loss',
    save_last=True
)

callbacks = [checkpointer]

In [None]:
current_datetime = dt.datetime.now().strftime("%Y%m%d-%H%M%S")

logger = TensorBoardLogger(
                    save_dir="experiments/logs",
                    name = "tensorboard/",
                    version=f"sen12ms_moco_pretraining_{current_datetime}"
                    )

trainer = pl.Trainer.from_argparse_args(args, enable_checkpointing=True, 
                                        logger=logger, 
                                        callbacks=callbacks)

trainer.fit(model, datamodule=datamodule)

## DeepCluster: Pretraining on WHUS2-CD+/CloudSEN12

In [4]:
dataset_name = {"WHUS2CD":"WHUS2-CD+",
                "CloudSen12":"CloudSEN12"}
dataset = "WHUS2CD"  #"CloudSen12"

In [None]:
config = dict(
    # DataModule Settings
    data_dir = dataset_path+f"/{dataset}",
    seed = 42,
    batch_size = 128,
    num_workers = 8,
    patch_size = 256,
    dataset = dataset_name[dataset],
    pretraining = True,

    # Trainer Settings
    gpus = 1,
    accelerator="gpu",
    
    # Model Parameters
    use_mlp = False,
    input_ch = 13,
    num_classes = 21,
    backbone = "resnet18",
    proj_hidden_dim = 2048,
    proj_output_dim = 128,
    temperature = 0.1,
    kmeans_iters = 10,
    
    # Optimizer Parameters
    optimizer = "Adam",
    scheduler = "CosineAnnealingLR",
    momentum = 0.9,
    max_epochs = 100,
    learning_rate = 0.6,
    classifier_lr = 0.1,
)

parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = DeepCluster.add_model_specific_args(parser)
parser = CloudDetectionDataModule.add_model_specific_args(parser)

args, arg_strings = parser.parse_known_args([], None)
for key, value in config.items():
    setattr(args, key, value)

seed_all(config['seed'])

datamodule = CloudDetectionDataModule.from_argparse_args(args)

model = DeepCluster(**args.__dict__)

checkpointer = pl.callbacks.ModelCheckpoint(
    dirpath=f'./experiments/{dataset.lower()}/deepcluster/pretraining',
    filename="{epoch}-{val_acc1:.2f}",
    save_last=True
)

callbacks = [checkpointer]

In [None]:
current_datetime = dt.datetime.now().strftime("%Y%m%d-%H%M%S")

logger = TensorBoardLogger(
                    save_dir="./experiments/logs",
                    name = "tensorboard/",
                    version=f"{dataset.lower()}_deepcluster_pretraining_{current_datetime}"
                    )

trainer = pl.Trainer.from_argparse_args(args, enable_checkpointing=True, 
                                        logger=logger, 
                                        callbacks=callbacks, 
                                        auto_lr_find=False)

trainer.fit(model, datamodule=datamodule)