In [1]:
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
from glob import glob
from PIL import Image
import seaborn as sns
import math
import random
import xarray as xr
from torch.masked import masked_tensor, as_masked_tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
train_data_dir = '/mnt/Backups/gbrench/repos/crunchy-snow/data/subsetsv1/train'
train_path_list = glob(f'{train_data_dir}/ASO_50M_SD*.nc') #[0:1000]

val_data_dir = '/mnt/Backups/gbrench/repos/crunchy-snow/data/subsetsv1/val'
val_path_list = glob(f'{val_data_dir}/ASO_50M_SD*.nc') #[0:1000]

In [3]:
# these are set by finding the min and max across the entire dataset
norm_dict = {'aso_sd':[0, 24.9],
             'vv':[0, 13523.8],
             'vh':[0, 43.2],
             'AOT':[0, 572.1],
             'coastal':[0, 23459.1],
             'blue':[0, 23004.1],
             'green':[0, 26440.1],
             'red':[0, 21576.1],
             'red_edge1':[0, 20796.1],
             'red_edge2':[0, 20432.1],
             'red_edge3':[0, 20149.1],
             'nir':[0, 21217.1],
             'water_vapor':[0, 18199.1],
             'swir1':[0, 17549.1],
             'swir2':[0, 17314.1],
             'scene_class_map':[0, 15],
             'water_vapor_product':[0, 6517.5],
             'elevation':[-100, 9000],
             'dowy':[0, 364],
             'lat': [-90, 90],
             'lon': [-180, 180]}

In [4]:
def calc_dowy(doy):
    if doy < 274:
        dowy = doy + (365-274)
    elif doy >= 274:
        dowy = doy-274
    return dowy

In [5]:
def calc_norm(tensor, minmax_list):
    '''
    normalize a tensor between 0 and 1 using a min and max value stored in a list
    '''
    normalized = (tensor-minmax_list[0])/(minmax_list[1]-minmax_list[0])
    return normalized

def undo_norm(tensor, minmax_list):
    original = (tensor*(minmax_list[1]-minmax_list[0]))+minmax_list[0]
    return original

In [6]:
# define dataset 
class dataset(torch.utils.data.Dataset):
    '''
    class that reads data from a netCDF and returns normalized tensors 
    '''
    def __init__(self, path_list, norm_dict, norm=True):
        self.path_list = path_list
        self.norm_dict = norm_dict
        self.norm = norm
        
    #dataset length
    def __len__(self):
        self.filelength = len(self.path_list)
        return self.filelength
    
    #load images
    def __getitem__(self,idx):
        path = self.path_list[idx]
        ds = xr.open_dataset(path)
        
        # convert to tensors
        aso_sd = torch.from_numpy(np.float32(ds.aso_sd.values))
        snowon_vv = torch.from_numpy(np.float32(ds.snowon_vv.values))
        snowon_vh = torch.from_numpy(np.float32(ds.snowon_vh.values))
        snowoff_vv = torch.from_numpy(np.float32(ds.snowoff_vv.values))
        snowoff_vh = torch.from_numpy(np.float32(ds.snowoff_vh.values))
        snowon_vv_mean = torch.from_numpy(np.float32(ds.snowon_vv_mean.values))
        snowon_vh_mean = torch.from_numpy(np.float32(ds.snowon_vh_mean.values))
        snowoff_vv_mean = torch.from_numpy(np.float32(ds.snowoff_vv_mean.values))
        snowoff_vh_mean = torch.from_numpy(np.float32(ds.snowoff_vh_mean.values))
        aerosol_optical_thickness = torch.from_numpy(np.float32(ds.AOT.values))
        coastal_aerosol = torch.from_numpy(np.float32(ds.B01.values))
        blue = torch.from_numpy(np.float32(ds.B02.values))
        green = torch.from_numpy(np.float32(ds.B03.values))
        red = torch.from_numpy(np.float32(ds.B04.values))
        red_edge1 = torch.from_numpy(np.float32(ds.B05.values))
        red_edge2 = torch.from_numpy(np.float32(ds.B06.values))
        red_edge3 = torch.from_numpy(np.float32(ds.B07.values))
        nir = torch.from_numpy(np.float32(ds.B08.values))
        water_vapor = torch.from_numpy(np.float32(ds.B09.values))
        swir1 = torch.from_numpy(np.float32(ds.B11.values))
        swir2 = torch.from_numpy(np.float32(ds.B12.values))
        scene_class_map = torch.from_numpy(np.float32(ds.SCL.values))
        water_vapor_product = torch.from_numpy(np.float32(ds.WVP.values))
        fcf = torch.from_numpy(np.float32(ds.fcf.values))
        elevation = torch.from_numpy(np.float32(ds.elevation.values))
        aso_gap_map = torch.from_numpy(np.float32(ds.aso_gap_map.values))
        rtc_gap_map = torch.from_numpy(np.float32(ds.rtc_gap_map.values))
        rtc_mean_gap_map = torch.from_numpy(np.float32(ds.rtc_mean_gap_map.values))
      
        # calculate some other inputs for our CNN
        ndvi = (nir - red)/(nir + red)
        ndsi = (green - swir1)/(green + swir1)
        ndwi = (green - nir)/(green + nir)

        # snowon_ratio = (snowon_vv - snowon_vh)/(snowon_vv + snowon_vh)
        # snowoff_ratio = (snowoff_vv - snowoff_vh)/(snowoff_vv + snowoff_vh)

        # fn = os.path.split(path)[-1]
        # dowy_1d = calc_dowy(pd.to_datetime(fn.split('_')[4]).dayofyear)
        # dowy = torch.full_like(aso_sd, dowy_1d)
        
        # normalize layers (except gap maps and fcf)
        if self.norm == True:
            aso_sd = torch.nan_to_num(calc_norm(aso_sd, self.norm_dict['aso_sd']), 0)
            snowon_vv = calc_norm(snowon_vv, self.norm_dict['vv'])
            snowon_vh = calc_norm(snowon_vh, self.norm_dict['vh'])
            snowoff_vv = calc_norm(snowoff_vv, self.norm_dict['vv'])
            snowoff_vh = calc_norm(snowoff_vh, self.norm_dict['vh'])
            snowon_vv_mean = calc_norm(snowon_vv_mean, self.norm_dict['vv'])
            snowon_vh_mean = calc_norm(snowon_vh_mean, self.norm_dict['vh'])
            snowoff_vv_mean = calc_norm(snowoff_vv_mean, self.norm_dict['vv'])
            snowoff_vh_mean = calc_norm(snowoff_vh_mean, self.norm_dict['vh'])
            aerosol_optical_thickness = calc_norm(aerosol_optical_thickness, self.norm_dict['AOT'])
            coastal_aerosol = calc_norm(coastal_aerosol, self.norm_dict['coastal'])
            blue = calc_norm(blue, self.norm_dict['blue'])
            green = calc_norm(green, self.norm_dict['green'])
            red = calc_norm(red, self.norm_dict['red'])
            red_edge1 = calc_norm(red_edge1, self.norm_dict['red_edge1'])
            red_edge2 = calc_norm(red_edge2, self.norm_dict['red_edge2'])
            red_edge3 = calc_norm(red_edge3, self.norm_dict['red_edge3'])
            nir = calc_norm(nir, self.norm_dict['nir'])
            water_vapor = calc_norm(water_vapor, self.norm_dict['water_vapor'])
            swir1 = calc_norm(swir1, self.norm_dict['swir1'])
            swir2 = calc_norm(swir2, self.norm_dict['swir2'])
            scene_class_map = calc_norm(scene_class_map, self.norm_dict['scene_class_map'])
            water_vapor_product = calc_norm(water_vapor_product, self.norm_dict['water_vapor_product'])
            elevation = calc_norm(elevation, self.norm_dict['elevation'])
            ndvi = torch.nan_to_num(calc_norm(ndvi, [-1, 1]), 0)
            ndsi = torch.nan_to_num(calc_norm(ndsi, [-1, 1]), 0)
            ndwi = torch.nan_to_num(calc_norm(ndwi, [-1, 1]), 0)
            # snowon_ratio = torch.nan_to_num(calc_norm(snowon_ratio, [-1, 1]), 0)
            # snowoff_ratio = torch.nan_to_num(calc_norm(snowoff_ratio, [-1, 1]), 0)
            #dowy = calc_norm(dowy, self.norm_dict['dowy'])
            
        
        # return only selected bands, for now
        return aso_sd[None, :, :], snowon_vv[None, :, :], snowon_vh[None, :, :], snowoff_vv[None, :, :], snowoff_vh[None, :, :], blue[None, :, :], green[None, :, :], red[None, :, :], fcf[None, :, :], elevation[None, :, :], aso_gap_map[None, :, :], rtc_gap_map[None, :, :], ndvi[None, :, :], ndsi[None, :, :], ndwi[None, :, :]

In [7]:
# create dataloaders
train_data = dataset(train_path_list, norm_dict, norm=True)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
val_data = dataset(val_path_list, norm_dict, norm=True)
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=1, shuffle=True)

In [8]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, emb_size):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, emb_size, H/patch_size, W/patch_size]
        x = x.flatten(2)  # [B, emb_size, N]
        x = x.transpose(1, 2)  # [B, N, emb_size]
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, emb_size, num_heads, ff_hidden_mult=4, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.ln1 = nn.LayerNorm(emb_size)
        self.mha = nn.MultiheadAttention(emb_size, num_heads, dropout=dropout)
        self.ln2 = nn.LayerNorm(emb_size)
        self.ff = nn.Sequential(
            nn.Linear(emb_size, ff_hidden_mult * emb_size),
            nn.GELU(),
            nn.Linear(ff_hidden_mult * emb_size, emb_size),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.dropout(self.mha(self.ln1(x), self.ln1(x), self.ln1(x))[0])
        x = x + self.dropout(self.ff(self.ln2(x)))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, in_channels, patch_size, emb_size, num_layers, num_heads, img_size):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size)
        self.pos_embedding = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2, emb_size))
        self.transformer_encoders = nn.ModuleList([
            TransformerEncoder(emb_size, num_heads) for _ in range(num_layers)
        ])
        self.ln = nn.LayerNorm(emb_size)
        self.fc = nn.Linear(emb_size, patch_size * patch_size)

    def forward(self, x):
        B, C, H, W = x.size()
        x = self.patch_embedding(x)
        x = x + self.pos_embedding
        for encoder in self.transformer_encoders:
            x = encoder(x)
        x = self.ln(x)
        x = self.fc(x)
        x = x.transpose(1, 2).reshape(B, -1, H, W)
        return x

In [8]:
model = VisionTransformer(in_channels=12, patch_size=16, emb_size=256, num_layers=6, num_heads=8, img_size=128)
model.to('cuda')  # Run on GPU

# Define optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=0.0003)
loss_fn = nn.MSELoss()
epochs = 50

train_loss = []
val_loss = []

for epoch in range(epochs):
    print(f'\nStarting epoch {epoch}')
    epoch_loss = []
    val_temp_loss = []

    # Loop through training data
    for (aso_sd, snowon_vv, snowon_vh, snowoff_vv, snowoff_vh, blue, green, red, fcf, elevation, aso_gap_map, rtc_gap_map, ndvi, ndsi, ndwi) in train_loader:
        model.train()
        optimizer.zero_grad()
        
        # Concatenate all feature channels
        inputs = torch.cat((snowon_vv, snowon_vh, snowoff_vv, snowoff_vh, blue, green, red, fcf, elevation, ndvi, ndsi, ndwi), dim=1).to('cuda')
        pred_sd = torch.clamp(model(inputs), 0, 1)  # Generate predictions

        # Limit prediction to areas with valid data
        pred_sd = torch.where(aso_gap_map.to('cuda') + rtc_gap_map.to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
        aso_sd = torch.where(aso_gap_map.to('cuda') + rtc_gap_map.to('cuda') == 0, aso_sd.to('cuda'), torch.zeros_like(pred_sd).to('cuda'))

        # Calculate loss
        loss = loss_fn(pred_sd, aso_sd.to('cuda'))
        epoch_loss.append(loss.item())

        loss.backward()  # Propagate the gradients in backward pass
        optimizer.step()

    train_loss.append(np.mean(epoch_loss))
    print(f'Training loss: {np.mean(epoch_loss)}')

    # Run model on validation data
    for (aso_sd, snowon_vv, snowon_vh, snowoff_vv, snowoff_vh, blue, green, red, fcf, elevation, aso_gap_map, rtc_gap_map, ndvi, ndsi, ndwi) in val_loader:
        with torch.no_grad():
            model.eval()
            
            # Concatenate all feature channels
            inputs = torch.cat((snowon_vv, snowon_vh, snowoff_vv, snowoff_vh, blue, green, red, fcf, elevation, ndvi, ndsi, ndwi), dim=1).to('cuda')
            pred_sd = torch.clamp(model(inputs), 0, 1)  # Generate predictions

            pred_sd = torch.where(aso_gap_map.to('cuda') + rtc_gap_map.to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
            aso_sd = torch.where(aso_gap_map.to('cuda') + rtc_gap_map.to('cuda') == 0, aso_sd.to('cuda'), torch.zeros_like(pred_sd).to('cuda'))

            loss = loss_fn(pred_sd.to('cuda'), aso_sd.to('cuda'))
            val_temp_loss.append(loss.item())

    val_loss.append(np.mean(val_temp_loss))
    print(f'Validation loss: {np.mean(val_temp_loss)}')
    torch.save(model.state_dict(), f'../../weights/ViT_v0')


Starting epoch 0
Training loss: 0.0030534321591835665
Validation loss: 0.0027223536457131325

Starting epoch 1
Training loss: 0.0029332178519314066


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f580520e590>>
Traceback (most recent call last):
  File "/mnt/Backups/gbrench/sw/miniconda3/envs/crunchy-snow/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f580520e590>>
Traceback (most recent call last):
  File "/mnt/Backups/gbrench/sw/miniconda3/envs/crunchy-snow/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


KeyboardInterrupt: 

In [None]:
#plot loss over all epochs
f, ax = plt.subplots(figsize=(10,5))
ax.plot(train_loss, label='training')
ax.plot(val_loss, label='validation')
ax.set_xlabel('epoch')
ax.set_ylabel('L1 loss')
ax.set_title('Loss')
ax.legend()
plt.savefig('../../figs/Vit_v0_loss1.png', dpi=300)

In [None]:
# Visualize model outputs
num_images = 1

for i, (aso_sd, snowon_vv, snowon_vh, snowoff_vv, snowoff_vh, blue, green, red, fcf, elevation, aso_gap_map, rtc_gap_map, ndvi, ndsi, ndwi) in enumerate(val_loader):
    if i < num_images:
        with torch.no_grad():
            # Concatenate all feature channels
            inputs = torch.cat((snowon_vv, snowon_vh, snowoff_vv, snowoff_vh, blue, green, red, fcf, elevation, ndvi, ndsi, ndwi), dim=1).to('cuda')
            pred_sd = model(inputs)  # Generate predictions using the model
            pred_sd = pred_sd.to('cpu')
            
            f, ax = plt.subplots(3, 3, figsize=(15, 15))
            ax[0, 0].imshow(pred_sd.squeeze(), cmap='Blues', vmin=0, vmax=0.4, interpolation=None)
            ax[0, 0].set_title('Predicted Snow Depth')
            ax[0, 1].imshow(aso_sd.squeeze(), cmap='Blues', vmin=0, vmax=0.4, interpolation=None)
            ax[0, 1].set_title('ASO Lidar Snow Depth')
            ax[0, 2].imshow(elevation.squeeze(), cmap='viridis', interpolation='none')
            ax[0, 2].set_title('Copernicus DEM')
            ax[1, 0].imshow(fcf.squeeze(), cmap='Greens', interpolation='none')
            ax[1, 0].set_title('Fractional Forest Cover')
            norm_max = np.max([green.max(), red.max(), blue.max()])  # There are better ways to do this
            ax[1, 1].imshow(torch.cat((red.squeeze()[:, :, None] / norm_max, green.squeeze()[:, :, None] / norm_max, blue.squeeze()[:, :, None] / norm_max), 2).squeeze(), interpolation='none')
            ax[1, 1].set_title('True Color Image')
            ax[1, 2].imshow(aso_gap_map.squeeze() + rtc_gap_map.squeeze(), cmap='Purples', interpolation='none')
            ax[1, 2].set_title('ASO and RTC Gaps')
            ax[2, 0].imshow(ndvi.squeeze(), cmap='YlGn', interpolation='none')
            ax[2, 0].set_title('NDVI')
            ax[2, 1].imshow(ndsi.squeeze(), cmap='BuPu', interpolation='none')
            ax[2, 1].set_title('NDSI')
            ax[2, 2].imshow(ndwi.squeeze(), cmap='YlGnBu', interpolation='none')
            ax[2, 2].set_title('NDWI')

            plt.tight_layout()
            # plt.savefig(f'pred_raw{i}.png', dpi=300)
    else:
        break