In [None]:
!pip install wandb timm h5py

In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
#import geopandas as gpd
import wandb
#import matplotlib.pyplot as plt
import h5py
import time
import copy

In [None]:
torch.cuda.is_available()

In [None]:
wandb.login()

In [None]:
means_np = np.load('sen2_65k_181b_means.npy')
stds_np = np.load('sen2_65k_181b_stds.npy')

In [None]:
bands_idxs = [i for i in range(180)] + [285]

In [None]:
#WITHOUT LOADING IN MEMORY
class HDF5Dataset(Dataset):
    def __init__(self, hdf5_path, transform=None, standardization=True):
        self.hdf5_path = hdf5_path
        self.transform = transform
        self.standardization = standardization
        self.h5file = h5py.File(hdf5_path, 'r')
        self.size = self.h5file['labels'].shape[0] #.size
        #self.data = torch.from_numpy(self.h5file['crops'][:].astype(np.float32) / 10000.0)
        #self.labels = torch.as_tensor(self.h5file['labels'][:],dtype=torch.long)
        #self.labels[self.labels==255]=20 #added for test with weights - TEMP
        self.means = torch.from_numpy(means_np.astype(np.float32)).view(181, 1, 1) 
        self.stds = torch.from_numpy(stds_np.astype(np.float32)).view(181, 1, 1) 
    
    def __len__(self):
        return self.size
    
    
    def __getitem__(self, idx):
        if self.standardization:
            crop = torch.from_numpy(self.h5file['crops'][idx, bands_idxs, :, :].astype(np.float32))
            crop = torch.where(crop > 10000, 10000, crop) #normalization
            crop = (crop - self.means)/self.stds
        else:
            crop = torch.from_numpy(self.h5file['crops'][idx, bands_idxs, :, :].astype(np.float32) / 10000.0)#self.data[idx]
            crop = torch.where(crop > 1, 1.0, crop) #normalization
        
        label = torch.as_tensor(self.h5file['labels'][idx],dtype=torch.long)#self.labels[idx]
        label[label==255]=20

        if self.transform:
            #label = label.unsqueeze(0) # Add a channel dimension to the label
            # Concatenate the label as an additional channel to the crop
            combined = torch.cat((crop, label.unsqueeze(0)), dim=0)
            combined = self.transform(combined)

            # Split the crop and label back into separate tensors
            crop = combined[:-1]  # All but the last channel
            label = combined[-1].long()  # The last channel

        return crop, label

In [None]:
from torchvision.transforms import v2
train_transforms = v2.Compose([
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip(),
    #v2.RandomCrop(size=(56,56))
    #v2.RandomSolarize(0.05)
    #v2.ToTensor()  # Convert image to PyTorch tensor
])
#train_transforms = None


In [None]:
train_set = HDF5Dataset("crops_train_seg_all_augmented.hdf5", transform=train_transforms, standardization=True)
test_set = HDF5Dataset("crops_test_seg_all.hdf5", standardization=True)

In [None]:
def trainModel(model, train_loader, test_loader, test_eval, optimizer, criterion, log_to_wandb=True, config_wandb=None):

    if log_to_wandb:
        wandb.init(project="ifn-weakly-supervised-seg-v5", config=config_wandb)
    best_val_acc = 0.0
    for epoch in range(config_wandb['epochs']):
        tstart = time.time()
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
        #for inputs, _, labels in train_loader: #uncomment to use dataset with new labels
            # Transfer to GPU 
            inputs, labels = inputs.cuda(), labels.cuda()

            # Zero the parameter gradients
            optimizer.zero_grad()

            # # Define the mask for loss computation
            # mask = labels!=255
            # mask_expanded = mask.unsqueeze(1).expand(-1, num_classes, -1, -1)

            # # Forward pass
            # outputs = model(inputs)
            # outputs_selected = outputs[mask_expanded].view(-1, num_classes)
            # labels_selected = labels[mask]
            # loss = criterion(outputs_selected, labels_selected) #no need to have softmax applied earlier

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # #_, predicted = torch.max(outputs, 1)
            # predicted_selected = torch.argmax(outputs_selected, dim=1) #no need to have softmax applied earlier
            # total += labels_selected.size(0)
            # correct += (predicted_selected == labels_selected).sum().item()

            _, predicted = torch.max(outputs, dim=1)
            mask = labels!=20
            labels_selected = labels[mask]
            predicted_selected = predicted[mask]
            correct += (predicted_selected == labels_selected).sum().item()  # Sum the correct predictions
            total += labels_selected.size(0)

            # Print statistics
            running_loss += loss.item()
            
        avg_loss = running_loss / len(train_loader)
        if test_eval:
            val_loss, val_accuracy = evaluate(model, test_loader, criterion)
        else:
            val_loss, val_accuracy = (0, 0)
        if test_eval and val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            best_model = copy.deepcopy(model)
        if log_to_wandb:
            wandb.log({"epoch": epoch, "train_loss": avg_loss, "train_acc":correct/total, "val_loss":val_loss, "val_acc":val_accuracy})
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Acc: {correct / total :.4f}, Test_eval: {str(test_eval)}, Test Acc: {val_accuracy:.4f}, Time/epoch: {round((time.time()-tstart)/60,2)}min")

    model_name = f"Model_{config_wandb['architecture']}_depth{config_wandb['depth']}_dim{config_wandb['dim']}_batch{config_wandb['batch_size']}_lr{str(config_wandb['learning_rate'])[2:]}_Aug{config_wandb['augmentations']}_{config_wandb['optimizer']}_{config_wandb['criterion']}.pt"
    torch.save(best_model, model_name)
    if log_to_wandb:
        wandb.finish()


In [None]:
def evaluate(model, dataloader, criterion):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient calculation
        for inputs, labels in dataloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            # Define the mask for loss computation
            # mask = labels!=255
            # mask_expanded = mask.unsqueeze(1).expand(-1, num_classes, -1, -1)
            # outputs = model(inputs)
            # outputs_selected = outputs[mask_expanded].view(-1, num_classes)
            # labels_selected = labels[mask]
            # loss = criterion(outputs_selected, labels_selected)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()

            # predicted_selected = torch.argmax(outputs_selected, dim=1) #no need to have softmax applied earlier
            # total += labels_selected.size(0)
            # correct += (predicted_selected == labels_selected).sum().item()

            _, predicted = torch.max(outputs, dim=1)
            mask = labels!=20
            labels_selected = labels[mask]
            predicted_selected = predicted[mask]
            correct += (predicted_selected == labels_selected).sum().item()  # Sum the correct predictions
            total += labels_selected.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy

In [None]:
pretrained = torch.load('Model_ConvNextV2_depth[2, 2, 6, 2]_dim[40, 80, 160, 320]_batch64_lr0001_AugH&V_Flip_Adam_MSE_STAND.pt')
ssl_model = pretrained

In [None]:
#fixed_names = {k.replace('encoder.',''):v for k, v in ssl_model.items()}
from collections import OrderedDict
import math
def remap_checkpoint_keys(ckpt):
    new_ckpt = OrderedDict()
    for k, v in ckpt.items():
        if k.startswith("encoder"):
            k = ".".join(k.split(".")[1:])  # remove encoder in the name
        if k.endswith("kernel"):
            k = ".".join(k.split(".")[:-1])  # remove kernel in the name
            new_k = k + ".weight"
            if len(v.shape) == 3:  # resahpe standard convolution
                kv, in_dim, out_dim = v.shape
                ks = int(math.sqrt(kv))
                new_ckpt[new_k] = (
                    v.permute(2, 1, 0).reshape(out_dim, in_dim, ks, ks).transpose(3, 2)
                )
            elif len(v.shape) == 2:  # reshape depthwise convolution
                kv, dim = v.shape
                ks = int(math.sqrt(kv))
                new_ckpt[new_k] = (
                    v.permute(1, 0).reshape(dim, 1, ks, ks).transpose(3, 2)
                )
            continue
        elif "ln" in k or "linear" in k:
            k = k.split(".")
            k.pop(-2)  # remove ln and linear in the name
            new_k = ".".join(k)
        elif "backbone.resnet" in k:
            # sometimes the resnet model is saved with the prefix backbone.resnet
            # we need to remove this prefix
            new_k = k.split("backbone.resnet.")[1]
        else:
            new_k = k
        new_ckpt[new_k] = v

    # reshape grn affine parameters and biases
    for k, v in new_ckpt.items():
        if k.endswith("bias") and len(v.shape) != 1:
            new_ckpt[k] = v.reshape(-1)
        elif "grn" in k:
            new_ckpt[k] = v.unsqueeze(0).unsqueeze(1)
    return new_ckpt

fixed_names = remap_checkpoint_keys(ssl_model)

In [None]:
from itertools import product
from convnextv2_unet import ConvNeXtV2_unet

ws = [1 for i in range(21)] ##TEMP BLOCK - added to experiment loss with weights
ws[-1] = 0
ws = torch.tensor(ws).float().cuda()

# sum_freq = train_set.labels.unique(return_counts=True)[1][:-1].sum() ##TEMP BLOCK - added to experiment with balancing loss weights
# ws = [sum_freq/(i*20) for i in train_set.labels.unique(return_counts=True)[1][:-1]]
# ws.append(torch.tensor(0.))
# ws = torch.tensor(ws).float().cuda()

num_classes = 21

num_epochs = 50
criterion = nn.CrossEntropyLoss(weight=ws)#FocalLoss(gamma=2.5, alpha=ws)#
test_eval = True
log_to_wandb = False

batch_size_list = [32]
lr_list = [0.0001]
depths_list = [
    [2, 2, 6, 2],#[3, 3, 9, 3],
]
dims_list = [
    [40, 80, 160, 320],#[96, 192, 384, 768],
]

for batch_size, lr, depth, dim in product(batch_size_list, lr_list, depths_list, dims_list):
    #set data loaders, which will vary according to the batch_size
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8,pin_memory=True,persistent_workers=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=8,pin_memory=True)
    #set model, which will vary according to the depth and dim
    model = None #first make sure no previous model was initialized
    #model = UNet(in_channels=181, out_channels=num_classes, init_features=64)
    #model.load_state_dict(pretrained , strict=False)
    model = ConvNeXtV2_unet(img_size=64, patch_size=8, in_chans=181, num_classes=num_classes, depths=depth, dims=dim)#, use_orig_stem=False)
    model.load_state_dict(fixed_names, strict=False)
    #set optimizer, which will vary according to the learning rate
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    #set wandb config dictionary, which will vary depending on the parameters
    config_wandb = {
    "optimizer": "Adam", #fixed
    "criterion": "CE_ws_64_181_aug_STAND_pretrainedSSL_2", #fixed
    "learning_rate": lr, 
    "epochs": num_epochs, #fixed
    "batch_size": batch_size,
    "augmentations":"H&V_Flip", #fixed
    "architecture":"ConvNextV2_mod",#"ConvNextV2_mod", "SimpleUNet" #fixed
    "depth": depth,
    "dim": dim
    }

    model = model.cuda() #send model to device

    #train model
    trainModel(model, train_loader, test_loader, test_eval, optimizer, criterion, log_to_wandb, config_wandb)