In [2]:
import sys
import os
sys.path.append('..')
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2

from utils.data_loading import SSL_Dataset
from training.train_model import trainMAE
from models.fcmae import FCMAE

import wandb

# Train Self-Supervised Masked Autoencoder (MAE)

## Load training data

In [3]:
#Load means and stds for data standardization
means_np = np.load('../data/sen2_65k_181b_means.npy')
stds_np = np.load('../data/sen2_65k_181b_stds.npy')

In [None]:
ssl_train_set_path = "../data/crops_train_all_SSL.hdf5"

In [None]:
#define transforms to be applied to training data
train_transforms = v2.Compose([
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip()
])

In [None]:
ssl_train_set = SSL_Dataset(ssl_train_set_path,
                            transform=train_transforms,
                            standardize=True,
                            means_np=means_np,
                            stds_np=stds_np)

## Set hyperparameters and load model (FCMAE)

In [None]:
num_epochs = 1000
batch_size = 128
lr = 0.0015

depths = [2, 2, 6, 2]
dims = [40, 80, 160, 320]

img_size = 56 #NxN pixels
patch_size = 8 #NxN pixels
in_chans = 181 #bands
mask_ratio = 0.6

In [None]:
model = FCMAE(img_size=img_size, img_size=img_size, in_chans=in_chans, mask_ratio=mask_ratio, depths=depths, dims=dims)

## Optimizer & Scheduler

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
#setup a lr scheduler to run 700 epochs at lr=0.0015 than change lr to 0.00015
def lr_lambda(epoch):
    return lr if epoch < 700 else 0.00015

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

## Create DataLoader

In [None]:
ssl_train_loader = DataLoader(ssl_train_set, batch_size=batch_size, shuffle=True, num_workers=8)

## Model training

In [None]:
#define whether to log model training to wandb
log_to_wandb = False
wandb_proj = 'ifn-ssl-mae'
if log_to_wandb:
    wandb.login()

In [None]:
#define run configs
save_model = False
run_config = {
    "epochs":num_epochs,
    "batch_size":batch_size,
    "learning_rate":lr,
    "optimizer":"Adam",
    "criterion":"MSE", #Mean Squared Error (computed internally by the model)
    "augmentations":"H&V_Flip",
    "architecture":"FCMAE",
    "depths":depths,
    "dims":dims
    }

In [None]:
trainMAE(model,
         ssl_train_loader,
         optimizer,
         mask_ratio,
         scheduler,
         log_to_wandb=log_to_wandb,
         wandb_proj=wandb_proj,
         run_config=run_config,
         save=save_model
)