In [1]:
import xarray as xr
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

sns.set_theme('notebook')

In [2]:
data = xr.open_dataset('data/oaflux_air_sea_fluxes_train.nc', engine='netcdf4')
labels = xr.open_dataset('data/marine_heatwave_labels_train.nc', engine='netcdf4')

In [3]:
df = None
variables = list(data.variables)
for var in tqdm(variables[3:]): 
    
    stacked = data[var].stack(dim=["lon", "lat", "time"]).to_pandas().T
    if df is None:
        df = stacked
    else:
        df = pd.concat([df, stacked], axis=1)

# df = df.dropna(0, how='all')
df.reset_index(inplace=True)
df.columns = variables

  0%|          | 0/10 [00:00<?, ?it/s]

In [4]:
# sort variables and remove rows where ALL variables are NaN
df = df.sort_values(by=["lon", "lat", "time"], ascending=[True] * 3)

In [5]:
df["month"] = df.time.dt.month
df["day"] = df.time.dt.day
df = df.drop(["time"], axis=1)
df = df.fillna(-1)

In [6]:
targets = labels.mhw_label \
    .stack(dim=["lon", "lat", "time"]) \
    .to_pandas().T.reset_index() 

targets = targets \
    .sort_values(by=["lon", "lat", "time"]) \
    .drop(["lon", "lat", "time"], axis=1)

In [7]:
targets

Unnamed: 0,0
0,0.0
1,0.0
2,0.0
3,0.0
4,0.0
...,...
26697595,0.0
26697596,0.0
26697597,0.0
26697598,0.0


In [8]:
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import make_pipeline

In [9]:
X_train, X_test, y_train, y_test = \
    train_test_split(df, targets, test_size=0.3, shuffle=True, stratify=targets, random_state=1)

In [11]:
scoring = ['accuracy', 'precision', 'recall']
pipe = make_pipeline(StandardScaler(), SGDClassifier(loss='log', verbose=10))
scores = cross_validate(pipe, X_train, y_train, scoring=scoring, verbose=10, n_jobs=-1)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   5 | elapsed:  2.9min remaining:  4.3min
[Parallel(n_jobs=-1)]: Done   3 out of   5 | elapsed:  2.9min remaining:  1.9min
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:  3.1min remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:  3.1min finished


In [12]:
scores

{'fit_time': array([172.24347663, 156.41606903, 154.59365749, 164.7169404 ,
        156.47908425]),
 'score_time': array([4.29234409, 5.24218154, 4.99512601, 4.46340895, 5.31519651]),
 'test_accuracy': array([0.94494609, 0.9449469 , 0.94494556, 0.94495894, 0.9449485 ]),
 'test_precision': array([0.16363636, 0.11827957, 0.16666667, 0.20930233, 0.11904762]),
 'test_recall': array([8.75065022e-05, 5.34761958e-05, 9.23679746e-05, 4.37530384e-05,
        4.86144871e-05])}

In [15]:
from sklearn.metrics import classification_report

pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)

  return f(*args, **kwargs)


-- Epoch 1
Norm: 1.86, NNZs: 13, Bias: -3.026502, T: 18688320, Avg. loss: 0.203085
Total training time: 5.68 seconds.
-- Epoch 2
Norm: 1.86, NNZs: 13, Bias: -3.046322, T: 37376640, Avg. loss: 0.201135
Total training time: 11.30 seconds.
-- Epoch 3
Norm: 1.84, NNZs: 13, Bias: -3.059706, T: 56064960, Avg. loss: 0.201102
Total training time: 16.91 seconds.
-- Epoch 4
Norm: 1.84, NNZs: 13, Bias: -3.051522, T: 74753280, Avg. loss: 0.201090
Total training time: 22.53 seconds.
-- Epoch 5
Norm: 1.84, NNZs: 13, Bias: -3.035042, T: 93441600, Avg. loss: 0.201085
Total training time: 28.14 seconds.
-- Epoch 6
Norm: 1.85, NNZs: 13, Bias: -3.046700, T: 112129920, Avg. loss: 0.201080
Total training time: 33.76 seconds.
-- Epoch 7
Norm: 1.85, NNZs: 13, Bias: -3.046522, T: 130818240, Avg. loss: 0.201075
Total training time: 39.40 seconds.
Convergence after 7 epochs took 39.40 seconds


In [17]:
report = classification_report(y_test, y_pred)
print(report)

              precision    recall  f1-score   support

         0.0       0.94      1.00      0.97   7568495
         1.0       0.15      0.00      0.00    440785

    accuracy                           0.94   8009280
   macro avg       0.55      0.50      0.49   8009280
weighted avg       0.90      0.94      0.92   8009280

