# Train a model on the deep-snow dataset

In [1]:
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from glob import glob
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import pickle
import random

import deep_snow.models
import deep_snow.dataset

## Prepare dataloader

In [2]:
# get paths to data
train_data_dir = '/mnt/Backups/gbrench/repos/deep-snow/data/subsets_v4/train'
train_path_list = glob(f'{train_data_dir}/ASO_50M_SD*.nc')

val_data_dir = '/mnt/Backups/gbrench/repos/deep-snow/data/subsets_v4/val'
val_path_list = glob(f'{val_data_dir}/ASO_50M_SD*.nc')

In [None]:
# # to test code with a small sample of the data
# import random
# n_imgs = 16

# train_path_list = random.sample(train_path_list, n_imgs )
# val_path_list = random.sample(val_path_list, n_imgs)

In [3]:
# define data to be returned by dataloader
selected_channels = [
    # ASO products
    'aso_sd', # ASO lidar snow depth (target dataset)
    'aso_gap_map', # gaps in ASO data
    
    # Sentinel-1 products
    'snowon_vv', # snow on Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vh', # snow on Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vv', # snow off Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vh', # snow off Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vv_mean', # snow on Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_vh_mean', # snow on Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vv_mean', # snow off Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vh_mean', # snow off Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_cr', # cross ratio, snowon_vh - snowon_vv
    'snowoff_cr', # cross ratio, snowoff_vh - snowoff_vv
    'delta_cr', # change in cross ratio, snowon_cr - snowoff_cr
    'rtc_gap_map', # gaps in Sentinel-1 data
    'rtc_mean_gap_map', # gaps in Sentinel-1 mean data
    
    # Sentinel-2 products 
    'aerosol_optical_thickness', # snow on Sentinel-2 aerosol optical thickness band 
    'coastal_aerosol', # snow on Sentinel-2 coastal aerosol band
    'blue', # snow on Sentinel-2 blue band
    'green', # snow on Sentinel-2 green band
    'red', # snow on Sentinel-2 red band
    'red_edge1', # snow on Sentinel-2 red edge 1 band
    'red_edge2', # snow on Sentinel-2 red edge 2 band
    'red_edge3', # snow on Sentinel-2 red edge 3 band
    'nir', # snow on Sentinel-2 near infrared band
    'water_vapor', # snow on Sentinel-2 water vapor
    'swir1', # snow on Sentinel-2 shortwave infrared band 1
    'swir2', # snow on Sentinel-2 shortwave infrared band 2
    'scene_class_map', # snow on Sentinel-2 scene classification product
    'water_vapor_product', # snow on Sentinel-2 water vapor product
    'ndvi', # Normalized Difference Vegetation Index from Sentinel-2
    'ndsi', # Normalized Difference Snow Index from Sentinel-2
    'ndwi', # Normalized Difference Water Index from Sentinel-2
    's2_gap_map', # gaps in Sentinel-2 data

    # PROBA-V global land cover dataset (Buchhorn et al., 2020)
    'fcf', # fractional forest cover
    
    # COP30 digital elevation model      
    'elevation',
    'slope',
    'aspect',
    'curvature',
    'tpi',
    'tri',

    # latitude and longitude
    'latitude',
    'longitude',

    # day of water year
    'dowy'
                    ]

# prepare training and validation dataloaders
train_data = deep_snow.dataset.Dataset(train_path_list, selected_channels, norm=True)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True, num_workers=16)
val_data = deep_snow.dataset.Dataset(val_path_list, selected_channels, norm=True, augment=False)
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=16, shuffle=True, num_workers=16)

In [4]:
# define input channels for model
input_channels = [
    'snowon_vv',
    'delta_cr',
    'green',
    'swir2',
    'ndsi',
    'ndwi',
    'elevation',
    'latitude',
    'longitude']

## Train model

In [5]:
# import model
# model = deep_snow.models.SimpleCNN(n_input_channels=len(input_channels))
# model = deep_snow.models.UNet(n_input_channels=len(input_channels))
# model = deep_snow.models.ResUNet(n_input_channels=len(input_channels))
# model = deep_snow.models.ResDepth(n_input_channels=len(input_channels))
# model = deep_snow.models.VisionTransformer(n_input_channels=len(input_channels))

model = deep_snow.models.ResDepth(n_input_channels=len(input_channels), depth=5)
model.to('cuda');  # Run on GPU

# name your model
model_name = 'quinn_ResDepth_v9'

In [None]:
# Define optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=0.0003, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
loss_fn = nn.MSELoss()
epochs = 250
min_val_loss = 1

train_loss = []
val_loss = []

# training and validation loop
for epoch in range(epochs):
    print(f'\nStarting epoch {epoch+1}')
    train_epoch_loss = []
    val_epoch_loss = []

    # Loop through training data with tqdm progress bar
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", ncols=130)
    for data_tuple in train_pbar:
        model.train()
        optimizer.zero_grad()

        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}
        
        # prepare inputs by concatenating along channel dimension
        inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')

        # generate prediction
        pred_sd = model(inputs)

        # Limit prediction to areas with valid data
        pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
        aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda')== 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))

        # Calculate loss
        train_batch_loss = loss_fn(pred_sd, aso_sd.to('cuda'))
        train_epoch_loss.append(train_batch_loss.item())

        # Update tqdm progress bar with batch loss
        train_pbar.set_postfix({'batch loss': train_batch_loss.item(), 'mean epoch loss': np.mean(train_epoch_loss)})

        train_batch_loss.backward()  # Propagate the gradients in backward pass
        optimizer.step()

    train_loss.append(np.mean(train_epoch_loss))
    print(f'Training loss: {np.mean(train_epoch_loss)}')
    scheduler.step(np.mean(train_epoch_loss))

    # Run model on validation data with tqdm progress bar
    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", ncols=130)
    for data_tuple in val_pbar:
        with torch.no_grad():
            model.eval()
            
            # read data into dictionary
            data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}
            # prepare inputs by concatenating along channel dimension
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
    
            # generate prediction
            pred_sd = model(inputs)
    
            # Limit prediction to areas with valid data
            pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
            aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))
    
            # Calculate loss
            val_batch_loss = loss_fn(pred_sd, aso_sd.to('cuda'))
            val_epoch_loss.append(val_batch_loss.item())

            # Update tqdm progress bar with batch loss
            val_pbar.set_postfix({'batch loss': val_batch_loss.item(), 'mean epoch loss': np.mean(val_epoch_loss)})

    if np.mean(val_epoch_loss) < min_val_loss:
        if epoch > 30:
            min_val_loss = np.mean(val_epoch_loss)
            torch.save(model.state_dict(), f'../../weights/{model_name}_{epoch}epochs')
                    
    # # calculate loss over previous 10 epochs for early stopping later
    # if epoch > 20:
    #     past_loss = np.mean(val_loss[-20:-10])

    val_loss.append(np.mean(val_epoch_loss))
    print(f'Validation loss: {np.mean(val_epoch_loss)}')
    
    # save loss 
    with open(f'../../loss/{model_name}_val_loss.pkl', 'wb') as f:
        pickle.dump(val_loss, f)
        
    with open(f'../../loss/{model_name}_train_loss.pkl', 'wb') as f:
        pickle.dump(train_loss, f)

    # # implement early stopping
    # if epoch > 20:
    #     current_loss = np.mean(val_loss[-10:-1])
    #     if current_loss > past_loss:
    #         counter +=1
    #         if counter >= 10:
    #             print('early stopping triggered')
    #             # save model
    #             torch.save(model.state_dict(), f'../../weights/{model_name}_{epoch}epochs')
    #             break
    #     else:
    #         counter = 0


Starting epoch 1


Epoch 1/250: 100%|████████████████████████████| 774/774 [05:05<00:00,  2.53batch/s, batch loss=0.000368, mean epoch loss=0.000972]


Training loss: 0.0009719300401476465


Epoch 1/250: 100%|███████████████████████████████| 135/135 [00:46<00:00,  2.92batch/s, batch loss=0.0015, mean epoch loss=0.00104]


Validation loss: 0.0010410517371686487

Starting epoch 2


Epoch 2/250: 100%|█████████████████████████████| 774/774 [05:08<00:00,  2.51batch/s, batch loss=0.00132, mean epoch loss=0.000862]


Training loss: 0.0008623384945623268


Epoch 2/250: 100%|██████████████████████████████| 135/135 [00:43<00:00,  3.11batch/s, batch loss=0.0013, mean epoch loss=0.000764]


Validation loss: 0.0007642924222939959

Starting epoch 3


Epoch 3/250: 100%|█████████████████████████████| 774/774 [05:01<00:00,  2.56batch/s, batch loss=0.00123, mean epoch loss=0.000769]


Training loss: 0.0007688479193228174


Epoch 3/250: 100%|████████████████████████████| 135/135 [00:45<00:00,  2.95batch/s, batch loss=0.000762, mean epoch loss=0.000572]


Validation loss: 0.0005718261083260316

Starting epoch 4


Epoch 4/250: 100%|████████████████████████████| 774/774 [05:32<00:00,  2.33batch/s, batch loss=0.000504, mean epoch loss=0.000648]


Training loss: 0.0006480104495931864


Epoch 4/250: 100%|█████████████████████████████| 135/135 [00:44<00:00,  3.07batch/s, batch loss=0.00109, mean epoch loss=0.000488]


Validation loss: 0.0004876671502845258

Starting epoch 5


Epoch 5/250: 100%|█████████████████████████████| 774/774 [06:35<00:00,  1.96batch/s, batch loss=0.00043, mean epoch loss=0.000611]


Training loss: 0.0006114542817059804


Epoch 5/250: 100%|█████████████████████████████| 135/135 [01:02<00:00,  2.17batch/s, batch loss=0.00064, mean epoch loss=0.000487]


Validation loss: 0.00048677260694805427

Starting epoch 6


Epoch 6/250: 100%|█████████████████████████████| 774/774 [05:14<00:00,  2.46batch/s, batch loss=0.00109, mean epoch loss=0.000581]


Training loss: 0.0005814883995889048


Epoch 6/250: 100%|████████████████████████████| 135/135 [01:02<00:00,  2.14batch/s, batch loss=0.000212, mean epoch loss=0.000415]


Validation loss: 0.00041497855498972864

Starting epoch 7


Epoch 7/250: 100%|█████████████████████████████| 774/774 [04:34<00:00,  2.81batch/s, batch loss=0.00142, mean epoch loss=0.000568]


Training loss: 0.000567808455473022


Epoch 7/250: 100%|████████████████████████████| 135/135 [00:43<00:00,  3.12batch/s, batch loss=0.000511, mean epoch loss=0.000495]


Validation loss: 0.0004954833617106218

Starting epoch 8


Epoch 8/250: 100%|████████████████████████████| 774/774 [05:07<00:00,  2.52batch/s, batch loss=0.000876, mean epoch loss=0.000546]


Training loss: 0.0005457216672604504


Epoch 8/250: 100%|████████████████████████████| 135/135 [00:43<00:00,  3.08batch/s, batch loss=0.000226, mean epoch loss=0.000452]


Validation loss: 0.00045243169890319995

Starting epoch 9


Epoch 9/250: 100%|████████████████████████████| 774/774 [05:26<00:00,  2.37batch/s, batch loss=0.000338, mean epoch loss=0.000537]


Training loss: 0.0005368942760756222


Epoch 9/250: 100%|████████████████████████████| 135/135 [00:51<00:00,  2.62batch/s, batch loss=0.000604, mean epoch loss=0.000389]


Validation loss: 0.0003886306728182481

Starting epoch 10


Epoch 10/250: 100%|███████████████████████████| 774/774 [04:51<00:00,  2.66batch/s, batch loss=0.000549, mean epoch loss=0.000518]


Training loss: 0.0005178669528972216


Epoch 10/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.15batch/s, batch loss=0.000328, mean epoch loss=0.000416]


Validation loss: 0.0004157710361781668

Starting epoch 11


Epoch 11/250: 100%|███████████████████████████| 774/774 [05:02<00:00,  2.56batch/s, batch loss=0.000279, mean epoch loss=0.000509]


Training loss: 0.0005085212142618544


Epoch 11/250: 100%|████████████████████████████| 135/135 [00:42<00:00,  3.17batch/s, batch loss=0.00042, mean epoch loss=0.000541]


Validation loss: 0.0005407413978698767

Starting epoch 12


Epoch 12/250: 100%|████████████████████████████| 774/774 [05:17<00:00,  2.44batch/s, batch loss=8.01e-5, mean epoch loss=0.000503]


Training loss: 0.0005025990532688361


Epoch 12/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.14batch/s, batch loss=0.000224, mean epoch loss=0.000467]


Validation loss: 0.00046734494686394035

Starting epoch 13


Epoch 13/250: 100%|███████████████████████████| 774/774 [05:15<00:00,  2.45batch/s, batch loss=0.000395, mean epoch loss=0.000496]


Training loss: 0.0004955977836063483


Epoch 13/250: 100%|████████████████████████████| 135/135 [00:42<00:00,  3.16batch/s, batch loss=0.00075, mean epoch loss=0.000363]


Validation loss: 0.00036274847373500015

Starting epoch 14


Epoch 14/250: 100%|████████████████████████████| 774/774 [05:08<00:00,  2.51batch/s, batch loss=6.79e-5, mean epoch loss=0.000495]


Training loss: 0.0004948767023674695


Epoch 14/250: 100%|████████████████████████████| 135/135 [00:48<00:00,  2.79batch/s, batch loss=6.88e-5, mean epoch loss=0.000383]


Validation loss: 0.00038266283429240705

Starting epoch 15


Epoch 15/250: 100%|████████████████████████████| 774/774 [04:54<00:00,  2.63batch/s, batch loss=0.00027, mean epoch loss=0.000474]


Training loss: 0.00047433720726148843


Epoch 15/250: 100%|███████████████████████████| 135/135 [01:38<00:00,  1.37batch/s, batch loss=0.000531, mean epoch loss=0.000396]


Validation loss: 0.000396299381073159

Starting epoch 16


Epoch 16/250: 100%|███████████████████████████| 774/774 [05:51<00:00,  2.20batch/s, batch loss=0.000266, mean epoch loss=0.000485]


Training loss: 0.0004848413916974547


Epoch 16/250: 100%|███████████████████████████| 135/135 [00:50<00:00,  2.67batch/s, batch loss=0.000393, mean epoch loss=0.000365]


Validation loss: 0.0003648443363932462

Starting epoch 17


Epoch 17/250: 100%|███████████████████████████| 774/774 [05:17<00:00,  2.44batch/s, batch loss=0.000972, mean epoch loss=0.000469]


Training loss: 0.00046932624361247817


Epoch 17/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.14batch/s, batch loss=0.000367, mean epoch loss=0.000359]


Validation loss: 0.00035945277405618174

Starting epoch 18


Epoch 18/250: 100%|███████████████████████████| 774/774 [05:22<00:00,  2.40batch/s, batch loss=0.000217, mean epoch loss=0.000469]


Training loss: 0.0004692704891572401


Epoch 18/250: 100%|███████████████████████████| 135/135 [00:48<00:00,  2.77batch/s, batch loss=0.000616, mean epoch loss=0.000468]


Validation loss: 0.00046806411073366353

Starting epoch 19


Epoch 19/250: 100%|███████████████████████████| 774/774 [05:56<00:00,  2.17batch/s, batch loss=0.000156, mean epoch loss=0.000463]


Training loss: 0.00046277624793121026


Epoch 19/250: 100%|████████████████████████████| 135/135 [00:43<00:00,  3.14batch/s, batch loss=9.82e-5, mean epoch loss=0.000395]


Validation loss: 0.000394657797490557

Starting epoch 20


Epoch 20/250: 100%|███████████████████████████| 774/774 [05:29<00:00,  2.35batch/s, batch loss=0.000499, mean epoch loss=0.000451]


Training loss: 0.00045112468202905824


Epoch 20/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.16batch/s, batch loss=0.000352, mean epoch loss=0.000387]


Validation loss: 0.00038715991406072

Starting epoch 21


Epoch 21/250: 100%|███████████████████████████| 774/774 [05:26<00:00,  2.37batch/s, batch loss=0.000134, mean epoch loss=0.000442]


Training loss: 0.0004417000176784828


Epoch 21/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.15batch/s, batch loss=0.000519, mean epoch loss=0.000406]


Validation loss: 0.0004063733203414207

Starting epoch 22


Epoch 22/250: 100%|███████████████████████████| 774/774 [05:45<00:00,  2.24batch/s, batch loss=0.000588, mean epoch loss=0.000454]


Training loss: 0.0004537237439784127


Epoch 22/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.20batch/s, batch loss=0.000209, mean epoch loss=0.000355]


Validation loss: 0.00035486534411406694

Starting epoch 23


Epoch 23/250: 100%|███████████████████████████| 774/774 [04:57<00:00,  2.61batch/s, batch loss=0.000595, mean epoch loss=0.000446]


Training loss: 0.00044572779192216666


Epoch 23/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.17batch/s, batch loss=0.000509, mean epoch loss=0.000334]


Validation loss: 0.00033360472053000533

Starting epoch 24


Epoch 24/250: 100%|███████████████████████████| 774/774 [05:02<00:00,  2.56batch/s, batch loss=0.000582, mean epoch loss=0.000441]


Training loss: 0.0004413668259677705


Epoch 24/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.15batch/s, batch loss=0.000361, mean epoch loss=0.000353]


Validation loss: 0.00035283196099313767

Starting epoch 25


Epoch 25/250: 100%|███████████████████████████| 774/774 [05:22<00:00,  2.40batch/s, batch loss=0.000242, mean epoch loss=0.000438]


Training loss: 0.00043770811143821557


Epoch 25/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.19batch/s, batch loss=0.000152, mean epoch loss=0.000344]


Validation loss: 0.0003436808631738165

Starting epoch 26


Epoch 26/250: 100%|████████████████████████████| 774/774 [06:05<00:00,  2.12batch/s, batch loss=0.00011, mean epoch loss=0.000442]


Training loss: 0.0004417624160277691


Epoch 26/250: 100%|███████████████████████████| 135/135 [00:49<00:00,  2.71batch/s, batch loss=0.000112, mean epoch loss=0.000397]


Validation loss: 0.0003967241284381426

Starting epoch 27


Epoch 27/250: 100%|████████████████████████████| 774/774 [05:33<00:00,  2.32batch/s, batch loss=0.000686, mean epoch loss=0.00043]


Training loss: 0.00043004211256065757


Epoch 27/250: 100%|███████████████████████████| 135/135 [00:59<00:00,  2.26batch/s, batch loss=0.000101, mean epoch loss=0.000349]


Validation loss: 0.00034882194730151373

Starting epoch 28


Epoch 28/250: 100%|███████████████████████████| 774/774 [05:19<00:00,  2.42batch/s, batch loss=0.000472, mean epoch loss=0.000422]


Training loss: 0.0004215050505706439


Epoch 28/250: 100%|███████████████████████████| 135/135 [01:01<00:00,  2.19batch/s, batch loss=0.000429, mean epoch loss=0.000363]


Validation loss: 0.00036306800059349

Starting epoch 29


Epoch 29/250: 100%|████████████████████████████| 774/774 [05:19<00:00,  2.43batch/s, batch loss=0.000201, mean epoch loss=0.00041]


Training loss: 0.0004100247161807411


Epoch 29/250: 100%|███████████████████████████| 135/135 [01:32<00:00,  1.45batch/s, batch loss=0.000521, mean epoch loss=0.000356]


Validation loss: 0.0003560472720925679

Starting epoch 30


Epoch 30/250: 100%|████████████████████████████| 774/774 [05:30<00:00,  2.34batch/s, batch loss=0.000344, mean epoch loss=0.00042]


Training loss: 0.0004196141976946998


Epoch 30/250: 100%|███████████████████████████| 135/135 [00:52<00:00,  2.56batch/s, batch loss=0.000533, mean epoch loss=0.000353]


Validation loss: 0.00035298068719890176

Starting epoch 31


Epoch 31/250: 100%|███████████████████████████| 774/774 [05:26<00:00,  2.37batch/s, batch loss=0.000311, mean epoch loss=0.000411]


Training loss: 0.0004107398823551438


Epoch 31/250: 100%|████████████████████████████| 135/135 [00:43<00:00,  3.11batch/s, batch loss=0.00032, mean epoch loss=0.000351]


Validation loss: 0.0003511630961266174

Starting epoch 32


Epoch 32/250: 100%|███████████████████████████| 774/774 [04:39<00:00,  2.77batch/s, batch loss=0.000253, mean epoch loss=0.000419]


Training loss: 0.00041889174467694794


Epoch 32/250: 100%|███████████████████████████| 135/135 [00:43<00:00,  3.13batch/s, batch loss=0.000572, mean epoch loss=0.000363]


Validation loss: 0.00036300388963826224

Starting epoch 33


Epoch 33/250: 100%|█████████████████████████████| 774/774 [04:33<00:00,  2.83batch/s, batch loss=8.5e-5, mean epoch loss=0.000398]


Training loss: 0.000397885453288324


Epoch 33/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.19batch/s, batch loss=0.000479, mean epoch loss=0.000322]


Validation loss: 0.00032207364678874407

Starting epoch 34


Epoch 34/250: 100%|████████████████████████████| 774/774 [05:08<00:00,  2.51batch/s, batch loss=0.00044, mean epoch loss=0.000411]


Training loss: 0.00041100482934832


Epoch 34/250: 100%|███████████████████████████| 135/135 [00:41<00:00,  3.23batch/s, batch loss=0.000479, mean epoch loss=0.000345]


Validation loss: 0.0003447781326098333

Starting epoch 35


Epoch 35/250: 100%|████████████████████████████| 774/774 [04:53<00:00,  2.64batch/s, batch loss=0.00111, mean epoch loss=0.000396]


Training loss: 0.0003962335760323101


Epoch 35/250: 100%|████████████████████████████| 135/135 [00:42<00:00,  3.15batch/s, batch loss=0.000192, mean epoch loss=0.00032]


Validation loss: 0.00031987599830716606

Starting epoch 36


Epoch 36/250: 100%|███████████████████████████| 774/774 [04:59<00:00,  2.58batch/s, batch loss=0.000494, mean epoch loss=0.000397]


Training loss: 0.00039682081395807405


Epoch 36/250: 100%|████████████████████████████| 135/135 [00:41<00:00,  3.22batch/s, batch loss=7.47e-5, mean epoch loss=0.000309]


Validation loss: 0.00030890777731353106

Starting epoch 37


Epoch 37/250: 100%|████████████████████████████| 774/774 [05:06<00:00,  2.52batch/s, batch loss=4.71e-5, mean epoch loss=0.000396]


Training loss: 0.0003958952016988213


Epoch 37/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.15batch/s, batch loss=0.000241, mean epoch loss=0.000317]


Validation loss: 0.000317196977765114

Starting epoch 38


Epoch 38/250: 100%|███████████████████████████| 774/774 [05:23<00:00,  2.39batch/s, batch loss=0.000395, mean epoch loss=0.000392]


Training loss: 0.0003922297397408794


Epoch 38/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.19batch/s, batch loss=0.000243, mean epoch loss=0.000318]


Validation loss: 0.0003181608947839036

Starting epoch 39


Epoch 39/250: 100%|███████████████████████████| 774/774 [05:24<00:00,  2.38batch/s, batch loss=0.000568, mean epoch loss=0.000386]


Training loss: 0.0003864018748553761


Epoch 39/250: 100%|███████████████████████████| 135/135 [00:54<00:00,  2.46batch/s, batch loss=0.000349, mean epoch loss=0.000336]


Validation loss: 0.0003359609727609765

Starting epoch 40


Epoch 40/250: 100%|███████████████████████████| 774/774 [05:17<00:00,  2.44batch/s, batch loss=0.000564, mean epoch loss=0.000385]


Training loss: 0.0003852518504633902


Epoch 40/250: 100%|███████████████████████████| 135/135 [01:06<00:00,  2.02batch/s, batch loss=0.000351, mean epoch loss=0.000342]


Validation loss: 0.00034165986742462135

Starting epoch 41


Epoch 41/250: 100%|████████████████████████████| 774/774 [05:14<00:00,  2.46batch/s, batch loss=0.00063, mean epoch loss=0.000374]


Training loss: 0.00037419712875805163


Epoch 41/250: 100%|███████████████████████████| 135/135 [00:58<00:00,  2.30batch/s, batch loss=0.000682, mean epoch loss=0.000341]


Validation loss: 0.00034146831479760025

Starting epoch 42


Epoch 42/250: 100%|███████████████████████████| 774/774 [04:45<00:00,  2.71batch/s, batch loss=0.000443, mean epoch loss=0.000381]


Training loss: 0.0003809004685808524


Epoch 42/250: 100%|███████████████████████████| 135/135 [01:10<00:00,  1.92batch/s, batch loss=0.000241, mean epoch loss=0.000308]


Validation loss: 0.00030776945746361484

Starting epoch 43


Epoch 43/250: 100%|███████████████████████████| 774/774 [04:57<00:00,  2.60batch/s, batch loss=0.000619, mean epoch loss=0.000382]


Training loss: 0.00038248549870952204


Epoch 43/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.16batch/s, batch loss=0.000233, mean epoch loss=0.000322]


Validation loss: 0.0003216907376621815

Starting epoch 44


Epoch 44/250: 100%|███████████████████████████| 774/774 [05:15<00:00,  2.45batch/s, batch loss=0.000334, mean epoch loss=0.000377]


Training loss: 0.00037745108555626585


Epoch 44/250: 100%|███████████████████████████| 135/135 [00:42<00:00,  3.18batch/s, batch loss=0.000331, mean epoch loss=0.000329]


Validation loss: 0.0003293276241122469

Starting epoch 45


Epoch 45/250: 100%|███████████████████████████| 774/774 [05:07<00:00,  2.52batch/s, batch loss=0.000206, mean epoch loss=0.000377]


Training loss: 0.0003765075089600138


Epoch 45/250: 100%|████████████████████████████| 135/135 [00:56<00:00,  2.38batch/s, batch loss=7.74e-5, mean epoch loss=0.000354]


Validation loss: 0.0003538472970292248

Starting epoch 46


Epoch 46/250: 100%|███████████████████████████| 774/774 [05:02<00:00,  2.56batch/s, batch loss=0.000269, mean epoch loss=0.000381]


Training loss: 0.0003812176787799215


Epoch 46/250: 100%|███████████████████████████| 135/135 [01:07<00:00,  1.99batch/s, batch loss=0.000144, mean epoch loss=0.000307]


Validation loss: 0.00030718670604983344

Starting epoch 47


Epoch 47/250: 100%|███████████████████████████| 774/774 [05:09<00:00,  2.50batch/s, batch loss=0.000246, mean epoch loss=0.000371]


Training loss: 0.0003714855224326277


Epoch 47/250: 100%|███████████████████████████| 135/135 [01:15<00:00,  1.79batch/s, batch loss=0.000446, mean epoch loss=0.000306]


Validation loss: 0.00030621297353516436

Starting epoch 48


Epoch 48/250: 100%|███████████████████████████| 774/774 [05:34<00:00,  2.32batch/s, batch loss=0.000243, mean epoch loss=0.000365]


Training loss: 0.0003648650613015745


Epoch 48/250: 100%|████████████████████████████| 135/135 [00:52<00:00,  2.56batch/s, batch loss=0.00014, mean epoch loss=0.000318]


Validation loss: 0.00031820364583162936

Starting epoch 49


Epoch 49/250: 100%|███████████████████████████| 774/774 [05:00<00:00,  2.58batch/s, batch loss=0.000721, mean epoch loss=0.000366]


Training loss: 0.00036644883765302085


Epoch 49/250: 100%|████████████████████████████| 135/135 [01:10<00:00,  1.92batch/s, batch loss=4.93e-5, mean epoch loss=0.000333]


Validation loss: 0.0003334821375944928

Starting epoch 50


Epoch 50/250: 100%|███████████████████████████| 774/774 [05:02<00:00,  2.56batch/s, batch loss=0.000649, mean epoch loss=0.000369]


Training loss: 0.00036930993686348784


Epoch 50/250: 100%|████████████████████████████| 135/135 [01:07<00:00,  1.99batch/s, batch loss=0.000549, mean epoch loss=0.00043]


Validation loss: 0.0004304002541040838

Starting epoch 51


Epoch 51/250:  91%|████████████████████████▋  | 706/774 [04:32<00:42,  1.61batch/s, batch loss=0.000437, mean epoch loss=0.000368]

## Examine results

In [None]:
#load previous model
model = deep_snow.models.ResDepth(n_input_channels=len(input_channels), depth=5)
model.load_state_dict(torch.load('../../weights/quinn_ResDepth_v4_74epochs'))
model.to('cuda');

In [None]:
with open(f'../../loss/quinn_ResDepth_v8_val_loss.pkl', 'rb') as f:
        val_loss = pickle.load(f)

with open(f'../../loss/quinn_ResDepth_v8_train_loss.pkl', 'rb') as f:
        train_loss = pickle.load(f)


# plot loss over all epochs
f, ax = plt.subplots(figsize=(10,5))
ax.plot(train_loss, label='training')
ax.plot(val_loss, label='validation')
ax.set_xlabel('epoch')
ax.set_ylabel('MSE loss')
ax.set_title('Loss')
ax.legend()

# save figure
plt.savefig(f'../../figs/quinn_ResDepth_v8_loss.png', dpi=300)

In [None]:
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=1, shuffle=True)

In [None]:
# visualize model predictions
sns.set_theme()
num_samples = 1

for i, data_tuple in enumerate(val_loader):
    if i < num_samples:
        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}

        with torch.no_grad():
            # Concatenate input feature channels, make prediction
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
            pred_sd = model(inputs)  # Generate predictions using the model
            pred_sd = pred_sd.to('cpu')
        
        f, ax = plt.subplots(3, 3, figsize=(15, 15), sharex=True, sharey=True)
        ax[0, 0].imshow(pred_sd.squeeze(), cmap='Blues', vmin=0, vmax=0.4, interpolation=None)
        ax[0, 0].set_title('Predicted Snow Depth')
        ax[0, 1].imshow(data_dict['aso_sd'].squeeze(), cmap='Blues', vmin=0, vmax=0.4, interpolation=None)
        ax[0, 1].set_title('ASO Lidar Snow Depth')
        ax[0, 2].imshow(data_dict['elevation'].squeeze(), cmap='viridis', interpolation='none')
        ax[0, 2].set_title('Copernicus DEM')
        ax[1, 0].imshow(data_dict['fcf'].squeeze(), cmap='Greens', interpolation='none')
        ax[1, 0].set_title('Fractional Forest Cover')
        norm_max = np.max([data_dict['green'].max(), data_dict['red'].max(), data_dict['blue'].max()]) # there are better ways to do this
        ax[1, 1].imshow(torch.cat((data_dict['red'].squeeze()[:, :, None]/norm_max, data_dict['green'].squeeze()[:, :, None]/norm_max, data_dict['blue'].squeeze()[:, :, None]/norm_max), 2).squeeze(), interpolation='none')
        ax[1, 1].set_title('true color image')
        ax[1, 2].imshow(data_dict['aso_gap_map'].squeeze() + data_dict['rtc_gap_map'].squeeze() + data_dict['s2_gap_map'].squeeze(), cmap='Purples', interpolation='none')
        ax[1, 2].set_title('ASO and RTC Gaps')
        ax[2, 0].imshow(data_dict['ndvi'].squeeze(), cmap='YlGn', interpolation='none')
        ax[2, 0].set_title('NDVI')
        ax[2, 1].imshow(data_dict['ndsi'].squeeze(), cmap='BuPu', interpolation='none')
        ax[2, 1].set_title('NDSI')
        ax[2, 2].imshow(data_dict['ndwi'].squeeze(), cmap='YlGnBu', interpolation='none')
        ax[2, 2].set_title('NDWI')
        
        # modify plot style
        for a in ax.flat:
            a.set_aspect('equal')
            a.set_xticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[1], 43))
            a.set_yticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[0], 43))
            a.grid(True, linewidth=1, alpha=0.5)
        
        f.tight_layout()
    else:
        break

In [None]:
# visualize prediction error
sns.set_theme()
num_samples = 1
norm_dict = deep_snow.dataset.norm_dict

for i, data_tuple in enumerate(val_loader):
    if i < num_samples:
        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}

        with torch.no_grad():
            # Concatenate input feature channels, make prediction
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
            pred_sd = model(inputs)  # Generate predictions using the model
            pred_sd = pred_sd.to('cpu')

            # mask nodata areas
            pred_sd = torch.where(data_dict['aso_gap_map']+data_dict['rtc_gap_map'] + data_dict['s2_gap_map'] == 0, pred_sd, torch.zeros_like(pred_sd))
            aso_sd = torch.where(data_dict['aso_gap_map']+data_dict['rtc_gap_map'] + data_dict['s2_gap_map'] == 0, data_dict['aso_sd'], torch.zeros_like(pred_sd))

            # undo normalization
            pred_sd = deep_snow.dataset.undo_norm(pred_sd, deep_snow.dataset.norm_dict['aso_sd']).squeeze()
            aso_sd = deep_snow.dataset.undo_norm(aso_sd, deep_snow.dataset.norm_dict['aso_sd']).squeeze()
            
            # mask values above 0
            pred_sd = torch.where(pred_sd >= 0, pred_sd, torch.zeros_like(pred_sd))
            
            f, ax = plt.subplots(2, 2, figsize=(10,10), sharex=True, sharey=True)
            im0 = ax[0, 0].imshow(pred_sd, cmap='Blues', vmin=0, vmax=2, interpolation='none') 
            ax[0, 0].set_title('predicted snow depth')
            f.colorbar(im0, shrink=0.5)
            im1 = ax[0, 1].imshow(aso_sd, cmap='Blues', vmin=0, vmax=2, interpolation='none')
            ax[0, 1].set_title('ASO lidar snow depth')
            f.colorbar(im1, shrink=0.5)

            im2 = ax[1, 0].imshow(aso_sd-pred_sd, cmap='RdBu', vmin=-2, vmax=2, interpolation='none') 
            ax[1, 0].set_title('ASO snow depth - predicted snow depth')
            f.colorbar(im2, shrink=0.5)
            norm_max = np.max([data_dict['green'].max(), data_dict['red'].max(), data_dict['blue'].max()]) # there are better ways to do this
            im3 = ax[1, 1].imshow(torch.cat((data_dict['red'].squeeze()[:, :, None]/norm_max, data_dict['green'].squeeze()[:, :, None]/norm_max, data_dict['blue'].squeeze()[:, :, None]/norm_max), 2).squeeze(), interpolation='none')
            ax[1, 1].set_title('true color image')
            f.colorbar(im3, shrink=0.5)

            # modify plot style
            for a in ax.flat:
                a.set_aspect('equal')
                a.set_xticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[1], 43))
                a.set_yticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[0], 43))
                a.grid(True, linewidth=1, alpha=0.5)

            plt.tight_layout()
    else: 
        break