# Train a model on the deep-snow dataset

In [1]:
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from glob import glob
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import pickle
import random

import deep_snow.models
import deep_snow.dataset

## Prepare dataloader

In [2]:
# get paths to data
train_data_dir = '/mnt/Backups/gbrench/repos/deep-snow/data/subsets_v4/train'
train_path_list = glob(f'{train_data_dir}/ASO_50M_SD*.nc')

val_data_dir = '/mnt/Backups/gbrench/repos/deep-snow/data/subsets_v4/val'
val_path_list = glob(f'{val_data_dir}/ASO_50M_SD*.nc')

#consolidate for final run
train_path_list = train_path_list + val_path_list

test_data_dir = '/mnt/Backups/gbrench/repos/deep-snow/data/subsets_v4/test'
test_path_list = glob(f'{test_data_dir}/ASO_50M_SD*.nc')

In [3]:
# # to test code with a small sample of the data
# import random
# n_imgs = 16

# train_path_list = random.sample(train_path_list, n_imgs )
# val_path_list = random.sample(val_path_list, n_imgs)

In [3]:
# define data to be returned by dataloader
selected_channels = [
    # ASO products
    'aso_sd', # ASO lidar snow depth (target dataset)
    'aso_gap_map', # gaps in ASO data
    
    # Sentinel-1 products
    'snowon_vv', # snow on Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vh', # snow on Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vv', # snow off Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vh', # snow off Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vv_mean', # snow on Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_vh_mean', # snow on Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vv_mean', # snow off Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vh_mean', # snow off Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_cr', # cross ratio, snowon_vh - snowon_vv
    'snowoff_cr', # cross ratio, snowoff_vh - snowoff_vv
    'delta_cr', # change in cross ratio, snowon_cr - snowoff_cr
    'rtc_gap_map', # gaps in Sentinel-1 data
    'rtc_mean_gap_map', # gaps in Sentinel-1 mean data
    
    # Sentinel-2 products 
    'aerosol_optical_thickness', # snow on Sentinel-2 aerosol optical thickness band 
    'coastal_aerosol', # snow on Sentinel-2 coastal aerosol band
    'blue', # snow on Sentinel-2 blue band
    'green', # snow on Sentinel-2 green band
    'red', # snow on Sentinel-2 red band
    'red_edge1', # snow on Sentinel-2 red edge 1 band
    'red_edge2', # snow on Sentinel-2 red edge 2 band
    'red_edge3', # snow on Sentinel-2 red edge 3 band
    'nir', # snow on Sentinel-2 near infrared band
    'water_vapor', # snow on Sentinel-2 water vapor
    'swir1', # snow on Sentinel-2 shortwave infrared band 1
    'swir2', # snow on Sentinel-2 shortwave infrared band 2
    'scene_class_map', # snow on Sentinel-2 scene classification product
    'water_vapor_product', # snow on Sentinel-2 water vapor product
    'ndvi', # Normalized Difference Vegetation Index from Sentinel-2
    'ndsi', # Normalized Difference Snow Index from Sentinel-2
    'ndwi', # Normalized Difference Water Index from Sentinel-2
    's2_gap_map', # gaps in Sentinel-2 data

    # PROBA-V global land cover dataset (Buchhorn et al., 2020)
    'fcf', # fractional forest cover
    
    # COP30 digital elevation model      
    'elevation',
    'slope',
    'aspect',
    'curvature',
    'tpi',
    'tri',

    # latitude and longitude
    'latitude',
    'longitude',

    # day of water year
    'dowy'
                    ]

# prepare training and validation dataloaders
train_data = deep_snow.dataset.Dataset(train_path_list, selected_channels, augment=False, norm=True)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True, num_workers=16)
# val_data = deep_snow.dataset.Dataset(val_path_list, selected_channels, norm=True, augment=False)
# val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=16, shuffle=True, num_workers=16)
test_data = deep_snow.dataset.Dataset(test_path_list, selected_channels, norm=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=16, shuffle=True, num_workers=16)

In [4]:
# define input channels for model
input_channels = [
    'snowon_vv',
    'delta_cr',
    'green',
    'swir2',
    'ndsi',
    'ndwi',
    'elevation',
    'latitude',
    'longitude']

## Train model

In [6]:
# import model
# model = deep_snow.models.SimpleCNN(n_input_channels=len(input_channels))
# model = deep_snow.models.UNet(n_input_channels=len(input_channels))
# model = deep_snow.models.ResUNet(n_input_channels=len(input_channels))
# model = deep_snow.models.ResDepth(n_input_channels=len(input_channels))
# model = deep_snow.models.VisionTransformer(n_input_channels=len(input_channels))

# model = deep_snow.models.ResDepth(n_input_channels=len(input_channels), depth=5)
# model.to('cuda');  # Run on GPU

#load previous model
model = deep_snow.models.ResDepth(n_input_channels=len(input_channels), depth=5)
model.load_state_dict(torch.load('../../weights/quinn_ResDepth_v10_256epochs'))
model.to('cuda');

# name your model
model_name = 'quinn_ResDepth_v11'

  model.load_state_dict(torch.load('../../weights/quinn_ResDepth_v10_256epochs'))


In [6]:
# Define optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=0.0003, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
loss_fn = nn.MSELoss()
epochs = 50
min_test_loss = 1

train_loss = []
test_loss = []

# training and testidation loop
for epoch in range(epochs):
    print(f'\nStarting epoch {epoch+1}')
    train_epoch_loss = []
    test_epoch_loss = []

    # Loop through training data with tqdm progress bar
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", ncols=130)
    for data_tuple in train_pbar:
        model.train()
        optimizer.zero_grad()

        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}
        
        # prepare inputs by concatenating along channel dimension
        inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')

        # generate prediction
        pred_sd = model(inputs)

        # Limit prediction to areas with valid data
        pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
        aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda')== 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))

        # Calculate loss
        train_batch_loss = loss_fn(pred_sd, aso_sd.to('cuda'))
        train_epoch_loss.append(train_batch_loss.item())

        # Update tqdm progress bar with batch loss
        train_pbar.set_postfix({'batch loss': train_batch_loss.item(), 'mean epoch loss': np.mean(train_epoch_loss)})

        train_batch_loss.backward()  # Propagate the gradients in backward pass
        optimizer.step()

    train_loss.append(np.mean(train_epoch_loss))
    print(f'Training loss: {np.mean(train_epoch_loss)}')
    scheduler.step(np.mean(train_epoch_loss))

    # Run model on validation data with tqdm progress bar
    test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", ncols=130)
    for data_tuple in test_pbar:
        with torch.no_grad():
            model.eval()
            
            # read data into dictionary
            data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}
            # prepare inputs by concatenating along channel dimension
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
    
            # generate prediction
            pred_sd = model(inputs)
    
            # Limit prediction to areas with valid data
            pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
            aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))
    
            # Calculate loss
            test_batch_loss = loss_fn(pred_sd, aso_sd.to('cuda'))
            test_epoch_loss.append(test_batch_loss.item())

            # Update tqdm progress bar with batch loss
            test_pbar.set_postfix({'batch loss': test_batch_loss.item(), 'mean epoch loss': np.mean(test_epoch_loss)})

    if np.mean(test_epoch_loss) < min_test_loss:
        #if epoch > 30:
        min_test_loss = np.mean(test_epoch_loss)
        torch.save(model.state_dict(), f'../../weights/{model_name}_{epoch+1+256}epochs')

    # if epoch == 200:
    #     # fine-tune with no augmentation
    #     train_data = deep_snow.dataset.Dataset(train_path_list, selected_channels, augment=False, norm=True)
    #     train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True, num_workers=16)
                            
    # # calculate loss over previous 10 epochs for early stopping later
    # if epoch > 20:
    #     past_loss = np.mean(test_loss[-20:-10])

    test_loss.append(np.mean(test_epoch_loss))
    print(f'test loss: {np.mean(test_epoch_loss)}')
    
    # save loss 
    with open(f'../../loss/{model_name}_test_loss.pkl', 'wb') as f:
        pickle.dump(test_loss, f)
        
    with open(f'../../loss/{model_name}_train_loss.pkl', 'wb') as f:
        pickle.dump(train_loss, f)

    # # implement early stopping
    # if epoch > 20:
    #     current_loss = np.mean(test_loss[-10:-1])
    #     if current_loss > past_loss:
    #         counter +=1
    #         if counter >= 10:
    #             print('early stopping triggered')
    #             # save model
    #             torch.save(model.state_dict(), f'../../weights/{model_name}_{epoch}epochs')
    #             break
    #     else:
    #         counter = 0


Starting epoch 1


Epoch 1/50: 100%|██████████████████████████████| 909/909 [09:11<00:00,  1.65batch/s, batch loss=7.58e-5, mean epoch loss=0.000207]


Training loss: 0.0002070649558482419


Epoch 1/50: 100%|███████████████████████████████| 95/95 [01:49<00:00,  1.15s/batch, batch loss=0.000401, mean epoch loss=0.000398]


test loss: 0.0003982331841191473

Starting epoch 2


Epoch 2/50: 100%|█████████████████████████████| 909/909 [08:24<00:00,  1.80batch/s, batch loss=0.000115, mean epoch loss=0.000198]


Training loss: 0.00019756872686307094


Epoch 2/50: 100%|████████████████████████████████| 95/95 [02:13<00:00,  1.40s/batch, batch loss=9.92e-5, mean epoch loss=0.000404]


test loss: 0.00040364447088091096

Starting epoch 3


Epoch 3/50: 100%|█████████████████████████████| 909/909 [10:22<00:00,  1.46batch/s, batch loss=0.000151, mean epoch loss=0.000188]


Training loss: 0.00018825009426996802


Epoch 3/50: 100%|███████████████████████████████| 95/95 [01:23<00:00,  1.13batch/s, batch loss=0.000165, mean epoch loss=0.000406]


test loss: 0.00040641652387995764

Starting epoch 4


Epoch 4/50: 100%|██████████████████████████████| 909/909 [08:30<00:00,  1.78batch/s, batch loss=0.00015, mean epoch loss=0.000194]


Training loss: 0.0001943977100127606


Epoch 4/50: 100%|███████████████████████████████| 95/95 [01:37<00:00,  1.02s/batch, batch loss=0.000369, mean epoch loss=0.000399]


test loss: 0.0003985414626903979

Starting epoch 5


Epoch 5/50: 100%|█████████████████████████████| 909/909 [09:07<00:00,  1.66batch/s, batch loss=0.000218, mean epoch loss=0.000181]


Training loss: 0.00018073983395243276


Epoch 5/50: 100%|███████████████████████████████| 95/95 [01:50<00:00,  1.16s/batch, batch loss=0.000138, mean epoch loss=0.000446]


test loss: 0.0004463806966285981

Starting epoch 6


Epoch 6/50: 100%|█████████████████████████████| 909/909 [08:44<00:00,  1.73batch/s, batch loss=0.000167, mean epoch loss=0.000175]


Training loss: 0.00017456283663611056


Epoch 6/50: 100%|████████████████████████████████| 95/95 [02:40<00:00,  1.69s/batch, batch loss=0.000383, mean epoch loss=0.00042]


test loss: 0.00041984574238181506

Starting epoch 7


Epoch 7/50: 100%|█████████████████████████████| 909/909 [08:24<00:00,  1.80batch/s, batch loss=0.000147, mean epoch loss=0.000179]


Training loss: 0.0001789766661479489


Epoch 7/50: 100%|███████████████████████████████| 95/95 [00:50<00:00,  1.89batch/s, batch loss=0.000394, mean epoch loss=0.000418]


test loss: 0.0004177193391737283

Starting epoch 8


Epoch 8/50: 100%|█████████████████████████████| 909/909 [08:51<00:00,  1.71batch/s, batch loss=0.000245, mean epoch loss=0.000172]


Training loss: 0.0001717051139049806


Epoch 8/50: 100%|████████████████████████████████| 95/95 [02:08<00:00,  1.35s/batch, batch loss=0.00226, mean epoch loss=0.000468]


test loss: 0.00046846889599692074

Starting epoch 9


Epoch 9/50: 100%|███████████████████████████████| 909/909 [08:59<00:00,  1.69batch/s, batch loss=0.00011, mean epoch loss=0.00017]


Training loss: 0.0001698988678599342


Epoch 9/50: 100%|███████████████████████████████| 95/95 [01:50<00:00,  1.16s/batch, batch loss=0.000331, mean epoch loss=0.000411]


test loss: 0.0004111231586552764

Starting epoch 10


Epoch 10/50: 100%|████████████████████████████| 909/909 [09:22<00:00,  1.62batch/s, batch loss=0.000265, mean epoch loss=0.000174]


Training loss: 0.0001739540689098672


Epoch 10/50: 100%|███████████████████████████████| 95/95 [02:48<00:00,  1.78s/batch, batch loss=0.00096, mean epoch loss=0.000456]


test loss: 0.0004557612853731323

Starting epoch 11


Epoch 11/50: 100%|████████████████████████████| 909/909 [09:06<00:00,  1.66batch/s, batch loss=0.000144, mean epoch loss=0.000163]


Training loss: 0.00016332135911629051


Epoch 11/50: 100%|██████████████████████████████| 95/95 [01:50<00:00,  1.16s/batch, batch loss=0.000738, mean epoch loss=0.000434]


test loss: 0.000433908727657246

Starting epoch 12


Epoch 12/50: 100%|████████████████████████████| 909/909 [09:10<00:00,  1.65batch/s, batch loss=0.000238, mean epoch loss=0.000166]


Training loss: 0.00016620643534982066


Epoch 12/50: 100%|███████████████████████████████| 95/95 [02:49<00:00,  1.78s/batch, batch loss=0.00055, mean epoch loss=0.000418]


test loss: 0.00041774656243720335

Starting epoch 13


Epoch 13/50: 100%|████████████████████████████| 909/909 [08:49<00:00,  1.72batch/s, batch loss=0.000193, mean epoch loss=0.000164]


Training loss: 0.00016414264428822289


Epoch 13/50: 100%|██████████████████████████████| 95/95 [01:49<00:00,  1.16s/batch, batch loss=0.000169, mean epoch loss=0.000412]


test loss: 0.0004124392882467395

Starting epoch 14


Epoch 14/50: 100%|████████████████████████████| 909/909 [09:00<00:00,  1.68batch/s, batch loss=0.000298, mean epoch loss=0.000157]


Training loss: 0.00015691481452423793


Epoch 14/50: 100%|██████████████████████████████| 95/95 [02:48<00:00,  1.77s/batch, batch loss=0.000184, mean epoch loss=0.000416]


test loss: 0.00041569653181604255

Starting epoch 15


Epoch 15/50: 100%|█████████████████████████████| 909/909 [09:07<00:00,  1.66batch/s, batch loss=0.00028, mean epoch loss=0.000154]


Training loss: 0.0001544298149315357


Epoch 15/50: 100%|██████████████████████████████| 95/95 [01:51<00:00,  1.17s/batch, batch loss=0.000727, mean epoch loss=0.000431]


test loss: 0.00043070760638281506

Starting epoch 16


Epoch 16/50: 100%|████████████████████████████| 909/909 [09:00<00:00,  1.68batch/s, batch loss=0.000154, mean epoch loss=0.000157]


Training loss: 0.00015673376288575654


Epoch 16/50: 100%|████████████████████████████████| 95/95 [02:48<00:00,  1.78s/batch, batch loss=0.00082, mean epoch loss=0.00055]


test loss: 0.0005502022527109243

Starting epoch 17


Epoch 17/50: 100%|████████████████████████████| 909/909 [09:10<00:00,  1.65batch/s, batch loss=0.000124, mean epoch loss=0.000158]


Training loss: 0.00015766473594926232


Epoch 17/50: 100%|███████████████████████████████| 95/95 [01:51<00:00,  1.17s/batch, batch loss=0.00041, mean epoch loss=0.000414]


test loss: 0.0004138240105117132

Starting epoch 18


Epoch 18/50: 100%|████████████████████████████| 909/909 [09:03<00:00,  1.67batch/s, batch loss=0.000221, mean epoch loss=0.000152]


Training loss: 0.00015180652501452875


Epoch 18/50: 100%|██████████████████████████████| 95/95 [02:46<00:00,  1.75s/batch, batch loss=0.000255, mean epoch loss=0.000447]


test loss: 0.000447231913858559

Starting epoch 19


Epoch 19/50: 100%|█████████████████████████████| 909/909 [09:11<00:00,  1.65batch/s, batch loss=0.00018, mean epoch loss=0.000149]


Training loss: 0.00014925590951734454


Epoch 19/50: 100%|██████████████████████████████| 95/95 [01:52<00:00,  1.19s/batch, batch loss=0.000709, mean epoch loss=0.000434]


test loss: 0.00043375402226382377

Starting epoch 20


Epoch 20/50: 100%|████████████████████████████| 909/909 [08:48<00:00,  1.72batch/s, batch loss=0.000133, mean epoch loss=0.000148]


Training loss: 0.00014832123819724476


Epoch 20/50: 100%|██████████████████████████████| 95/95 [02:48<00:00,  1.77s/batch, batch loss=0.000372, mean epoch loss=0.000401]


test loss: 0.00040120093096782897

Starting epoch 21


Epoch 21/50: 100%|████████████████████████████| 909/909 [08:48<00:00,  1.72batch/s, batch loss=0.000177, mean epoch loss=0.000148]


Training loss: 0.0001476741177087699


Epoch 21/50: 100%|███████████████████████████████| 95/95 [01:49<00:00,  1.16s/batch, batch loss=0.00134, mean epoch loss=0.000434]


test loss: 0.0004337755132981233

Starting epoch 22


Epoch 22/50: 100%|████████████████████████████| 909/909 [08:58<00:00,  1.69batch/s, batch loss=0.000234, mean epoch loss=0.000143]


Training loss: 0.00014321003304863109


Epoch 22/50: 100%|██████████████████████████████| 95/95 [02:45<00:00,  1.75s/batch, batch loss=0.000269, mean epoch loss=0.000449]


test loss: 0.00044890669692234183

Starting epoch 23


Epoch 23/50: 100%|████████████████████████████| 909/909 [09:05<00:00,  1.67batch/s, batch loss=0.000114, mean epoch loss=0.000143]


Training loss: 0.00014286857210165814


Epoch 23/50: 100%|██████████████████████████████| 95/95 [01:50<00:00,  1.16s/batch, batch loss=0.000291, mean epoch loss=0.000435]


test loss: 0.0004353799703255247

Starting epoch 24


Epoch 24/50:  23%|██████▍                     | 207/909 [01:16<02:16,  5.15batch/s, batch loss=0.000105, mean epoch loss=0.000144]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 25/50: 100%|███████████████████████████████| 909/909 [09:02<00:00,  1.68batch/s, batch loss=4.9e-5, mean epoch loss=0.00014]


Training loss: 0.0001404549269658107


Epoch 25/50: 100%|██████████████████████████████| 95/95 [02:08<00:00,  1.35s/batch, batch loss=0.000503, mean epoch loss=0.000434]


test loss: 0.0004337812037096898

Starting epoch 26


Epoch 26/50: 100%|████████████████████████████| 909/909 [08:27<00:00,  1.79batch/s, batch loss=0.000176, mean epoch loss=0.000135]


Training loss: 0.00013501167664174802


Epoch 26/50: 100%|██████████████████████████████| 95/95 [02:49<00:00,  1.78s/batch, batch loss=0.000469, mean epoch loss=0.000448]


test loss: 0.00044822920784721835

Starting epoch 27


Epoch 27/50: 100%|████████████████████████████| 909/909 [08:52<00:00,  1.71batch/s, batch loss=0.000142, mean epoch loss=0.000146]


Training loss: 0.00014606603590326802


Epoch 27/50: 100%|███████████████████████████████| 95/95 [01:51<00:00,  1.17s/batch, batch loss=0.00023, mean epoch loss=0.000433]


test loss: 0.00043301497177086084

Starting epoch 28


Epoch 28/50: 100%|████████████████████████████| 909/909 [08:45<00:00,  1.73batch/s, batch loss=0.000153, mean epoch loss=0.000136]


Training loss: 0.00013565243435031374


Epoch 28/50: 100%|██████████████████████████████| 95/95 [02:49<00:00,  1.78s/batch, batch loss=0.000584, mean epoch loss=0.000484]


test loss: 0.0004840376662813421

Starting epoch 29


Epoch 29/50: 100%|█████████████████████████████| 909/909 [08:47<00:00,  1.72batch/s, batch loss=3.14e-5, mean epoch loss=0.000134]


Training loss: 0.0001340568393667733


Epoch 29/50: 100%|██████████████████████████████| 95/95 [01:50<00:00,  1.16s/batch, batch loss=0.000272, mean epoch loss=0.000423]


test loss: 0.0004232123584523307

Starting epoch 30


Epoch 30/50: 100%|████████████████████████████| 909/909 [09:01<00:00,  1.68batch/s, batch loss=0.000157, mean epoch loss=0.000143]


Training loss: 0.00014333870343736087


Epoch 30/50: 100%|██████████████████████████████| 95/95 [02:49<00:00,  1.78s/batch, batch loss=0.000516, mean epoch loss=0.000468]


test loss: 0.000468186423332602

Starting epoch 31


Epoch 31/50: 100%|█████████████████████████████| 909/909 [08:56<00:00,  1.69batch/s, batch loss=8.01e-5, mean epoch loss=0.000136]


Training loss: 0.0001364299973034042


Epoch 31/50: 100%|██████████████████████████████| 95/95 [01:51<00:00,  1.18s/batch, batch loss=0.000486, mean epoch loss=0.000471]


test loss: 0.0004705078891435589

Starting epoch 32


Epoch 32/50: 100%|████████████████████████████| 909/909 [08:29<00:00,  1.78batch/s, batch loss=0.000132, mean epoch loss=0.000145]


Training loss: 0.0001452473504393958


Epoch 32/50: 100%|███████████████████████████████| 95/95 [02:47<00:00,  1.76s/batch, batch loss=0.000582, mean epoch loss=0.00043]


test loss: 0.00042992160206746407

Starting epoch 33


Epoch 33/50: 100%|█████████████████████████████| 909/909 [08:32<00:00,  1.77batch/s, batch loss=5.15e-5, mean epoch loss=0.000129]


Training loss: 0.00012870360402459133


Epoch 33/50: 100%|███████████████████████████████| 95/95 [01:50<00:00,  1.17s/batch, batch loss=6.77e-5, mean epoch loss=0.000458]


test loss: 0.0004578636607651501

Starting epoch 34


Epoch 34/50: 100%|█████████████████████████████| 909/909 [08:44<00:00,  1.73batch/s, batch loss=0.000163, mean epoch loss=0.00013]


Training loss: 0.00012966846780128465


Epoch 34/50: 100%|██████████████████████████████| 95/95 [02:48<00:00,  1.78s/batch, batch loss=0.000736, mean epoch loss=0.000451]


test loss: 0.00045072164962460336

Starting epoch 35


Epoch 35/50: 100%|████████████████████████████| 909/909 [08:37<00:00,  1.76batch/s, batch loss=0.000294, mean epoch loss=0.000131]


Training loss: 0.00013131122756217904


Epoch 35/50: 100%|███████████████████████████████| 95/95 [01:50<00:00,  1.17s/batch, batch loss=0.00027, mean epoch loss=0.000455]


test loss: 0.0004552878547416951

Starting epoch 36


Epoch 36/50: 100%|████████████████████████████| 909/909 [09:01<00:00,  1.68batch/s, batch loss=0.000112, mean epoch loss=0.000128]


Training loss: 0.00012782599411892678


Epoch 36/50: 100%|██████████████████████████████| 95/95 [02:47<00:00,  1.77s/batch, batch loss=0.000137, mean epoch loss=0.000463]


test loss: 0.0004631673825267506

Starting epoch 37


Epoch 37/50: 100%|█████████████████████████████| 909/909 [08:50<00:00,  1.71batch/s, batch loss=5.99e-5, mean epoch loss=0.000129]


Training loss: 0.00012865180255219825


Epoch 37/50:  34%|██████████                    | 32/95 [00:52<01:43,  1.65s/batch, batch loss=0.000449, mean epoch loss=0.000597]Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: <function CachingFileManager.__del__ at 0x7fcfc7a8f600>Exception ignored in: <function CachingFileManager.__del__ at 0x7fcfc7a8f600>Exception ignored in: Exception ignored in: <function CachingFileManager.__del__ at 0x7fcfc7a8f600><function CachingFileManager.__del__ at 0x7fcfc7a8f600><function CachingFileManager.__del__ at 0x7fcfc7a8f600>

<function CachingFileManager.__del__ at 0x7fcfc7a8f600><function File.close at 0x7fcfb95b7b00>

<function File.close at 0x7fcfb95be840><function CachingFileManager.__del__ at 0x7fcfc7a8f600>Traceback (most recent call last):

Traceback (most recent call last):

Traceback (most recent call last):
Traceback (most recent call last):

  File "/mnt/Backups/gbrench/sw/miniconda3/envs/deep-snow/lib/

KeyboardInterrupt: 

## Examine results

In [None]:
#load previous model
model = deep_snow.models.ResDepth(n_input_channels=len(input_channels), depth=5)
model.load_state_dict(torch.load('../../weights/quinn_ResDepth_v4_74epochs'))
model.to('cuda');

In [None]:
with open(f'../../loss/quinn_ResDepth_v11_test_loss.pkl', 'rb') as f:
        test_loss = pickle.load(f)

with open(f'../../loss/quinn_ResDepth_v11_train_loss.pkl', 'rb') as f:
        train_loss = pickle.load(f)


# plot loss over all epochs
f, ax = plt.subplots(figsize=(10,5))
ax.plot(train_loss, label='training')
ax.plot(test_loss, label='testing')
ax.set_xlabel('epoch')
ax.set_ylabel('MSE loss')
ax.set_title('Loss')
ax.legend()

# save figure
plt.savefig(f'../../figs/quinn_ResDepth_v11_finetune_loss.png', dpi=300)

In [None]:
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=1, shuffle=True)

In [None]:
# visualize model predictions
sns.set_theme()
num_samples = 1

for i, data_tuple in enumerate(val_loader):
    if i < num_samples:
        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}

        with torch.no_grad():
            # Concatenate input feature channels, make prediction
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
            pred_sd = model(inputs)  # Generate predictions using the model
            pred_sd = pred_sd.to('cpu')
        
        f, ax = plt.subplots(3, 3, figsize=(15, 15), sharex=True, sharey=True)
        ax[0, 0].imshow(pred_sd.squeeze(), cmap='Blues', vmin=0, vmax=0.4, interpolation=None)
        ax[0, 0].set_title('Predicted Snow Depth')
        ax[0, 1].imshow(data_dict['aso_sd'].squeeze(), cmap='Blues', vmin=0, vmax=0.4, interpolation=None)
        ax[0, 1].set_title('ASO Lidar Snow Depth')
        ax[0, 2].imshow(data_dict['elevation'].squeeze(), cmap='viridis', interpolation='none')
        ax[0, 2].set_title('Copernicus DEM')
        ax[1, 0].imshow(data_dict['fcf'].squeeze(), cmap='Greens', interpolation='none')
        ax[1, 0].set_title('Fractional Forest Cover')
        norm_max = np.max([data_dict['green'].max(), data_dict['red'].max(), data_dict['blue'].max()]) # there are better ways to do this
        ax[1, 1].imshow(torch.cat((data_dict['red'].squeeze()[:, :, None]/norm_max, data_dict['green'].squeeze()[:, :, None]/norm_max, data_dict['blue'].squeeze()[:, :, None]/norm_max), 2).squeeze(), interpolation='none')
        ax[1, 1].set_title('true color image')
        ax[1, 2].imshow(data_dict['aso_gap_map'].squeeze() + data_dict['rtc_gap_map'].squeeze() + data_dict['s2_gap_map'].squeeze(), cmap='Purples', interpolation='none')
        ax[1, 2].set_title('ASO and RTC Gaps')
        ax[2, 0].imshow(data_dict['ndvi'].squeeze(), cmap='YlGn', interpolation='none')
        ax[2, 0].set_title('NDVI')
        ax[2, 1].imshow(data_dict['ndsi'].squeeze(), cmap='BuPu', interpolation='none')
        ax[2, 1].set_title('NDSI')
        ax[2, 2].imshow(data_dict['ndwi'].squeeze(), cmap='YlGnBu', interpolation='none')
        ax[2, 2].set_title('NDWI')
        
        # modify plot style
        for a in ax.flat:
            a.set_aspect('equal')
            a.set_xticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[1], 43))
            a.set_yticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[0], 43))
            a.grid(True, linewidth=1, alpha=0.5)
        
        f.tight_layout()
    else:
        break

In [None]:
# visualize prediction error
sns.set_theme()
num_samples = 1
norm_dict = deep_snow.dataset.norm_dict

for i, data_tuple in enumerate(val_loader):
    if i < num_samples:
        # read data into dictionary
        data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}

        with torch.no_grad():
            # Concatenate input feature channels, make prediction
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
            pred_sd = model(inputs)  # Generate predictions using the model
            pred_sd = pred_sd.to('cpu')

            # mask nodata areas
            pred_sd = torch.where(data_dict['aso_gap_map']+data_dict['rtc_gap_map'] + data_dict['s2_gap_map'] == 0, pred_sd, torch.zeros_like(pred_sd))
            aso_sd = torch.where(data_dict['aso_gap_map']+data_dict['rtc_gap_map'] + data_dict['s2_gap_map'] == 0, data_dict['aso_sd'], torch.zeros_like(pred_sd))

            # undo normalization
            pred_sd = deep_snow.dataset.undo_norm(pred_sd, deep_snow.dataset.norm_dict['aso_sd']).squeeze()
            aso_sd = deep_snow.dataset.undo_norm(aso_sd, deep_snow.dataset.norm_dict['aso_sd']).squeeze()
            
            # mask values above 0
            pred_sd = torch.where(pred_sd >= 0, pred_sd, torch.zeros_like(pred_sd))
            
            f, ax = plt.subplots(2, 2, figsize=(10,10), sharex=True, sharey=True)
            im0 = ax[0, 0].imshow(pred_sd, cmap='Blues', vmin=0, vmax=2, interpolation='none') 
            ax[0, 0].set_title('predicted snow depth')
            f.colorbar(im0, shrink=0.5)
            im1 = ax[0, 1].imshow(aso_sd, cmap='Blues', vmin=0, vmax=2, interpolation='none')
            ax[0, 1].set_title('ASO lidar snow depth')
            f.colorbar(im1, shrink=0.5)

            im2 = ax[1, 0].imshow(aso_sd-pred_sd, cmap='RdBu', vmin=-2, vmax=2, interpolation='none') 
            ax[1, 0].set_title('ASO snow depth - predicted snow depth')
            f.colorbar(im2, shrink=0.5)
            norm_max = np.max([data_dict['green'].max(), data_dict['red'].max(), data_dict['blue'].max()]) # there are better ways to do this
            im3 = ax[1, 1].imshow(torch.cat((data_dict['red'].squeeze()[:, :, None]/norm_max, data_dict['green'].squeeze()[:, :, None]/norm_max, data_dict['blue'].squeeze()[:, :, None]/norm_max), 2).squeeze(), interpolation='none')
            ax[1, 1].set_title('true color image')
            f.colorbar(im3, shrink=0.5)

            # modify plot style
            for a in ax.flat:
                a.set_aspect('equal')
                a.set_xticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[1], 43))
                a.set_yticks(np.arange(0, data_dict['aso_sd'].squeeze().shape[0], 43))
                a.grid(True, linewidth=1, alpha=0.5)

            plt.tight_layout()
    else: 
        break