In [1]:
import sys
import os
sys.path.append('..')
import numpy as np
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch

from training.train_model import trainWSLModel
from models.convnextv2_unet import ConvNeXtV2_unet
from utils.data_loading import WSL_Dataset
from utils.helper import remap_checkpoint_keys

import wandb



# Load WSL train and test sets

In [8]:
#Load means and stds for data standardization
means_np = np.load('../data/sen2_65k_181b_means.npy')
stds_np = np.load('../data/sen2_65k_181b_stds.npy')

In [9]:
#define train and test set paths
wsl_train_set_path = "../data/crops_train_seg_all_64x64_181b_augmented.hdf5"
wsl_test_set_path = "../data/crops_test_seg_all_64x64_181b.hdf5"

In [10]:
#define transforms to be applied to training data
train_transforms = v2.Compose([
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip()
])

In [None]:
wsl_train_set = WSL_Dataset(wsl_train_set_path,
                        transform=train_transforms,
                        standardize=True,
                        means_np=means_np,
                        stds_np=stds_np,
                        downsample_classes=[5,10,14,16])

In [None]:
wsl_test_set = WSL_Dataset(wsl_test_set_path,
                        transform=train_transforms,
                        standardize=True,
                        means_np=means_np,
                        stds_np=stds_np,
                        downsample_classes=[5,10,14,16])

# Train Weakly Supervised Baseline Model

## Set hyperparameters and create model (ConvNext-V2 U-Net)

In [3]:
num_epochs = 50
batch_size = 32
lr = 0.0001

depths = [2, 2, 6, 2]
dims = [40, 80, 160, 320]

img_size = 64 #NxN pixels
patch_size = 8 #NxN pixels
in_chans = 181 #bands

num_classes = 21 #20 classes + 1 additional class for unlabeled data

In [4]:
model = ConvNeXtV2_unet(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, depths=depths, dims=dims)

## Loss and optimizer

In [5]:
#define loss criterion
#set weight of the last class (20 - unlabeled pixels) to zero
ws = [1 for i in range(21)]
ws[-1] = 0
ws = torch.tensor(ws).float().cuda()
criterion = torch.nn.CrossEntropyLoss(weight=ws)

#optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

## Create DataLoaders

In [6]:
wsl_train_loader = DataLoader(wsl_train_set, batch_size=batch_size, shuffle=True, num_workers=8)
wsl_test_loader = DataLoader(wsl_test_set, batch_size=batch_size, shuffle=False, num_workers=8)

## Model training

In [7]:
#define whether to log model statistics to wandb
log_to_wandb = False
wandb_proj = 'ifn-wsl'
if log_to_wandb:
    wandb.login()

In [8]:
#define run configs
test_eval = True #compute statistics for the test set
mask_pixel = 20 #mask pixels equal to 20 (unlabeled)
save_model = False
run_config = {
    "epochs":num_epochs,
    "batch_size":batch_size,
    "learning_rate":lr,
    "optimizer":"Adam",
    "criterion":"WCE", #weighted Cross-Entropy
    "augmentations":"H&V_Flip",
    "architecture":"ConvNextV2_UNet",
    "depths":depths,
    "dims":dims
    }

In [None]:
trainWSLModel(model,
            wsl_train_loader,
            wsl_test_loader,
            optimizer,
            criterion,
            test_eval=test_eval,
            mask_pixel=mask_pixel,
            log_to_wandb=log_to_wandb,
            wandb_proj=wandb_proj,
            run_config=run_config,
            save=save_model)

# Finetune Self-Supervised pretrained MAE model

## Hyperparameters

In [13]:
num_epochs = 50
batch_size = 32
lr = 0.0001

depths = [2, 2, 6, 2]
dims = [40, 80, 160, 320]

img_size = 64 #NxN pixels
patch_size = 8 #NxN pixels
in_chans = 181 #bands

num_classes = 21 #20 classes + 1 additional class for unlabeled data

## Load MAE model weights

*Refer to README.md on models/saved_models to download our pretrained model.*

*Alternatively, load your own saved model.*

In [None]:
pretrained_model_weights = torch.load("../models/saved_models/MAEModel_FCMAE_depths[2-2-6-2]_dim[40-80-160-320]_batch128_lr00015_AugH&V_Flip_Adam_MSE.pt")

#harmonize dict keys to facilitate weight transfer
pretrained_model_weights = remap_checkpoint_keys(pretrained_model_weights)

## Create ConvNext-V2 U-Net model and transfer weights

In [15]:
#create new model
model = ConvNeXtV2_unet(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, depths=depths, dims=dims)

#transfer weights from pretrained model
model.load_state_dict(pretrained_model_weights, strict=False)

_IncompatibleKeys(missing_keys=['norm.weight', 'norm.bias', 'head.weight', 'head.bias', 'upsample_layers.0.conv.weight', 'upsample_layers.0.conv.bias', 'upsample_layers.0.norm.weight', 'upsample_layers.0.norm.bias', 'upsample_layers.1.conv.weight', 'upsample_layers.1.conv.bias', 'upsample_layers.1.norm.weight', 'upsample_layers.1.norm.bias', 'upsample_layers.2.conv.weight', 'upsample_layers.2.conv.bias', 'upsample_layers.2.norm.weight', 'upsample_layers.2.norm.bias', 'upsample_layers.3.conv.weight', 'upsample_layers.3.conv.bias', 'upsample_layers.3.norm.weight', 'upsample_layers.3.norm.bias', 'initial_conv_upsample.0.weight', 'initial_conv_upsample.0.bias', 'initial_conv_upsample.1.weight', 'initial_conv_upsample.1.bias'], unexpected_keys=['mask_token'])

## Loss and optimizer

In [16]:
#define loss criterion
#set weight of the last class (20 - unlabeled pixels) to zero
ws = [1 for i in range(21)]
ws[-1] = 0
ws = torch.tensor(ws).float().cuda()
criterion = torch.nn.CrossEntropyLoss(weight=ws)

#optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

## DataLoaders

In [17]:
wsl_train_loader = DataLoader(wsl_train_set, batch_size=batch_size, shuffle=True, num_workers=8)
wsl_test_loader = DataLoader(wsl_test_set, batch_size=batch_size, shuffle=False, num_workers=8)

## Model finetuning

In [18]:
#define whether to log model statistics to wandb
log_to_wandb = False
wandb_proj = 'ifn-wsl'
if log_to_wandb:
    wandb.login()

In [19]:
#define run configs
test_eval = True #compute statistics for the test set
mask_pixel = 20 #mask pixels equal to 20 (unlabeled)
save_model = False
run_config = {
    "epochs":num_epochs,
    "batch_size":batch_size,
    "learning_rate":lr,
    "optimizer":"Adam",
    "criterion":"WCE", #weighted Cross-Entropy
    "augmentations":"H&V_Flip",
    "architecture":"ConvNextV2_UNet",
    "depths":depths,
    "dims":dims
    }

In [None]:
trainWSLModel(model,
            wsl_train_loader,
            wsl_test_loader,
            optimizer,
            criterion,
            test_eval=test_eval,
            mask_pixel=mask_pixel,
            log_to_wandb=log_to_wandb,
            wandb_proj=wandb_proj,
            run_config=run_config,
            save=save_model)