# Learn Gamma

The objective is to fit one gamma distribution per lat-lon to model the precipitation distribution of a tile.
First, we study the Gamma distribution object from pytorch to learn how tu use it.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pathlib
import torch
import scipy
import seaborn as sns
import xarray as xr

from crims2s.util import fix_dataset_dims

In [None]:
d = torch.distributions.Gamma(torch.Tensor([5.0]), torch.Tensor([1]))

In [None]:
sample = d.sample((1000,))

In [None]:
df = pd.DataFrame({'value': sample.numpy()[:,0]})

In [None]:
sns.displot(data=df)

In [None]:
a_hat = sample.mean() ** 2 / sample.var()
b_hat = sample.var() / sample.mean()

In [None]:
a_hat

In [None]:
b_hat

In [None]:

a = torch.full((1,), a_hat, requires_grad=True)
b = torch.full((1,), b_hat, requires_grad=True)

#a = torch.rand((1,), requires_grad=True)
#b = torch.rand((1,), requires_grad=True)

optimizer = torch.optim.SGD([a,b], lr=1e-2, momentum=0)

losses = []
a_list = []
b_list = []
mean_lls = []
regs = []

lambd = 1e-10

for _ in range(1000):
    estimated_gamma = torch.distributions.Gamma(torch.clamp(a, min=1e-6) , torch.clamp(b, min=1e-6))
    
    mean_log_likelihood = (1.0 - lambd) * estimated_gamma.log_prob(sample).mean()
    regularization = lambd * torch.square(a+b)
    
    mean_lls.append(mean_log_likelihood.detach().item())
    regs.append(regularization.detach().item())
    
    loss = -1.0 * mean_log_likelihood + regularization    
    a_list.append(a.detach().item())
    b_list.append(b.detach().item())
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    losses.append(loss.detach().item())
    
print(a.detach().item())
print(b.detach().item())

In [None]:
plt.plot(losses)

In [None]:
plt.plot(a_list)

In [None]:
plt.plot(b_list)

In [None]:
plt.plot(mean_lls)

In [None]:
plt.plot(regs)

## Do it for our real data

In [None]:
OBS_DIR = '***BASEDIR***training-output-reference/'
obs_path = pathlib.Path(OBS_DIR)

obs_files = sorted([f for f in obs_path.iterdir() if 'tp' in f.stem])

In [None]:
tp = xr.open_mfdataset(obs_files, preprocess=fix_dataset_dims)

In [None]:
tp

In [None]:
tp_w34 = (tp.sel(lead_time='28D') - tp.sel(lead_time='14D')).sel(latitude=slice(50.0, 30.0), forecast_dayofyear=slice(60, 220), forecast_year=slice(2007, None))

In [None]:
tp_w34

In [None]:
tp_w34.isnull().sum(dim=['latitude', 'longitude']).tp.compute().plot()

In [None]:
tp_w34 = tp_w34.stack(station=('latitude', 'longitude'))

In [None]:
tp_w34.dims

In [None]:
station_ids = xr.DataArray(np.arange(tp_w34.dims['station']), dims='station_coords')

In [None]:
tp_w34 = tp_w34.rename(station='station_coords').assign_coords(station=station_ids).swap_dims(station_coords='station')

In [None]:
#tp_w34 = tp_w34.drop('station_coords')

In [None]:
tp_w34

In [None]:
station_mask = (tp_w34.isnull().sum(dim=['forecast_year', 'forecast_dayofyear']) == 0).compute()

In [None]:
station_mask

In [None]:
tp_w34_only_land = tp_w34.where(station_mask, drop=True)

In [None]:
tp_w34_only_land

In [None]:
tp_train = tp_w34_only_land.isel(forecast_year=slice(None, -3))
tp_val = tp_w34_only_land.isel(forecast_year=slice(-3, None))

In [None]:
tp_train

In [None]:
tp_val

In [None]:
a_hat_xarray = tp_train.mean(dim='forecast_year') ** 2 / (tp_train.var(dim='forecast_year') + 1e-6)
b_hat_xarray = (tp_train.mean(dim='forecast_year') + 1e-6) / (tp_train.var(dim='forecast_year') + 1e-6)

In [None]:
a_hat_xarray.isnull().compute().sum()

In [None]:
train_pytorch = torch.tensor(tp_train.tp.data.compute())

In [None]:
train_pytorch.shape

In [None]:
val_pytorch = torch.tensor(tp_val.tp.data.compute())

In [None]:
val_pytorch.shape

In [None]:
train_pytorch.min()

In [None]:
train_pytorch.shape

In [None]:
a_hat = torch.tensor(a_hat_xarray.tp.data.compute(), requires_grad=True, device='cuda')
b_hat = torch.tensor(b_hat_xarray.tp.data.compute(), requires_grad=True, device='cuda')

#a_hat = torch.rand(*train_pytorch.shape[1:], requires_grad=True)
#b_hat = torch.rand(*train_pytorch.shape[1:], requires_grad=True)

optimizer = torch.optim.SGD([a_hat,b_hat], lr=1e-2, momentum=0.0)

losses = []
a_list = []
b_list = []
mean_lls = []
regs = []
vals = []

true_train = []
true_val = []


train_pytorch = torch.tensor(tp_train.tp.data.compute()).cuda()
val_pytorch = torch.tensor(tp_val.tp.data.compute()).cuda()

In [None]:
lambd = 0.01
optimizer = torch.optim.SGD([a_hat,b_hat], lr=5.0, momentum=0.0)

In [None]:
for i in range(2000):
    estimated_gamma = torch.distributions.Gamma(torch.clamp(a_hat, min=1e-6) , torch.clamp(b_hat, min=1e-6))
    
    mean_log_likelihood = (1.0 - lambd) * estimated_gamma.log_prob(train_pytorch + 1e-6).mean()
    regularization = lambd * (torch.square(a_hat) + torch.square(b_hat)).mean()
    
    mean_lls.append(-mean_log_likelihood.detach().item())
    regs.append(regularization.detach().item())
    
    loss = -1.0 * mean_log_likelihood + regularization    

    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    val_mean_log_likelihood = (1.0 - lambd) * estimated_gamma.log_prob(val_pytorch + 1e-6).mean()
    
    losses.append(loss.detach().item())
    vals.append(-val_mean_log_likelihood.detach().item())
    
    if i % 10 == 0:
        a_list.append(a_hat.mean().detach().item())
        b_list.append(b_hat.mean().detach().item())
    
    true_train.append(estimated_gamma.log_prob(train_pytorch + 1e-6).mean().detach().item())
    true_val.append(estimated_gamma.log_prob(val_pytorch + 1e-6).mean().detach().item())

In [None]:
estimated_gamma.log_prob(train_pytorch + 1e-6)[:, 0, 0]

In [None]:
fig, ax = plt.subplots()
plt.plot(true_train)
plt.plot(true_val)

In [None]:
begin = 0
end = -1

fig, ax = plt.subplots()
ax.plot(mean_lls[begin:end], label='train')
ax.plot(vals[begin:end], label='val')
ax.plot(regs[begin:end], label='reg')
ax.legend()
plt.show()

In [None]:
plt.plot(a_list)

In [None]:
plt.plot(b_list)

In [None]:
plt.plot(regs)

In [None]:
(a_hat < 0.0).sum()

In [None]:
a_hat.mean()

In [None]:
b_hat.mean()

In [None]:
a_hat.max()

In [None]:
val_pytorch.shape

In [None]:
train_pytorch[:, 0, 0]

In [None]:
train_pytorch.mean()

In [None]:
a_hat.max()

In [None]:
b_hat.max()

In [None]:
b_hat[0,0]

In [None]:
a_hat[0,0]

In [None]:
g = torch.distributions.Gamma(a_hat[0,0], b_hat[0,0])

In [None]:
g.log_prob(train_pytorch[:,0,0])

In [None]:
pdf = torch.exp(g.log_prob(torch.arange(1e-6, 50)))

In [None]:
plt.plot(pdf.detach().cpu().numpy())

In [None]:
a_hat[0,0]

In [None]:
scipy_g = scipy.stats.gamma(a=0.4462, scale=1 / 0.0194)

In [None]:
scipy_g

In [None]:
pdfs

In [None]:
fix, ax = plt.subplots()
ax.plot(pdfs)

In [None]:
a, loc, scale = scipy.stats.gamma.fit(train_pytorch[:, 0, 0].detach().cpu().numpy())
scipy_g = scipy.stats.gamma(a=1.5, scale=0.0681)
pdfs = scipy_g.pdf(np.arange(0.1, 50))

In [None]:
plt.plot(pdfs)

In [None]:
train_pytorch[:, 0, 0]

In [None]:
scipy_g.pdf(train_pytorch[:, 0, 0].detach().cpu().numpy())

In [None]:
a_hat[0,0]

## Do it on only one station

In [None]:
sample = tp_train.isel(station=0, forecast_dayofyear=0).compute()

In [None]:
sample

In [None]:
a_hat_xarray = (sample.mean(dim='forecast_year') ** 2 / (sample.var(dim='forecast_year') + 1e-6)).compute().tp.data
b_hat_xarray = ((sample.mean(dim='forecast_year') / sample.var(dim='forecast_year') + 1e-6)).compute().tp.data

In [None]:
a_hat_xarray

In [None]:
b_hat_xarray

In [None]:
a_hat = torch.tensor(a_hat_xarray, requires_grad=True)
b_hat =  torch.tensor(b_hat_xarray, requires_grad=True)

In [None]:
g = torch.distributions.Gamma(a_hat, b_hat)

In [None]:
g.log_prob(sample.tp.data)