In [1]:
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from glob import glob
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import pickle
import random
import time

import deep_snow.models
import deep_snow.dataset

In [2]:
# def set_seed(seed: int = 43):
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False

#     set_seed()

In [3]:
def sample_lognormal(center, sigma):
    mu = np.log(center)
    return np.random.lognormal(mean=mu, sigma=sigma)

# def visualize_lognormal(center, sigmas, n_samples=10000):
#     f, ax = plt.subplots(figsize=(10, 6))
#     for sigma in sigmas:
#         mu = np.log(center)
#         samples = np.random.lognormal(mean=mu, sigma=sigma, size=n_samples)
#         sns.kdeplot(samples, label=f'sigma={sigma}', ax=ax)
#     ax.set_xlabel('Sampled Value')
#     ax.set_ylabel('Density')
#     ax.set_title(f'Lognormal Samples Centered on {center}')
#     f.legend()
#     ax.grid(True, which='both', linestyle='--')
#     ax.set_xlim(0, 0.0005)
#     ax.set_ylim(0)

# # Example usage
# visualize_lognormal(center=1e-4, sigmas=[1.0])

In [4]:
# get paths to data
train_data_dir = '/mnt/working/brencher/repos/deep-snow/data/subsets_v4/train'
train_path_list = glob(f'{train_data_dir}/ASO_50M_SD*.nc')

val_data_dir = '/mnt/working/brencher/repos/deep-snow/data/subsets_v4/val'
val_path_list = glob(f'{val_data_dir}/ASO_50M_SD*.nc')

In [5]:
def train_model(input_channels, return_channels, epochs, lr, weight_decay, n_layers=5):
    model = deep_snow.models.ResDepth(n_input_channels=len(input_channels), depth=n_layers)
    model_name = f'ResDepth_lr{lr}_weightdecay{weight_decay}'
    model.to('cuda');  # Run on GPU
    # Define optimizer and loss function
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    loss_fn = nn.MSELoss()
    epochs = epochs
    
    train_loss = []
    val_loss = []
    counter = 0
    min_val_loss = 1
    patience = 0
    patience_limit = 30

    # training and validation loop
    for epoch in range(epochs):
        epoch_start_time = time.time()
        print(f'\nStarting epoch {epoch+1}')
        train_epoch_loss = []
        val_epoch_loss = []
            
        # Loop through training data with tqdm progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", ncols=130)
        for data_tuple in pbar:
            model.train()
            optimizer.zero_grad()
    
            # read data into dictionary
            data_dict = {name: tensor for name, tensor in zip(return_channels, data_tuple)}
            # prepare inputs by concatenating along channel dimension
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
    
            # generate prediction
            pred_sd = model(inputs)
    
            # Limit prediction to areas with valid data
            pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
            aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))
    
            # Calculate loss
            train_batch_loss = loss_fn(pred_sd, aso_sd.to('cuda'))
            train_epoch_loss.append(train_batch_loss.item())
    
            # Update tqdm progress bar with batch loss
            pbar.set_postfix({'batch loss': train_batch_loss.item(), 'mean epoch loss': np.mean(train_epoch_loss)})
    
            train_batch_loss.backward()  # Propagate the gradients in backward pass
            optimizer.step()
    
        train_loss.append(np.mean(train_epoch_loss))
        print(f'Training loss: {np.mean(train_epoch_loss)}')
    
        # Run model on validation data with tqdm progress bar
        for data_tuple in tqdm(val_loader, desc="Validation", unit="batch"):
            with torch.no_grad():
                model.eval()
                
                # read data into dictionary
                data_dict = {name: tensor for name, tensor in zip(return_channels, data_tuple)}
                # prepare inputs by concatenating along channel dimension
                inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
        
                # generate prediction
                pred_sd = model(inputs)
        
                # Limit prediction to areas with valid data
                pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
                aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))
        
                # Calculate loss
                val_batch_loss = loss_fn(pred_sd, aso_sd.to('cuda'))
                val_epoch_loss.append(val_batch_loss.item())
    
        val_loss.append(np.mean(val_epoch_loss))
        print(f'Validation loss: {np.mean(val_epoch_loss)}')
        scheduler.step(np.mean(val_epoch_loss))

        # save loss 
        with open(f'../../../loss/{model_name}_val_loss.pkl', 'wb') as f:
            pickle.dump(val_loss, f)
            
        with open(f'../../../loss/{model_name}_train_loss.pkl', 'wb') as f:
            pickle.dump(train_loss, f)
        
        # Early stopping check (start saving after 30 epochs)
        if np.mean(val_epoch_loss) < min_val_loss:
            min_val_loss = np.mean(val_epoch_loss)
            min_val_loss_epoch = epoch
            patience = 0
            if epoch > 30:
                torch.save(model.state_dict(), f'../../../weights/{model_name}_epochs{epoch}_minvalloss{min_val_loss:.5f}')
        else:
            patience += 1

        if patience >= patience_limit:
            print(f"\nEarly stopping at epoch {epoch + 1}. No improvement in validation loss for {patience_limit} epochs.")
            break

        epoch_end_time = time.time()
        print(f'epoch time: {epoch_end_time - epoch_start_time:.4f} seconds')

    #plot_loss(train_loss, val_loss)
    return [min_val_loss_epoch, min_val_loss]

In [6]:
# define data to be returned by dataloader
return_channels = [
    # ASO products
    'aso_sd', # ASO lidar snow depth (target dataset)
    'aso_gap_map', # gaps in ASO data
    
    'delta_cr', # change in cross ratio, snowon_cr - snowoff_cr
    'rtc_gap_map', # gaps in Sentinel-1 data
   
    # Sentinel-2 products 
    'blue', # snow on Sentinel-2 blue band
    'swir1', # snow on Sentinel-2 shortwave infrared band 1
    'ndsi', # Normalized Difference Snow Index from Sentinel-2
    's2_gap_map', # gaps in Sentinel-2 data

    # snodas datset
    'snodas_sd', # snow depth

    # PROBA-V global land cover dataset (Buchhorn et al., 2020)
    'fcf', # fractional forest cover
    
    # COP30 digital elevation model      
    'elevation',
    'slope',
    'northness',
    'curvature',

    # day of water year
    'dowy'
                    ]

# prepare training and validation dataloaders
train_data = deep_snow.dataset.Datasetv2(train_path_list, return_channels, norm=True, cache_data=True)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
val_data = deep_snow.dataset.Datasetv2(val_path_list, return_channels, norm=True, augment=False, cache_data=True)
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=16, shuffle=True)

In [7]:
# define input channels for model
input_channels = ['snodas_sd',
                  'blue',
                  'swir1',
                  'ndsi',
                  'elevation',
                  'northness',
                  'slope',
                  'curvature',
                  'dowy',
                  'delta_cr',
                  'fcf'
                 ]

In [None]:
num_trials = 20
epochs=500
exp_dict = {}

for trial in range(num_trials):
    
    print('---------------------------------------------------------')
    print(f'starting trial {trial}')
    lr = sample_lognormal(center=3e-4, sigma=1.0)
    weight_decay = sample_lognormal(center=1e-4, sigma=1.0)
    print(f'lr: {lr}, weight decay: {weight_decay}')
    min_val_loss_epoch, min_val_loss = train_model(input_channels, return_channels, epochs=epochs, lr=lr, weight_decay=weight_decay)
    print(f'lr: {lr}, weight decay: {weight_decay}, final epoch: {min_val_loss_epoch}, final val loss: {min_val_loss}')
    exp_dict[trial] = [lr, weight_decay, min_val_loss_epoch, min_val_loss]
    # save experiments 
    with open(f'../../../loss/ResDepth_lr_tuning_loss_v4.pkl', 'wb') as f:
        pickle.dump(exp_dict, f)