In [1]:
from torch.utils.data import DataLoader
import numpy as np
import torch
from glob import glob

import crunchy_snow.dataset

In [2]:
train_data_dir = '/mnt/Backups/gbrench/repos/crunchy-snow/data/subsets_v2/train'
train_path_list = glob(f'{train_data_dir}/ASO_50M_SD*.nc')

val_data_dir = '/mnt/Backups/gbrench/repos/crunchy-snow/data/subsets_v2/val'
val_path_list = glob(f'{val_data_dir}/ASO_50M_SD*.nc')

In [3]:
# define data to be returned by dataloader
selected_channels = [
    # ASO products
    'aso_sd', # ASO lidar snow depth (target dataset)
    'aso_gap_map', # gaps in ASO data
    
    # Sentinel-1 products
    'snowon_vv', # snow on Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vh', # snow on Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vv', # snow off Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vh', # snow off Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vv_mean', # snow on Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_vh_mean', # snow on Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vv_mean', # snow off Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vh_mean', # snow off Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_cr', # cross ratio, snowon_vh - snowon_vv
    'snowoff_cr', # cross ratio, snowoff_vh - snowoff_vv
    'delta_cr', # change in cross ratio, snowon_cr - snowoff_cr
    'rtc_gap_map', # gaps in Sentinel-1 data
    'rtc_mean_gap_map', # gaps in Sentinel-1 mean data
    
    # Sentinel-2 products 
    'aerosol_optical_thickness', # snow on Sentinel-2 aerosol optical thickness band 
    'coastal_aerosol', # snow on Sentinel-2 coastal aerosol band
    'blue', # snow on Sentinel-2 blue band
    'green', # snow on Sentinel-2 green band
    'red', # snow on Sentinel-2 red band
    'red_edge1', # snow on Sentinel-2 red edge 1 band
    'red_edge2', # snow on Sentinel-2 red edge 2 band
    'red_edge3', # snow on Sentinel-2 red edge 3 band
    'nir', # snow on Sentinel-2 near infrared band
    'water_vapor', # snow on Sentinel-2 water vapor
    'swir1', # snow on Sentinel-2 shortwave infrared band 1
    'swir2', # snow on Sentinel-2 shortwave infrared band 2
    'scene_class_map', # snow on Sentinel-2 scene classification product
    'water_vapor_product', # snow on Sentinel-2 water vapor product
    'ndvi', # Normalized Difference Vegetation Index from Sentinel-2
    'ndsi', # Normalized Difference Snow Index from Sentinel-2
    'ndwi', # Normalized Difference Water Index from Sentinel-2
    's2_gap_map', # gaps in Sentinel-2 data

    # PROBA-V global land cover dataset (Buchhorn et al., 2020)
    'fcf', # fractional forest cover
    
    # COP30 digital elevation model      
    'elevation',

    # latitude and longitude
    'latitude',
    'longitude',

    # day of water year
    'dowy'
                    ]

# prepare training and validation dataloaders
train_data = crunchy_snow.dataset.Dataset(train_path_list, selected_channels, norm=False)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2000, shuffle=False)

# prepare training and validation dataloaders
val_data = crunchy_snow.dataset.Dataset(val_path_list, selected_channels, norm=False)
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=2000, shuffle=False)

In [4]:
# find dataset min and max for normalization
norm_dict = {}
for i, outputs in enumerate(train_loader):
    print(f'loop {i+1}')
    for j, item in enumerate(outputs):
        data_name = selected_channels[j]
        if i == 0:
            norm_dict[data_name] = [item.min().item(), item.max().item()]
        if item.max() > norm_dict[data_name][1]:
            norm_dict[data_name][1] = item.max().item()
        if item.min() < norm_dict[data_name][0] and not item.min() == 0:
            norm_dict[data_name][0] = item.min().item()

loop 1
loop 2
loop 3
loop 4
loop 5
loop 6
loop 7
loop 8
loop 9
loop 10
loop 11
loop 12
loop 13
loop 14
loop 15
loop 16


In [5]:
norm_dict

{'aso_sd': [0.0, 77.68216705322266],
 'aso_gap_map': [0.0, 1.0],
 'snowon_vv': [-58.50587463378906, 39.43202590942383],
 'snowon_vh': [-67.34005737304688, 18.352039337158203],
 'snowoff_vv': [-52.92261505126953, 41.31097412109375],
 'snowoff_vh': [-61.13330841064453, 15.569765090942383],
 'snowon_vv_mean': [-58.50587463378906, 36.8072624206543],
 'snowon_vh_mean': [-67.34005737304688, 17.918773651123047],
 'snowoff_vv_mean': [-57.858699798583984, 40.318572998046875],
 'snowoff_vh_mean': [-69.16107940673828, 14.829648971557617],
 'snowon_cr': [-40.59846878051758, 15.802289962768555],
 'snowoff_cr': [-42.27442169189453, 12.083983421325684],
 'delta_cr': [-32.43565368652344, 26.25350570678711],
 'rtc_gap_map': [0.0, 1.0],
 'rtc_mean_gap_map': [0.0, 1.0],
 'aerosol_optical_thickness': [0.0, 572.0],
 'coastal_aerosol': [0.0, 23459.0],
 'blue': [0.0, 23004.0],
 'green': [0.0, 26440.0],
 'red': [0.0, 21576.0],
 'red_edge1': [0.0, 20796.0],
 'red_edge2': [0.0, 20432.0],
 'red_edge3': [0.0, 201

In [4]:
# find dataset min and max for normalization
norm_dict = {}
for i, outputs in enumerate(val_loader):
    print(f'loop {i+1}')
    for j, item in enumerate(outputs):
        data_name = selected_channels[j]
        if i == 0:
            norm_dict[data_name] = [item.min().item(), item.max().item()]
        if item.max() > norm_dict[data_name][1]:
            norm_dict[data_name][1] = item.max().item()
        if item.min() < norm_dict[data_name][0] and not item.min() == 0:
            norm_dict[data_name][0] = item.min().item()

loop 1
loop 2


In [5]:
norm_dict

{'aso_sd': [0.0, 397.2589111328125],
 'aso_gap_map': [0.0, 1.0],
 'snowon_vv': [-58.84242630004883, 23.79745101928711],
 'snowon_vh': [-65.97496795654297, 9.849296569824219],
 'snowoff_vv': [-55.99445724487305, 20.7916259765625],
 'snowoff_vh': [-62.95058822631836, 10.464592933654785],
 'snowon_vv_mean': [-58.84242630004883, 21.081655502319336],
 'snowon_vh_mean': [-65.97496795654297, 9.639721870422363],
 'snowoff_vv_mean': [-55.639896392822266, 19.098800659179688],
 'snowoff_vh_mean': [-65.64103698730469, 10.436249732971191],
 'snowon_cr': [-35.89558029174805, 10.695364952087402],
 'snowoff_cr': [-27.902101516723633, 11.517260551452637],
 'delta_cr': [-28.254453659057617, 23.295503616333008],
 'rtc_gap_map': [0.0, 1.0],
 'rtc_mean_gap_map': [0.0, 1.0],
 'aerosol_optical_thickness': [0.0, 457.0],
 'coastal_aerosol': [0.0, 24304.0],
 'blue': [0.0, 23371.0],
 'green': [0.0, 21459.0],
 'red': [0.0, 20357.0],
 'red_edge1': [0.0, 19776.0],
 'red_edge2': [0.0, 18810.0],
 'red_edge3': [0.0, 1