In [None]:
%load_ext autoreload

In [None]:
import collections
import hydra
import itertools
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pathlib
import torch
import torch.nn as nn

from pandas.api.types import CategoricalDtype

from smc01.postprocessing.util import concat_collate_fn

In [None]:
with hydra.initialize_config_module('smc01.postprocessing.conf'):
    cfg = hydra.compose('train')

In [None]:
cfg

In [None]:
dataset = hydra.utils.instantiate(cfg.experiment.dataset)

In [None]:
dataset[0]

## First try with dataloader

In [None]:
loader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=concat_collate_fn, shuffle=True, num_workers=4)

In [None]:
weights_np = np.zeros(((len(dataset.stations), 730, 81, 5)))
weights_np[..., 0] = 1
weights = torch.tensor(weights_np, requires_grad=True)

biases = torch.zeros((len(dataset.stations), 730, 81), requires_grad=True)

optimizer = torch.optim.Adam([weights, biases], lr=1e-3)

In [None]:
for b in itertools.islice(loader, 10):
    features = b['features']
    obs = b['obs']
    
    station_ids, forecast_id, step_id = [b[k] for k in ['station_id', 'forecast_id', 'step_id']]
    
    gathered_weights = weights[station_ids, forecast_id, step_id]
    gathered_biases = biases[station_ids, forecast_id, step_id]   
    
    pred = (gathered_weights * features).sum(dim=1) + gathered_biases
    rmse = torch.sqrt(torch.square(obs - pred).mean())
    
    print(rmse)
    
    rmse.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
weights.mean(dim=0).mean(dim=0).mean(dim=0)

In [None]:
biases.max()

In [None]:
biases.min()

In [None]:
weights.max()

In [None]:
weights.min()

In [None]:
weights[..., 0].min()

In [None]:
weights.shape

In [None]:
weights[437].shape

In [None]:
plt.plot(weights[437, ..., 2].mean(dim=1).detach().numpy())