# Train a model on the deep-snow dataset

In [1]:
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from glob import glob
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import pickle
import random

import deep_snow.models
import deep_snow.dataset

## Prepare dataloader

In [2]:
# get paths to data
train_data_dir = '/mnt/Backups/gbrench/repos/deep-snow/data/subsets_v4/train'
train_path_list = glob(f'{train_data_dir}/ASO_50M_SD*.nc')

val_data_dir = '/mnt/Backups/gbrench/repos/deep-snow/data/subsets_v4/val'
val_path_list = glob(f'{val_data_dir}/ASO_50M_SD*.nc')

In [3]:
# # to test code with a small sample of the data
# import random
# n_imgs = 16

# train_path_list = random.sample(train_path_list, n_imgs )
# val_path_list = random.sample(val_path_list, n_imgs)

In [4]:
# define data to be returned by dataloader
selected_channels = [
    # ASO products
    'aso_sd', # ASO lidar snow depth (target dataset)
    'aso_gap_map', # gaps in ASO data
    
    # Sentinel-1 products
    'snowon_vv', # snow on Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vh', # snow on Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vv', # snow off Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vh', # snow off Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vv_mean', # snow on Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_vh_mean', # snow on Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vv_mean', # snow off Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vh_mean', # snow off Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_cr', # cross ratio, snowon_vh - snowon_vv
    'snowoff_cr', # cross ratio, snowoff_vh - snowoff_vv
    'delta_cr', # change in cross ratio, snowon_cr - snowoff_cr
    'rtc_gap_map', # gaps in Sentinel-1 data
    'rtc_mean_gap_map', # gaps in Sentinel-1 mean data
    
    # Sentinel-2 products 
    'aerosol_optical_thickness', # snow on Sentinel-2 aerosol optical thickness band 
    'coastal_aerosol', # snow on Sentinel-2 coastal aerosol band
    'blue', # snow on Sentinel-2 blue band
    'green', # snow on Sentinel-2 green band
    'red', # snow on Sentinel-2 red band
    'red_edge1', # snow on Sentinel-2 red edge 1 band
    'red_edge2', # snow on Sentinel-2 red edge 2 band
    'red_edge3', # snow on Sentinel-2 red edge 3 band
    'nir', # snow on Sentinel-2 near infrared band
    'water_vapor', # snow on Sentinel-2 water vapor
    'swir1', # snow on Sentinel-2 shortwave infrared band 1
    'swir2', # snow on Sentinel-2 shortwave infrared band 2
    'scene_class_map', # snow on Sentinel-2 scene classification product
    'water_vapor_product', # snow on Sentinel-2 water vapor product
    'ndvi', # Normalized Difference Vegetation Index from Sentinel-2
    'ndsi', # Normalized Difference Snow Index from Sentinel-2
    'ndwi', # Normalized Difference Water Index from Sentinel-2
    's2_gap_map', # gaps in Sentinel-2 data

    # PROBA-V global land cover dataset (Buchhorn et al., 2020)
    'fcf', # fractional forest cover
    
    # COP30 digital elevation model      
    'elevation',
    'slope',
    'aspect',
    'curvature',
    'tpi',
    'tri',

    # latitude and longitude
    'latitude',
    'longitude',

    # day of water year
    'dowy'
                    ]

# prepare training and validation dataloaders
train_data = deep_snow.dataset.Dataset(train_path_list, selected_channels, norm=True)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True, num_workers=16)
val_data = deep_snow.dataset.Dataset(val_path_list, selected_channels, norm=True, augment=False)
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=16, shuffle=True, num_workers=16)

In [5]:
# define input channels for model
input_channels = [
    'snowon_vv',
    'delta_cr',
    'green',
    'swir2',
    'ndsi',
    'ndwi',
    'elevation',
    'latitude',
    'longitude']

## Train model

In [6]:
# import model
# model = deep_snow.models.SimpleCNN(n_input_channels=len(input_channels))
# model = deep_snow.models.UNet(n_input_channels=len(input_channels))
# model = deep_snow.models.ResUNet(n_input_channels=len(input_channels))
# model = deep_snow.models.ResDepth(n_input_channels=len(input_channels))
# model = deep_snow.models.VisionTransformer(n_input_channels=len(input_channels))

model = deep_snow.models.ResDepth(n_input_channels=len(input_channels), depth=5)
model.to('cuda');  # Run on GPU

# name your model
model_name = 'quinn_ResDepth_v9'

In [None]:
# Define optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=0.0003, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
loss_fn = nn.MSELoss()
epochs = 500
min_val_loss = 1

train_loss = []
val_loss = []

# training and validation loop
for epoch in range(epochs):
    print(f'\nStarting epoch {epoch+1}')
    train_epoch_loss = []
    val_epoch_loss = []

    # Loop through training data with tqdm progress bar
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", ncols=130)
    for data_tuple in train_pbar:
        model.train()
        optimizer.zero_grad()

        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}
        
        # prepare inputs by concatenating along channel dimension
        inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')

        # generate prediction
        pred_sd = model(inputs)

        # Limit prediction to areas with valid data
        pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
        aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda')== 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))

        # Calculate loss
        train_batch_loss = loss_fn(pred_sd, aso_sd.to('cuda'))
        train_epoch_loss.append(train_batch_loss.item())

        # Update tqdm progress bar with batch loss
        train_pbar.set_postfix({'batch loss': train_batch_loss.item(), 'mean epoch loss': np.mean(train_epoch_loss)})

        train_batch_loss.backward()  # Propagate the gradients in backward pass
        optimizer.step()

    train_loss.append(np.mean(train_epoch_loss))
    print(f'Training loss: {np.mean(train_epoch_loss)}')
    scheduler.step(np.mean(train_epoch_loss))

    # Run model on validation data with tqdm progress bar
    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", ncols=130)
    for data_tuple in val_pbar:
        with torch.no_grad():
            model.eval()
            
            # read data into dictionary
            data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}
            # prepare inputs by concatenating along channel dimension
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
    
            # generate prediction
            pred_sd = model(inputs)
    
            # Limit prediction to areas with valid data
            pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
            aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))
    
            # Calculate loss
            val_batch_loss = loss_fn(pred_sd, aso_sd.to('cuda'))
            val_epoch_loss.append(val_batch_loss.item())

            # Update tqdm progress bar with batch loss
            val_pbar.set_postfix({'batch loss': val_batch_loss.item(), 'mean epoch loss': np.mean(val_epoch_loss)})

    if np.mean(val_epoch_loss) < min_val_loss:
        if epoch > 30:
            min_val_loss = np.mean(val_epoch_loss)
            torch.save(model.state_dict(), f'../../weights/{model_name}_{epoch}epochs')
                    
    # # calculate loss over previous 10 epochs for early stopping later
    # if epoch > 20:
    #     past_loss = np.mean(val_loss[-20:-10])

    val_loss.append(np.mean(val_epoch_loss))
    print(f'Validation loss: {np.mean(val_epoch_loss)}')
    
    # save loss 
    with open(f'../../loss/{model_name}_val_loss.pkl', 'wb') as f:
        pickle.dump(val_loss, f)
        
    with open(f'../../loss/{model_name}_train_loss.pkl', 'wb') as f:
        pickle.dump(train_loss, f)

    # # implement early stopping
    # if epoch > 20:
    #     current_loss = np.mean(val_loss[-10:-1])
    #     if current_loss > past_loss:
    #         counter +=1
    #         if counter >= 10:
    #             print('early stopping triggered')
    #             # save model
    #             torch.save(model.state_dict(), f'../../weights/{model_name}_{epoch}epochs')
    #             break
    #     else:
    #         counter = 0


Starting epoch 1


Epoch 1/500: 100%|████████████████████████████| 774/774 [27:59<00:00,  2.17s/batch, batch loss=0.000921, mean epoch loss=0.000995]


Training loss: 0.0009945781601942445


Epoch 1/500: 100%|████████████████████████████| 135/135 [01:44<00:00,  1.30batch/s, batch loss=0.000721, mean epoch loss=0.000835]


Validation loss: 0.0008352722571645346

Starting epoch 2


Epoch 2/500: 100%|████████████████████████████| 774/774 [05:10<00:00,  2.49batch/s, batch loss=0.000254, mean epoch loss=0.000877]


Training loss: 0.0008768303138777723


Epoch 2/500: 100%|█████████████████████████████| 135/135 [00:44<00:00,  3.04batch/s, batch loss=0.00032, mean epoch loss=0.000909]


Validation loss: 0.0009085515623963956

Starting epoch 3


Epoch 3/500:  40%|███████████▋                 | 312/774 [01:43<02:28,  3.10batch/s, batch loss=0.00111, mean epoch loss=0.000799]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 47/500: 100%|████████████████████████████| 774/774 [04:07<00:00,  3.13batch/s, batch loss=0.000209, mean epoch loss=0.00039]


Training loss: 0.00039014903989268405


Epoch 47/500: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.15batch/s, batch loss=0.000776, mean epoch loss=0.000323]


Validation loss: 0.0003229957584591358

Starting epoch 48


Epoch 48/500: 100%|███████████████████████████| 774/774 [04:15<00:00,  3.02batch/s, batch loss=0.000384, mean epoch loss=0.000396]


Training loss: 0.00039600748736712526


Epoch 48/500: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.17batch/s, batch loss=0.000657, mean epoch loss=0.000382]


Validation loss: 0.0003823265849388446

Starting epoch 49


Epoch 49/500:  48%|█████████████▍              | 373/774 [01:59<02:02,  3.28batch/s, batch loss=0.00023, mean epoch loss=0.000388]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 92/500: 100%|███████████████████████████| 774/774 [04:14<00:00,  3.04batch/s, batch loss=0.000467, mean epoch loss=0.000316]


Training loss: 0.0003160885320380011


Epoch 92/500: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.18batch/s, batch loss=0.000123, mean epoch loss=0.000306]


Validation loss: 0.00030573596668546086

Starting epoch 93


Epoch 93/500: 100%|███████████████████████████| 774/774 [04:17<00:00,  3.01batch/s, batch loss=0.000132, mean epoch loss=0.000312]


Training loss: 0.0003123681050779729


Epoch 93/500: 100%|████████████████████████████| 135/135 [00:42<00:00,  3.15batch/s, batch loss=0.000155, mean epoch loss=0.00031]


Validation loss: 0.0003098306365750937

Starting epoch 94


Epoch 94/500:  52%|██████████████             | 404/774 [02:13<03:26,  1.79batch/s, batch loss=0.000275, mean epoch loss=0.000334]

## Examine results

In [None]:
#load previous model
model = deep_snow.models.ResDepth(n_input_channels=len(input_channels), depth=5)
model.load_state_dict(torch.load('../../weights/quinn_ResDepth_v4_74epochs'))
model.to('cuda');

In [None]:
with open(f'../../loss/quinn_ResDepth_v9_val_loss.pkl', 'rb') as f:
        val_loss = pickle.load(f)

with open(f'../../loss/quinn_ResDepth_v9_train_loss.pkl', 'rb') as f:
        train_loss = pickle.load(f)


# plot loss over all epochs
f, ax = plt.subplots(figsize=(10,5))
ax.plot(train_loss, label='training')
ax.plot(val_loss, label='validation')
ax.set_xlabel('epoch')
ax.set_ylabel('MSE loss')
ax.set_title('Loss')
ax.legend()

# save figure
plt.savefig(f'../../figs/quinn_ResDepth_v9_loss.png', dpi=300)

In [None]:
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=1, shuffle=True)

In [None]:
# visualize model predictions
sns.set_theme()
num_samples = 1

for i, data_tuple in enumerate(val_loader):
    if i < num_samples:
        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}

        with torch.no_grad():
            # Concatenate input feature channels, make prediction
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
            pred_sd = model(inputs)  # Generate predictions using the model
            pred_sd = pred_sd.to('cpu')
        
        f, ax = plt.subplots(3, 3, figsize=(15, 15), sharex=True, sharey=True)
        ax[0, 0].imshow(pred_sd.squeeze(), cmap='Blues', vmin=0, vmax=0.4, interpolation=None)
        ax[0, 0].set_title('Predicted Snow Depth')
        ax[0, 1].imshow(data_dict['aso_sd'].squeeze(), cmap='Blues', vmin=0, vmax=0.4, interpolation=None)
        ax[0, 1].set_title('ASO Lidar Snow Depth')
        ax[0, 2].imshow(data_dict['elevation'].squeeze(), cmap='viridis', interpolation='none')
        ax[0, 2].set_title('Copernicus DEM')
        ax[1, 0].imshow(data_dict['fcf'].squeeze(), cmap='Greens', interpolation='none')
        ax[1, 0].set_title('Fractional Forest Cover')
        norm_max = np.max([data_dict['green'].max(), data_dict['red'].max(), data_dict['blue'].max()]) # there are better ways to do this
        ax[1, 1].imshow(torch.cat((data_dict['red'].squeeze()[:, :, None]/norm_max, data_dict['green'].squeeze()[:, :, None]/norm_max, data_dict['blue'].squeeze()[:, :, None]/norm_max), 2).squeeze(), interpolation='none')
        ax[1, 1].set_title('true color image')
        ax[1, 2].imshow(data_dict['aso_gap_map'].squeeze() + data_dict['rtc_gap_map'].squeeze() + data_dict['s2_gap_map'].squeeze(), cmap='Purples', interpolation='none')
        ax[1, 2].set_title('ASO and RTC Gaps')
        ax[2, 0].imshow(data_dict['ndvi'].squeeze(), cmap='YlGn', interpolation='none')
        ax[2, 0].set_title('NDVI')
        ax[2, 1].imshow(data_dict['ndsi'].squeeze(), cmap='BuPu', interpolation='none')
        ax[2, 1].set_title('NDSI')
        ax[2, 2].imshow(data_dict['ndwi'].squeeze(), cmap='YlGnBu', interpolation='none')
        ax[2, 2].set_title('NDWI')
        
        # modify plot style
        for a in ax.flat:
            a.set_aspect('equal')
            a.set_xticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[1], 43))
            a.set_yticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[0], 43))
            a.grid(True, linewidth=1, alpha=0.5)
        
        f.tight_layout()
    else:
        break

In [None]:
# visualize prediction error
sns.set_theme()
num_samples = 1
norm_dict = deep_snow.dataset.norm_dict

for i, data_tuple in enumerate(val_loader):
    if i < num_samples:
        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}

        with torch.no_grad():
            # Concatenate input feature channels, make prediction
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
            pred_sd = model(inputs)  # Generate predictions using the model
            pred_sd = pred_sd.to('cpu')

            # mask nodata areas
            pred_sd = torch.where(data_dict['aso_gap_map']+data_dict['rtc_gap_map'] + data_dict['s2_gap_map'] == 0, pred_sd, torch.zeros_like(pred_sd))
            aso_sd = torch.where(data_dict['aso_gap_map']+data_dict['rtc_gap_map'] + data_dict['s2_gap_map'] == 0, data_dict['aso_sd'], torch.zeros_like(pred_sd))

            # undo normalization
            pred_sd = deep_snow.dataset.undo_norm(pred_sd, deep_snow.dataset.norm_dict['aso_sd']).squeeze()
            aso_sd = deep_snow.dataset.undo_norm(aso_sd, deep_snow.dataset.norm_dict['aso_sd']).squeeze()
            
            # mask values above 0
            pred_sd = torch.where(pred_sd >= 0, pred_sd, torch.zeros_like(pred_sd))
            
            f, ax = plt.subplots(2, 2, figsize=(10,10), sharex=True, sharey=True)
            im0 = ax[0, 0].imshow(pred_sd, cmap='Blues', vmin=0, vmax=2, interpolation='none') 
            ax[0, 0].set_title('predicted snow depth')
            f.colorbar(im0, shrink=0.5)
            im1 = ax[0, 1].imshow(aso_sd, cmap='Blues', vmin=0, vmax=2, interpolation='none')
            ax[0, 1].set_title('ASO lidar snow depth')
            f.colorbar(im1, shrink=0.5)

            im2 = ax[1, 0].imshow(aso_sd-pred_sd, cmap='RdBu', vmin=-2, vmax=2, interpolation='none') 
            ax[1, 0].set_title('ASO snow depth - predicted snow depth')
            f.colorbar(im2, shrink=0.5)
            norm_max = np.max([data_dict['green'].max(), data_dict['red'].max(), data_dict['blue'].max()]) # there are better ways to do this
            im3 = ax[1, 1].imshow(torch.cat((data_dict['red'].squeeze()[:, :, None]/norm_max, data_dict['green'].squeeze()[:, :, None]/norm_max, data_dict['blue'].squeeze()[:, :, None]/norm_max), 2).squeeze(), interpolation='none')
            ax[1, 1].set_title('true color image')
            f.colorbar(im3, shrink=0.5)

            # modify plot style
            for a in ax.flat:
                a.set_aspect('equal')
                a.set_xticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[1], 43))
                a.set_yticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[0], 43))
                a.grid(True, linewidth=1, alpha=0.5)

            plt.tight_layout()
    else: 
        break