In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributions
import scipy.stats as stats
import xarray as xr

from crims2s.dataset import S2SDataset, TransformedDataset
from crims2s.transform import AddBiweeklyDimTransform

In [None]:
DATASET = '***BASEDIR***/mlready/2021-08-21-test/'

In [None]:
dataset = S2SDataset(DATASET, include_features=False, name_filter=lambda x: x.endswith('0102.nc'))
dataset = TransformedDataset(dataset, AddBiweeklyDimTransform())

In [None]:
dataset[0].keys()

In [None]:
len(dataset)

In [None]:
model = dataset[0]['model']

In [None]:
model.tp

In [None]:
model

In [None]:
def fit_wrapper(*args, **kwargs):
    ret = stats.gamma.fit(*args, **kwargs)
    ret_np = np.array(ret)
    return ret_np

In [None]:
scipy_fit = xr.apply_ufunc(fit_wrapper, model.tp.isel(lead_time=-1, latitude=slice(0,10), longitude=slice(0,10)), input_core_dims=[['realization']], output_core_dims=[['parameter']], vectorize=True)

In [None]:
scipy_fit

# Fit gamma using pytorch

The fitting is too slow using scipy stats. We'll have to make that logic using pytorch.

In [None]:
REG = 1e-9

In [None]:
weekly_total = model.tp.isel(lead_time=-1)

In [None]:
# Initial estimate using the method of moments.

a_hat_xarray = weekly_total.mean(dim='realization') ** 2 / (weekly_total.var(dim='realization') + REG)
b_hat_xarray = (weekly_total.mean(dim='realization') + REG) / (weekly_total.var(dim='realization') + REG)

In [None]:
(a_hat_xarray / b_hat_xarray**2).mean()

In [None]:
mean_tp = a_hat_xarray / b_hat_xarray

In [None]:
mean_tp.isel(biweekly_forecast=1).plot()

In [None]:
a_hat =torch.tensor(a_hat_xarray.data, requires_grad=True)
b_hat = torch.tensor(b_hat_xarray.data, requires_grad=True)

In [None]:
weekly_total_torch = torch.clamp(torch.from_numpy(weekly_total.transpose('realization', 'biweekly_forecast', 'latitude', 'longitude').data), min=REG)

In [None]:
weekly_total_torch.shape

In [None]:
optimizer = torch.optim.Adam([a_hat, b_hat], lr=1e-2)

In [None]:
weekly_total_torch.shape

In [None]:
mean_lls = []

In [None]:
for epoch in range(50):
    estimated_gamma = torch.distributions.Gamma(torch.clamp(a_hat, min=REG) , torch.clamp(b_hat, min=REG))

    mean_log_likelihood = estimated_gamma.log_prob(weekly_total_torch).mean()

    mean_lls.append(-mean_log_likelihood.detach().item())

    loss = -mean_log_likelihood

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
plt.plot(mean_lls)

In [None]:
b_hat

In [None]:
plt.imshow((a_hat.detach() / b_hat.detach())[0])

In [None]:
a_hat.max()

In [None]:
b_hat.min()

In [None]:
(a_hat / b_hat).mean()

In [None]:
def fit_gamma_xarray(array: xr.DataArray, dim=None, regularization=1e-9, **kwargs):
    """"""
    # Use method of moments for initial estimate.
    a_hat_xarray = array.mean(dim=dim) ** 2 / (array.var(dim=dim) + regularization)
    b_hat_xarray = (array.mean(dim=dim) + regularization) / (array.var(dim=dim) + regularization)
    
    transposed = array.transpose(dim, ...)
    
    alpha, beta = fit_gamma_pytorch(transposed.data, a_hat_xarray.data, b_hat_xarray.data, regularization=regularization, **kwargs)
    
    alpha_xarray = xr.zeros_like(a_hat_xarray).rename(f'{a_hat_xarray.name}_alpha')
    beta_xarray = xr.zeros_like(b_hat_xarray).rename(f'{a_hat_xarray.name}_beta')
    
    alpha_xarray.data = alpha.numpy()
    beta_xarray.data = beta.numpy()
    
    return xr.merge([alpha_xarray, beta_xarray])
    

In [None]:
def fit_gamma_pytorch(data, a_hat, b_hat, regularization=1e-9, max_epochs=500, lr=1e-2, tol=1e-5, patience=5, return_losses=False):
    n_iter_waited = 0
    
    alpha = torch.tensor(a_hat, requires_grad=True)
    beta = torch.tensor(b_hat, requires_grad=True)
    data = torch.tensor(data)
    
    optimizer = torch.optim.Adam([alpha, beta], lr=lr)
    log_likelihoods = []
    for epoch in range(max_epochs):
        clamped_alpha = torch.clamp(alpha, min=regularization)
        clamped_beta = torch.clamp(beta, min=regularization)
        
        estimated_gamma = torch.distributions.Gamma(clamped_alpha , clamped_beta)

        loss = -estimated_gamma.log_prob(data).mean()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if len(log_likelihoods) > 0:
            percent_improvement = log_likelihoods[-1] / loss - 1.0
            best_loss = np.array(log_likelihoods).min()
            if np.abs(best_loss - loss.detach()) < tol:
                n_iter_waited += 1
                
                if n_iter_waited >= patience:
                    break
        
        log_likelihoods.append(loss.detach().item())
            
    alpha, beta = torch.clamp(alpha, min=regularization).detach(), torch.clamp(beta, min=regularization).detach()

    if return_losses:
        return alpha, beta, log_likelihoods
    else:
        return alpha, beta

In [None]:
gamma_params = fit_gamma_xarray(model.tp.isel(lead_time=-1).clip(min=1e-9), dim='realization', tol=1e-5)

In [None]:
gamma_params.tp_beta.isel(biweekly_forecast=1).plot()