### Quick Note
I originally wanted to submit this while the competition was running, 
but only joined in the last week and got the warning that notebooks may not be made public this late.
I hope it still provides some interesting information.

My own submissions used the recommended band08 in addition to Ash, but I did not have time left to verify if there was a substantial advantage against Ash only, I'm somewhat skeptical of that now, having not seen it used by any of the top solutions posted yet.

# Summary and Key Takeaway
Taking the Ash false color image as a basis seems great, but I expect most people are doing that already.
Just adding bands 11 to 16 to that base seems like a bad idea, as the information is almost completely contained in blue already.
I would recommend trying to add band08 as a 4th input channel (normalized), as it is much less correlated with the Ash colors and pretty well correlated with the contrails, meaning it could add new and useful information, an implementation is here [Contrails Dataset (Ash+08)](https://www.kaggle.com/code/raki21/contrails-dataset-ash-band08-with-soft-mask). I would not add all 3 (bands 08,09 and 10) as they are highly correlated with each other and additional channels that don't contribute additional information might increase opportunity for overfitting. 

# Information Bands
A lot of information about bands is already contained in other EDAs, especially [this EDA](https://www.kaggle.com/code/pranavnadimpali/comprehensive-eda-submission), a short summary:

The Advanced Baseline Imager (ABI) is a key instrument on the Geostationary Operational Environmental Satellite (GOES) series. It captures images of the Earth using 16 spectral bands, with each band focusing on a specific wavelength. This provides a wealth of information about the Earth's atmosphere, clouds, land, and water, significantly improving weather analysis and forecasting. In this dataset, you have access to 9 bands for each example.

Each band provides a series of images taken at 10-minute intervals, leading to 8 images for each band spanning 80 minutes. This temporal data captures how contrails evolve over time. Two types of segmentation masks are provided: 'human_pixel_masks' that represent the consolidated ground truth, and 'human_individual_masks' that represent annotations from multiple labellers. The ground truth corresponds to the 5th image in the bands.

The image that labellers annotate is a false color image, which is not directly provided in the spectral bands but can be generated from them. The false color image uses the [Ash color scheme](https://eumetrain.org/sites/default/files/2020-05/RGB_recipes.pdf), using bands 15,14 and 11, making contrails appear darker and thus easier to detect. The ash color scheme combines red, green, and blue channels to represent different features. 


## Correlation Between Bands and Contrails
The contribution of this notebook is trying to find bands that are useful for contrail detection.
For this we create a correlation matrix between the bands and the contrail groundtruth.

In [None]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm

In [None]:
data_dir = '/kaggle/input/google-research-identify-contrails-reduce-global-warming/'

train_rs = os.listdir(data_dir + 'train')
valid_rs = os.listdir(data_dir + 'validation')

train_df = pd.DataFrame(train_rs, columns=['record_id'])
valid_df = pd.DataFrame(valid_rs, columns=['record_id'])

train_df['train'] = 'train'
valid_df['train'] = 'valid'

In [None]:
def plot_correlation_matrix(correlation_matrix, band_labels, band_labels_nickname):
    plt.figure(figsize=(10,10)) # Set the figure size
    sns.set(font_scale=1.2)  # Set font size
    
    # Generate a mask for the upper triangle
    mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
    
    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(230, 20, as_cmap=True)
    
    # Draw the heatmap with the mask and correct aspect ratio
    sns.set(font_scale=0.7)
    sns.heatmap(correlation_matrix, annot=True, fmt=".3f", mask=mask, cmap=cmap, 
                cbar_kws={"shrink": .5}, xticklabels=band_labels, yticklabels=band_labels_nickname)
    plt.xlabel('Band')  # Adding label to x-axis
    plt.ylabel('Band Nickname')  # Adding label to y-axis
    plt.title('Correlation matrix of Band Data')
    plt.show()

In [None]:
def correlation_create(record_id, directory):
    record_data = {}
    for band in range(8, 17):  # For bands 8 through 16
        band_key = f"band_{band:02d}"
        band_data = np.load(os.path.join(directory, record_id, band_key + ".npy"))
        record_data[band_key] = band_data[:,:,4]                                          # Only look at current time

    if directory == 'train':
        individual = np.load(os.path.join(directory, record_id, 'human_individual_masks' + ".npy"))
        record_data['mean_mask'] = individual.sum(axis=3) / individual.shape[3]
    else:
        record_data['mask'] = np.load(os.path.join(directory, record_id, 'human_pixel_masks' + ".npy")).squeeze()

    band_values = np.array([band_data.flatten() for band_data in record_data.values()])
    correlation_matrix = np.corrcoef(band_values)
    correlation_matrix = np.nan_to_num(correlation_matrix)

    return correlation_matrix

In [None]:
for n,i in enumerate(train_rs):
    if n == 0:
        cor = correlation_create(str(i), data_dir+'train')
        total_cor = cor
    elif n%10 == 0:
        cor = correlation_create(str(i), data_dir+'train')
        total_cor += cor
    if n == 10000 - 1:
        break
correlation_matrix = total_cor/1000

In [None]:
band_labels = ["band_08", "band_09", "band_10", "band_11", "band_12", "band_13", "band_14", "band_15", "band_16","mask"]
band_labels_nickname = ["Upper-Level Tropospheric Water Vapor", 
                       "Mid-Level Tropospheric Water Vapor", 
                       "Lower-level Water Vapor", 
                       "Cloud-Top Phase", 
                       "Ozone", 
                       "Clean IR Longwave Window", 
                       "IR Longwave Window", 
                       "Dirty Longwave Window", 
                       "CO2 Longwave Infrared",
                       "Contrails"]
plot_correlation_matrix(correlation_matrix, band_labels, band_labels_nickname)

You can see that bands 8,9,10 and bands 11 to 16 are highly correlated with each other, the bands 8 and 16 seem to be most correlated with contrails.
I made a few small training runs with only using those bands leading to Val Dice that seem to align well with the correlation scores.

    Band  Dice  Val Dice
    8 	0.5  	0.377
    9 	0.487	0.372
    10	0.47 	0.346
    11	0.349	0.195
    12	0.323	0.165
    13	0.342	0.192
    14	0.454	0.297
    15	0.484	0.354
    16	0.518	0.396

We now look at the correlation with the Ash color scheme.

## Correlation Ash

In [None]:
def normalize_range(data, bounds):
    return (data - bounds[0]) / (bounds[1] - bounds[0])

In [None]:
def correlation_create_ash(record_id, directory):
    _T11_BOUNDS = (243, 303)                   #
    _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
    _TDIFF_BOUNDS = (-4, 2)

    record_data = {}
    for band in range(8, 17):  # For bands 8 through 16
        band_key = f"band_{band:02d}"
        band_data = np.load(os.path.join(directory, record_id, band_key + ".npy"))
        record_data[band_key] = band_data[:,:,4]                                          # Only look at current time

    record_data['r'] = normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
    record_data['g'] = normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
    record_data['b'] = normalize_range(record_data["band_14"], _T11_BOUNDS)
    
    if directory == 'train':
        individual = np.load(os.path.join(directory, record_id, 'human_individual_masks' + ".npy"))
        record_data['mean_mask'] = individual.sum(axis=3) / individual.shape[3]
    else:
        record_data['mask'] = np.load(os.path.join(directory, record_id, 'human_pixel_masks' + ".npy")).squeeze()

    band_values = np.array([band_data.flatten() for band_data in record_data.values()])
    correlation_matrix = np.corrcoef(band_values)
    correlation_matrix = np.nan_to_num(correlation_matrix)

    return correlation_matrix

In [None]:
for n,i in enumerate(train_rs):
    if n == 0:
        cor = correlation_create_ash(str(i), data_dir+'train')
        total_cor = cor
    elif n%10 == 0:
        cor = correlation_create_ash(str(i), data_dir+'train')
        total_cor += cor
    if n == 10000 - 1:
        break
correlation_matrix = total_cor/1000

In [None]:
band_labels = ["band_08", "band_09", "band_10", "band_11", "band_12", "band_13", "band_14", "band_15", "band_16", "r", "g", "b", "mask"]
band_labels_nickname = ["Upper-Level Tropospheric Water Vapor", 
                       "Mid-Level Tropospheric Water Vapor", 
                       "Lower-level Water Vapor", 
                       "Cloud-Top Phase", 
                       "Ozone", 
                       "Clean IR Longwave Window", 
                       "IR Longwave Window", 
                       "Dirty Longwave Window", 
                       "CO2 Longwave Infrared",
                       "Red", "Green", "Blue",
                       "Contrails"]
plot_correlation_matrix(correlation_matrix, band_labels, band_labels_nickname)

The correlation between the red and green false colors in the ash image with contrails is very high and they are not highly correlated with any other bands. 
I except this to mean that they are really useful, which is plausible as this is what labelers saw. The exception is with Blue but this is also almost the same as band 14. 

Taking the Ash false color image as a basis seems great, adding bands 11 to 16 to that base seems like a bad idea, as the information is almost completely contained in blue already, I would recommend trying to add band08 as a 4th input channel, as it is much less correlated with the Ash colors and pretty well correlated with the contrails, meaning it could add new and useful information, I would not add all 3 as they are highly correlated with each other.

# Dataset Creation

In [None]:
train_df.shape, valid_df.shape

## Save csvs

In [None]:
# Save the csvs
train_df.to_csv('train_df.csv', index=False)
valid_df.to_csv('valid_df.csv', index=False)

## Save Images as Numpy

In [None]:
def read_record(record_id, directory, mode):
    record_data = {}
    read = ["band_08", "band_11", "band_14", "band_15"]
    if mode == 'train':
        read.append('human_individual_masks')
    elif mode == 'val':
        read.append('human_pixel_masks')
    for x in read:
        if x == 'human_individual_masks':
            individual = np.load(os.path.join(directory, record_id, x + ".npy"))
            record_data['human_pixel_masks'] = individual.sum(axis=3) / individual.shape[3]
        else:
            record_data[x] = np.load(os.path.join(directory, record_id, x + ".npy"))
    return record_data

In [None]:
def normalize_range(data, bounds):
    return (data - bounds[0]) / (bounds[1] - bounds[0])

def get_false_color(record_data):
    _T11_BOUNDS = (243, 303)
    _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
    _TDIFF_BOUNDS = (-4, 2)
    N_TIMES_BEFORE = 4

    r = normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
    g = normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
    b = normalize_range(record_data["band_14"], _T11_BOUNDS)
    n = (record_data["band_08"] - 230) / 20

    false_color = np.stack([r, g, b, n], axis=2)
    img = false_color[..., N_TIMES_BEFORE]
    #print(img.shape, img.mean(axis=(0,1)), img.min(axis=(0,1)), img.max(axis=(0,1)))
    
    return img

In [None]:
path = Path('contrails')
path.mkdir(exist_ok=True, parents=True)

In [None]:
#Val
for i in tqdm(valid_rs):
    data = read_record(str(i), data_dir+'validation', mode='val')
    img = get_false_color(data)
    final = np.dstack([img, data['human_pixel_masks']])
    final = final.astype(np.float16)
    
    pathc = path/f"{i}.npy"
    np.save(str(pathc), final)

In [None]:
#Train
for i in tqdm(train_rs):
    data = read_record(str(i), data_dir+'train', mode='train')
    img = get_false_color(data)
    final = np.dstack([img, data['human_pixel_masks']])
    final = final.astype(np.float16)

    pathc = path/f"{i}.npy"
    np.save(str(pathc), final)

In [None]:
nan