In [1]:
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from glob import glob
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import pickle
import random
import pandas as pd

import crunchy_snow.models
import crunchy_snow.dataset

In [2]:
# get paths to data
train_data_dir = '/mnt/Backups/gbrench/repos/crunchy-snow/data/subsets_v3/train'
train_path_list = glob(f'{train_data_dir}/ASO_50M_SD*.nc')

val_data_dir = '/mnt/Backups/gbrench/repos/crunchy-snow/data/subsets_v3/val'
val_path_list = glob(f'{val_data_dir}/ASO_50M_SD*.nc')

In [3]:
# define data to be returned by dataloader
selected_channels = [
    # ASO products
    'aso_sd', # ASO lidar snow depth (target dataset)
    'aso_gap_map', # gaps in ASO data
    
    # Sentinel-1 products
    'snowon_vv', # snow on Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vh', # snow on Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vv', # snow off Sentinel-1 VV polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowoff_vh', # snow off Sentinel-1 VH polarization backscatter in dB, closest acquisition to ASO acquisition
    'snowon_vv_mean', # snow on Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_vh_mean', # snow on Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vv_mean', # snow off Sentinel-1 VV polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowoff_vh_mean', # snow off Sentinel-1 VH polarization backscatter in dB, mean of acquisition in 4 week period around ASO acquisition
    'snowon_cr', # cross ratio, snowon_vh - snowon_vv
    'snowoff_cr', # cross ratio, snowoff_vh - snowoff_vv
    'delta_cr', # change in cross ratio, snowon_cr - snowoff_cr
    'rtc_gap_map', # gaps in Sentinel-1 data
    'rtc_mean_gap_map', # gaps in Sentinel-1 mean data
    
    # Sentinel-2 products 
    'aerosol_optical_thickness', # snow on Sentinel-2 aerosol optical thickness band 
    'coastal_aerosol', # snow on Sentinel-2 coastal aerosol band
    'blue', # snow on Sentinel-2 blue band
    'green', # snow on Sentinel-2 green band
    'red', # snow on Sentinel-2 red band
    'red_edge1', # snow on Sentinel-2 red edge 1 band
    'red_edge2', # snow on Sentinel-2 red edge 2 band
    'red_edge3', # snow on Sentinel-2 red edge 3 band
    'nir', # snow on Sentinel-2 near infrared band
    'water_vapor', # snow on Sentinel-2 water vapor
    'swir1', # snow on Sentinel-2 shortwave infrared band 1
    'swir2', # snow on Sentinel-2 shortwave infrared band 2
    'scene_class_map', # snow on Sentinel-2 scene classification product
    'water_vapor_product', # snow on Sentinel-2 water vapor product
    'ndvi', # Normalized Difference Vegetation Index from Sentinel-2
    'ndsi', # Normalized Difference Snow Index from Sentinel-2
    'ndwi', # Normalized Difference Water Index from Sentinel-2
    's2_gap_map', # gaps in Sentinel-2 data

    # PROBA-V global land cover dataset (Buchhorn et al., 2020)
    'fcf', # fractional forest cover
    
    # COP30 digital elevation model      
    'elevation',

    # latitude and longitude
    'latitude',
    'longitude',

    # day of water year
    'dowy'
                    ]

In [4]:
def train_model(input_channels, epochs, n_layers, max_kernel):
    model = crunchy_snow.models.ResDepth(n_input_channels=len(input_channels), depth=n_layers, max_filter_depth=max_kernel)
    model.to('cuda');  # Run on GPU
    # Define optimizer and loss function
    optimizer = optim.AdamW(model.parameters(), lr=0.0003)
    loss_fn = nn.MSELoss()
    epochs = epochs
    
    train_loss = []
    val_loss = []
    
    # training and validation loop
    for epoch in range(epochs):
        print(f'\nStarting epoch {epoch+1}')
        epoch_loss = []
        val_temp_loss = []
    
        # Loop through training data with tqdm progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", ncols=130)
        for data_tuple in pbar:
            model.train()
            optimizer.zero_grad()
    
            # read data into dictionary
            data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}
            # prepare inputs by concatenating along channel dimension
            inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
    
            # generate prediction
            pred_sd = model(inputs)
    
            # Limit prediction to areas with valid data
            pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
            aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') + data_dict['s2_gap_map'].to('cuda')== 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))
    
            # Calculate loss
            loss = loss_fn(pred_sd, aso_sd.to('cuda'))
            epoch_loss.append(loss.item())
    
            # Update tqdm progress bar with batch loss
            pbar.set_postfix({'batch loss': loss.item(), 'mean epoch loss': np.mean(epoch_loss)})
    
    
            loss.backward()  # Propagate the gradients in backward pass
            optimizer.step()
    
        train_loss.append(np.mean(epoch_loss))
        print(f'Training loss: {np.mean(epoch_loss)}')
    
        # Run model on validation data with tqdm progress bar
        for data_tuple in tqdm(val_loader, desc="Validation", unit="batch"):
            with torch.no_grad():
                model.eval()
                
                # read data into dictionary
                data_dict = {name: tensor for name, tensor in zip(selected_channels, data_tuple)}
                # prepare inputs by concatenating along channel dimension
                inputs = torch.cat([data_dict[channel] for channel in input_channels], dim=1).to('cuda')
        
                # generate prediction
                pred_sd = model(inputs)
        
                # Limit prediction to areas with valid data
                pred_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') == 0, pred_sd, torch.zeros_like(pred_sd).to('cuda'))
                aso_sd = torch.where(data_dict['aso_gap_map'].to('cuda') + data_dict['rtc_gap_map'].to('cuda') == 0, data_dict['aso_sd'].to('cuda'), torch.zeros_like(pred_sd).to('cuda'))
        
                # Calculate loss
                loss = loss_fn(pred_sd, aso_sd.to('cuda'))
                val_temp_loss.append(loss.item())
    
        val_loss.append(np.mean(val_temp_loss))
        print(f'Validation loss: {np.mean(val_temp_loss)}')

    return np.min(train_loss), np.min(val_loss)

In [5]:
# define input channels for model
all_input_channels = [
    'snowon_vv',
    'snowon_vh',
    'snowoff_vv',
    'snowoff_vh',
    'snowon_cr',
    'snowoff_cr',
    'delta_cr',
    'blue',
    'green',
    'red',
    'nir',
    'swir1',
    'swir2',
    'scene_class_map',
    'ndvi',
    'ndsi',
    'ndwi',
    'fcf',
    'elevation',
    'latitude',
    'longitude',
    'dowy']

In [7]:
# subsample dataset
n_imgs = 16
train_path_list = random.sample(train_path_list, n_imgs)
val_path_list = random.sample(val_path_list, n_imgs)
# prepare training and validation dataloaders
train_data = crunchy_snow.dataset.Dataset(train_path_list, selected_channels, norm=True)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
val_data = crunchy_snow.dataset.Dataset(val_path_list, selected_channels, norm=True)
val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=16, shuffle=True)

In [None]:
num_trials = 100
exp_dict = {}
for trial in range(num_trials):
    print('---------------------------------------------------------')
    print(f'trial {trial+1}/{num_trials}')
    input_channels = random.sample(all_input_channels, 11)
    print(f'trial {trial+1} input channels: {input_channels}')
    # # subsample dataset
    # train_path_list = random.sample(train_path_list, n_imgs)
    # val_path_list = random.sample(val_path_list, n_imgs)
    # # prepare training and validation dataloaders
    # train_data = crunchy_snow.dataset.Dataset(train_path_list, selected_channels, norm=True)
    # train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True)
    # val_data = crunchy_snow.dataset.Dataset(val_path_list, selected_channels, norm=True)
    # val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=16, shuffle=True)
    # train model
    final_train_loss, final_val_loss = train_model(input_channels, epochs=20, n_layers=5, max_kernel=1024)
    print(f'trial {trial+1} final train loss: {final_train_loss}, final val loss: {final_val_loss}')
    exp_dict[trial+1] = [input_channels, final_train_loss, final_val_loss]
    # save experiments 
    with open(f'../../loss/ResDepth_feature_sel_loss_v1.pkl', 'wb') as f:
        pickle.dump(exp_dict, f)

---------------------------------------------------------
trial 1/100
trial 1 input channels: ['fcf', 'swir1', 'snowoff_vh', 'ndwi', 'snowoff_cr', 'snowon_vh', 'snowoff_vv', 'ndsi', 'snowon_cr', 'elevation', 'swir2']

Starting epoch 1


Epoch 1/20: 100%|█████████████████████████████████████| 1/1 [00:04<00:00,  4.96s/batch, batch loss=0.0139, mean epoch loss=0.0139]


Training loss: 0.01386793702840805


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.48s/batch]


Validation loss: 0.006122227292507887

Starting epoch 2


Epoch 2/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.42s/batch, batch loss=0.00775, mean epoch loss=0.00775]


Training loss: 0.007752183824777603


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.88s/batch]


Validation loss: 0.003418389242142439

Starting epoch 3


Epoch 3/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch, batch loss=0.00495, mean epoch loss=0.00495]


Training loss: 0.004946465604007244


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch]


Validation loss: 0.003097961423918605

Starting epoch 4


Epoch 4/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch, batch loss=0.00483, mean epoch loss=0.00483]


Training loss: 0.004831929225474596


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch]


Validation loss: 0.0038175168447196484

Starting epoch 5


Epoch 5/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.45s/batch, batch loss=0.00579, mean epoch loss=0.00579]


Training loss: 0.005793564021587372


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.004116851836442947

Starting epoch 6


Epoch 6/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.45s/batch, batch loss=0.00615, mean epoch loss=0.00615]


Training loss: 0.006153309252113104


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.003818458877503872

Starting epoch 7


Epoch 7/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.00574, mean epoch loss=0.00574]


Training loss: 0.005737831816077232


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch]


Validation loss: 0.003265687730163336

Starting epoch 8


Epoch 8/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch, batch loss=0.00498, mean epoch loss=0.00498]


Training loss: 0.00497992942109704


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.42s/batch]


Validation loss: 0.002762323711067438

Starting epoch 9


Epoch 9/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.45s/batch, batch loss=0.00425, mean epoch loss=0.00425]


Training loss: 0.004253816790878773


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.0024532070383429527

Starting epoch 10


Epoch 10/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.00374, mean epoch loss=0.00374]


Training loss: 0.003738784696906805


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.0023637106642127037

Starting epoch 11


Epoch 11/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.50s/batch, batch loss=0.00348, mean epoch loss=0.00348]


Training loss: 0.003479021368548274


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.15s/batch]


Validation loss: 0.002428196370601654

Starting epoch 12


Epoch 12/20: 100%|██████████████████████████████████| 1/1 [00:02<00:00,  2.74s/batch, batch loss=0.00341, mean epoch loss=0.00341]


Training loss: 0.003411351703107357


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.45s/batch]


Validation loss: 0.0025427136570215225

Starting epoch 13


Epoch 13/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch, batch loss=0.00342, mean epoch loss=0.00342]


Training loss: 0.003424044931307435


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.66s/batch]


Validation loss: 0.002614496275782585

Starting epoch 14


Epoch 14/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.00342, mean epoch loss=0.00342]


Training loss: 0.0034181978553533554


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.42s/batch]


Validation loss: 0.002589217619970441

Starting epoch 15


Epoch 15/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.00334, mean epoch loss=0.00334]


Training loss: 0.003335281740874052


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.0024556387215852737

Starting epoch 16


Epoch 16/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch, batch loss=0.00316, mean epoch loss=0.00316]


Training loss: 0.0031632566824555397


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch]


Validation loss: 0.002239635679870844

Starting epoch 17


Epoch 17/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.42s/batch, batch loss=0.00293, mean epoch loss=0.00293]


Training loss: 0.0029324369970709085


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.42s/batch]


Validation loss: 0.001994190737605095

Starting epoch 18


Epoch 18/20: 100%|████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.0027, mean epoch loss=0.0027]


Training loss: 0.0026992796920239925


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.0017766420496627688

Starting epoch 19


Epoch 19/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.00252, mean epoch loss=0.00252]


Training loss: 0.0025195854250341654


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch]


Validation loss: 0.0016262579010799527

Starting epoch 20


Epoch 20/20: 100%|██████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.00242, mean epoch loss=0.00242]


Training loss: 0.002424437552690506


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.0015477908309549093
trial 1 final train loss: 0.002424437552690506, final val loss: 0.0015477908309549093
---------------------------------------------------------
trial 2/100
trial 2 input channels: ['blue', 'red', 'elevation', 'snowoff_vv', 'dowy', 'snowoff_cr', 'snowon_cr', 'snowon_vh', 'scene_class_map', 'fcf', 'swir2']

Starting epoch 1


Epoch 1/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.00914, mean epoch loss=0.00914]


Training loss: 0.009136082604527473


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.0027455275412648916

Starting epoch 2


Epoch 2/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.46s/batch, batch loss=0.00392, mean epoch loss=0.00392]


Training loss: 0.003918749745935202


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.42s/batch]


Validation loss: 0.002608755137771368

Starting epoch 3


Epoch 3/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch, batch loss=0.00374, mean epoch loss=0.00374]


Training loss: 0.0037361502181738615


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.003588136751204729

Starting epoch 4


Epoch 4/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.45s/batch, batch loss=0.00482, mean epoch loss=0.00482]


Training loss: 0.004823969677090645


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.0034565534442663193

Starting epoch 5


Epoch 5/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.46s/batch, batch loss=0.00465, mean epoch loss=0.00465]


Training loss: 0.004645613022148609


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/batch]


Validation loss: 0.002666451735422015

Starting epoch 6


Epoch 6/20: 100%|███████████████████████████████████| 1/1 [00:03<00:00,  3.45s/batch, batch loss=0.00368, mean epoch loss=0.00368]


Training loss: 0.0036839160602539778


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch]


Validation loss: 0.001977604813873768

Starting epoch 7


Epoch 7/20: 100%|█████████████████████████████████████| 1/1 [00:03<00:00,  3.46s/batch, batch loss=0.0028, mean epoch loss=0.0028]


Training loss: 0.002796629909425974


Validation: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.44s/batch]


Validation loss: 0.0017495134379714727

Starting epoch 8


Epoch 8/20: 100%|█████████████████████████████████████| 1/1 [00:03<00:00,  3.47s/batch, batch loss=0.0024, mean epoch loss=0.0024]


Training loss: 0.0024011300411075354


Validation:   0%|                                                                              | 0/1 [00:00<?, ?batch/s]

In [None]:
with open(f'../../loss/ResDepth_feature_sel_loss.pkl', 'rb') as f:
        exp_dict = pickle.load(f)

In [None]:
train_channel_performance = {}
val_channel_performance = {}
all_values = {}
for channel in all_input_channels:
    train_loss_w_channel = []
    val_loss_w_channel = []
    train_loss_wo_channel = []
    val_loss_wo_channel = []
    for trial in exp_dict.keys():
        if channel in exp_dict[trial][0]:
            train_loss_w_channel.append(exp_dict[trial][1])
            val_loss_w_channel.append(exp_dict[trial][2])
        else:
            train_loss_wo_channel.append(exp_dict[trial][1])
            val_loss_wo_channel.append(exp_dict[trial][2])
    train_loss_diff = np.mean(train_loss_wo_channel) - np.mean(train_loss_w_channel)
    val_loss_diff = np.mean(val_loss_wo_channel) - np.mean(val_loss_w_channel)
    train_channel_performance[channel] = train_loss_diff
    val_channel_performance[channel] = val_loss_diff
    all_values[channel] = [train_loss_w_channel, train_loss_wo_channel, val_loss_w_channel, val_loss_wo_channel]

df = pd.DataFrame({
    'channels': train_channel_performance.keys(),
    'train_loss_diff': train_channel_performance.values(),
    'val_loss_diff': val_channel_performance.values()
})

df = df.sort_values('val_loss_diff', ascending=False)

In [None]:
# Find the maximum length of the lists
max_len = max(max(len(lst) for lst in metrics) for metrics in all_values.values())

# Convert dictionary to DataFrame
rows = []
for channel, metrics in all_values.items():
    for i in range(max_len):
        row = [channel]
        for metric in metrics:
            if i < len(metric):
                row.append(metric[i])
            else:
                row.append(np.nan)
        rows.append(row)

df = pd.DataFrame(rows, columns=['channel', 'train_mse_w', 'train_mse_wo', 'val_mse_w', 'val_mse_wo'])

In [None]:
from scipy.stats import mannwhitneyu
results = []

# Perform Mann–Whitney U tests for each channel
for channel in df['channel'].unique():
    channel_data = df[df['channel'] == channel]
    metric_1 = channel_data['train_mse_w'].dropna()
    metric_2 = channel_data['train_mse_wo'].dropna()
    metric_3 = channel_data['val_mse_w'].dropna()
    metric_4 = channel_data['val_mse_wo'].dropna()
    
    # Mann–Whitney U test for metrics 1 and 2
    stat_train, p_train = mannwhitneyu(metric_1, metric_2)
    
    # Mann–Whitney U test for metrics 3 and 4
    stat_val, p_val = mannwhitneyu(metric_3, metric_4)
    
    # Append results to the list
    results.append([channel, p_train, p_val])

# Convert results to DataFrame
results_df = pd.DataFrame(results, columns=['channel', 'train', 'val'])
print(results_df)

In [None]:
# Create subplots
f, ax = plt.subplots(nrows=2, ncols=1, figsize=(10, 8))

# Top plot: Performance metrics 1 and 2
sns.boxplot(x='channel', y='value', hue='metric', data=pd.melt(df, id_vars=['channel'], value_vars=['train_mse_w', 'train_mse_wo'], var_name='metric'), ax=ax[0])
ax[0].set_title('training performance with and without channel')
ax[0].set_ylabel('MSE')
ax[0].set_xlabel('')
ax[0].tick_params(axis='x', rotation=90)
ax[0].set_ylim(None, 0.0010)

# Bottom plot: Performance metrics 3 and 4
sns.boxplot(x='channel', y='value', hue='metric', data=pd.melt(df, id_vars=['channel'], value_vars=['val_mse_w', 'val_mse_wo'], var_name='metric'), ax=ax[1])
ax[1].set_title('validation performance with and without channel')
ax[1].set_ylabel('MSE')
ax[1].set_xlabel('')
ax[1].tick_params(axis='x', rotation=90)
ax[1].set_ylim(None, 0.0010)

# Adjust layout
plt.tight_layout()

In [None]:
# Set the position of the bars on the x-axis
bar_width = 0.35
r1 = np.arange(len(df['channels']))
r2 = [x + bar_width for x in r1]

# Create the plot
fig, ax = plt.subplots(figsize=(10, 3))
ax.bar(r1, df['train_loss_diff'], color='blue', width=bar_width, edgecolor='grey', label='change in training MSE')
ax.bar(r2, df['val_loss_diff'], color='orange', width=bar_width, edgecolor='grey', label='change in validation MSE')
# Add labels
ax.set_xlabel('channels')
ax.set_xticks([r + bar_width/2 for r in range(len(df['channels']))])
ax.set_xticklabels(df['channels'], rotation=90)
ax.set_ylabel('change in MSE loss', fontweight='bold')
ax.set_title('mean change in final MSE when channels are included')
# Add legend
ax.legend()

In [None]:
# Create the plot
# Determine the min and max values for the axes
min_value = min(min(df['val_loss_diff']), min(df['train_loss_diff']))
max_value = max(max(df['val_loss_diff']), max(df['train_loss_diff']))

fig, ax = plt.subplots()
ax.axvline(0, alpha=0.3)
ax.axhline(0, alpha=0.3)
ax.plot([min_value, max_value], [min_value, max_value], color='grey', linestyle='--', linewidth=1)

for channel in df['channels']:
    ax.scatter(df.loc[df['channels'] == channel, 'val_loss_diff'], df.loc[df['channels'] == channel, 'train_loss_diff'], label=channel)

# Add labels
ax.set_xlabel('change in validation MSE', fontweight='bold')
ax.set_ylabel('change in training MSE', fontweight='bold')
ax.set_title('mean improvement in final MSE when channels are included')
ax.set_aspect('equal')
padding = 0.00001
ax.set_xlim(min_value-padding, max_value+padding)
ax.set_ylim(min_value-padding, max_value+padding)

# Add legend
ax.legend(title='channels', bbox_to_anchor=(1.05, 1), loc='upper left')