# Analyzing ClimateNet Dataset

In this notebook we analyze the ClimateNet dataset by computing several useful statistics and visualizing interesting examples.

In [None]:
from pathlib import Path

# specify the path to the data and output directories
out_dir = Path('/mnt/data/ai4good/out')
data_dir = Path('/mnt/data/ai4good/climatenet_new')  # expected to have a subfolder 'train' containing the train set
train_dir = data_dir / 'train'
test_dir = data_dir / 'test'

##  Data Exploration

First, we load the available train data and explore it a bit.

In [None]:
from utils.data import ClimateNetDataset
import xarray as xr
import numpy as np

ds = ClimateNetDataset(train_dir)

concat_ds = []
for i in np.random.randint(0, len(ds), 3):
    x = ds[i]
    concat_ds.append(x)

concat_ds = xr.concat(concat_ds, dim='time')

len(ds)
len(concat_ds)

In [None]:
sample = ds[0]
sample

In [None]:
sample.dims

In [None]:
sample.coords

In [None]:
for var in sample.data_vars:
    if var == 'LABELS':
        print(f'{var}: {sample[var].attrs["description"]} ({sample[var].dtype})')
    else:
        print(f'{var}: {sample[var].attrs["description"]} ({sample[var].attrs["units"]}) ({sample[var].dtype})')

In [None]:
import ipywidgets as widgets
variables = list(sample.data_vars)
var_dropdown = widgets.Dropdown(options=variables, value=variables[0], description='Variable')

def plot_sample(var):
    sample[var].plot(figsize=(10,5))

widgets.interact(plot_sample, var=var_dropdown);

## Data Analysis

In [None]:
from utils.stats import Stats

stats = Stats(ds=ds, num_samples=3)

In [None]:
bg, tc, ar = stats.get_label_distribution()
bg = bg.sum()
tc = tc.sum()
ar = ar.sum()
sum = bg + tc + ar

print(f'Background: {bg.values} ({bg/sum*100:.2f}%)')
print(f'Tropical Cyclone: {tc.values} ({tc/sum*100:.2f}%)')
print(f'Atmospheric Rivers: {ar.values} ({ar/sum*100:.2f}%)')

In [None]:
cm = stats.get_corr_matrix()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(abs(cm), ax=ax, xticklabels=stats.data_vars, yticklabels=stats.data_vars, cmap='RdBu_r', center=0, vmin=0, vmax=1)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
plt.show()

In [None]:
import ipywidgets as widgets
variables = list(stats.data_vars)
var_dropdown = widgets.Dropdown(options=variables, value=variables[0], description='Variable')

def plot_sample(var):
    var_stats = stats.get_stats(var)
    for stat in var_stats:
        print(f'{stat}: {var_stats[stat]}')

widgets.interact(plot_sample, var=var_dropdown);

In [None]:
features = ['PSL', 'TMQ', 'U850', 'V850', 'T500', 'ZBOT']

In [None]:
cm_features = stats.get_corr_matrix_vars(features)

In [None]:
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(abs(cm_features), ax=ax, xticklabels=features, yticklabels=features, cmap='RdBu_r', center=0, vmin=0, vmax=1)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha='right')
plt.show()