In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributions
import scipy.stats as stats
import seaborn as sns
import xarray as xr

from crims2s.dataset import S2SDataset, TransformedDataset
from crims2s.transform import AddBiweeklyDimTransform
from crims2s.distribution import std_estimator

In [None]:
DATASET = '***BASEDIR***/mlready/2021-08-28-test/'

In [None]:
dataset = S2SDataset(DATASET, include_features=False, name_filter=lambda x: x.endswith('0102.nc'))
dataset = TransformedDataset(dataset, AddBiweeklyDimTransform())

In [None]:
dataset[10].keys()

In [None]:
len(dataset)

In [None]:
model = dataset[10]['model']

In [None]:
model.tp

In [None]:
model

In [None]:
model = xr.concat([dataset[i]['model'] for i in range(4)], dim='forecast_time')

In [None]:
model

# Fit gamma using pytorch

The fitting is too slow using scipy stats. We'll have to make that logic using pytorch.

In [None]:
REG = 1e-9

In [None]:
weekly_total = model.tp.isel(lead_time=-1)

In [None]:
# Initial estimate using the method of moments.

a_hat_xarray = weekly_total.mean(dim='realization') ** 2 / (weekly_total.var(dim='realization') + REG)
b_hat_xarray = (weekly_total.mean(dim='realization') + REG) / (weekly_total.var(dim='realization') + REG)

In [None]:
(a_hat_xarray / b_hat_xarray**2).mean()

In [None]:
mean_tp = a_hat_xarray / b_hat_xarray

In [None]:
mean_tp.isel(biweekly_forecast=1).plot()

In [None]:
a_hat =torch.tensor(a_hat_xarray.data, requires_grad=True)
b_hat = torch.tensor(b_hat_xarray.data, requires_grad=True)

In [None]:
weekly_total_torch = torch.clamp(torch.from_numpy(weekly_total.transpose('realization', 'biweekly_forecast', 'latitude', 'longitude').data), min=REG)

In [None]:
weekly_total_torch.shape

In [None]:
optimizer = torch.optim.Adam([a_hat, b_hat], lr=1e-2)

In [None]:
weekly_total_torch.shape

In [None]:
mean_lls = []

In [None]:
for epoch in range(50):
    estimated_gamma = torch.distributions.Gamma(torch.clamp(a_hat, min=REG) , torch.clamp(b_hat, min=REG))

    mean_log_likelihood = estimated_gamma.log_prob(weekly_total_torch).mean()

    mean_lls.append(-mean_log_likelihood.detach().item())

    loss = -mean_log_likelihood

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
plt.plot(mean_lls)

In [None]:
b_hat

In [None]:
plt.imshow((a_hat.detach() / b_hat.detach())[0])

In [None]:
a_hat.max()

In [None]:
b_hat.min()

In [None]:
(a_hat / b_hat).mean()

In [None]:
def fit_gamma_xarray(array: xr.DataArray, dim=None, regularization=1e-9, **kwargs):
    """"""
    # Use method of moments for initial estimate.
    a_hat_xarray = array.mean(dim=dim) ** 2 / (array.var(dim=dim) + regularization)
    b_hat_xarray = (array.mean(dim=dim) + regularization) / (array.var(dim=dim) + regularization)
    
    transposed = array.transpose(dim, ...)
    
    alpha, beta = fit_gamma_pytorch(transposed.data, a_hat_xarray.data, b_hat_xarray.data, regularization=regularization, **kwargs)
    
    alpha_xarray = xr.zeros_like(a_hat_xarray).rename(f'{a_hat_xarray.name}_alpha')
    beta_xarray = xr.zeros_like(b_hat_xarray).rename(f'{a_hat_xarray.name}_beta')
    
    alpha_xarray.data = alpha.numpy()
    beta_xarray.data = beta.numpy()
    
    return xr.merge([alpha_xarray, beta_xarray])
    

In [None]:
def fit_gamma_pytorch(data, a_hat, b_hat, regularization=1e-9, max_epochs=500, lr=1e-2, tol=1e-5, patience=5, return_losses=False):
    n_iter_waited = 0
    
    alpha = torch.tensor(a_hat, requires_grad=True)
    beta = torch.tensor(b_hat, requires_grad=True)
    data = torch.tensor(data)
    
    optimizer = torch.optim.Adam([alpha, beta], lr=lr)
    log_likelihoods = []
    for epoch in range(max_epochs):
        clamped_alpha = torch.clamp(alpha, min=regularization)
        clamped_beta = torch.clamp(beta, min=regularization)
        
        estimated_gamma = torch.distributions.Gamma(clamped_alpha , clamped_beta)

        loss = -estimated_gamma.log_prob(data).mean()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if len(log_likelihoods) > 0:
            percent_improvement = log_likelihoods[-1] / loss - 1.0
            best_loss = np.array(log_likelihoods).min()
            if np.abs(best_loss - loss.detach()) < tol:
                n_iter_waited += 1
                
                if n_iter_waited >= patience:
                    break
        
        log_likelihoods.append(loss.detach().item())
            
    alpha, beta = torch.clamp(alpha, min=regularization).detach(), torch.clamp(beta, min=regularization).detach()

    if return_losses:
        return alpha, beta, log_likelihoods
    else:
        return alpha, beta

In [None]:
gamma_params = fit_gamma_xarray(model.tp.isel(lead_time=-1).clip(min=1e-9), dim='realization', tol=1e-5)

In [None]:
gamma_params.tp_beta.isel(biweekly_forecast=1).plot()

# Zero inflated gaussian

In [None]:
class NormalExpMixture:
    def __init__(self, loc, scale, rate, mix):
        self.normal = torch.distributions.Normal(loc, scale)
        self.exponential = torch.distributions.Exponential(rate)
        self.mix = mix
        
    def log_prob(self, x):
        return (1.0 - self.mix) * self.normal.log_prob(x) + self.mix * self.exponential.log_prob(x + 1e-9)
    
    def cdf(self, x):
        return (1.0 - self.mix) * self.normal.cdf(x) + self.mix * self.exponential.cdf(x + 1e-9)

In [None]:
class GaussianMixtureModel:
    def __init__(self, loc, scale, zero):
        self.normal = torch.distributions.Normal(loc, scale)
        self.zero = torch.tensor(zero)
        
    def __check_input(self, x):
        if (x < 0.0).any():
            raise ValueError('Values must be whithin support')
            
    def log_prob(self, x):
        self.__check_input(x)
        
        ll = torch.zeros_like(x)
        
        zero_mask = x == 0.0
        
        print(ll.shape)
        print(self.zero.shape)
        
        ll[:] = self.zero
        ll[~zero_mask] = (1.0 - self.zero) * self.normal.log_prob(x)[~zero_mask]
        
        return ll
    
    def cdf(self, x):
        self.__check_input(x)
        
        zero_mask = x == 0.0
        cdf[zero_mask] = self.zero[zero_mask]
        cdf[~zero_mask] = self.zero[~zero_mask] + self.normal.cdf(x)[~zero_mask]
        
        return cdf

In [None]:
class ZeroInflatedGaussian:
    def __init__(self, loc, scale):
        self.normal = torch.distributions.Normal(loc, scale)
        
    def log_prob(self, x):
        if (x < 0.0).any():
            raise ValueError('Values must be whithin support')
        
        ll = torch.zeros_like(x)
        
        bigger_mask = x > 0.
        ll[bigger_mask] = self.normal.log_prob(x)[bigger_mask]
        
        equal_mask = x == 0.
        ll[equal_mask] = self.normal.cdf(x)[equal_mask]
        
        return ll
        
        
    def cdf(self, x):
        if (x < 0.0).any():
            raise ValueError('Values must be whithin support')
            
        return self.normal.cdf(x)

In [None]:
class CensoredNormal:
    def __init__(self, loc, scale):
        self.normal = torch.distributions.Normal(loc, scale)
        
    def __check_input(self, x):
        if (x < 0.0).any():
            raise ValueError('Values must be whithin support')
        
    def log_prob(self, x):
        self.__check_input(x)
        
        normal_log_prob = self.normal.log_prob(x)
        denominator = torch.log(1.0 - self.normal.cdf(torch.zeros_like(x)) + 1e-6)

        log_prob = normal_log_prob - denominator 
                
        return log_prob
        
    def cdf(self, x):
        self.__check_input(x)
        zero_cdf = self.normal.cdf(torch.zeros_like(x))
        
    
        
        return (self.normal.cdf(x) - zero_cdf) / (1.0 - zero_cdf)

In [None]:

def fit_zero_inflated_normal_pytorch(data, mu_hat, theta_hat, regularization=1e-9, max_epochs=1, lr=1e-4, tol=1e-5, patience=5, log_likelihoods=losses):
    n_iter_waited = 0
    
    mu = torch.tensor(mu_hat, requires_grad=True)
    mu.retain_grad()
    theta = torch.tensor(theta_hat, requires_grad=True)
    
    data = torch.tensor(data)
    
    optimizer = torch.optim.Adam([mu, theta], lr=lr)
    for epoch in range(max_epochs):
        clampe
        clamped_theta = torch.clamp(theta, min=regularization)
        clamped_theta.retain_grad()
                
        estimated_distribution = CensoredNormal(mu, clamped_theta)
        
        loss = -estimated_distribution.log_prob(data).mean() + torch.square(mu).mean()
        loss.backward()
                
        optimizer.step()
        optimizer.zero_grad()
        
        if len(log_likelihoods) > 0:
            percent_improvement = log_likelihoods[-1] / loss - 1.0
            best_loss = np.array(log_likelihoods).min()
            if np.abs(best_loss - loss.detach()) < tol:
                n_iter_waited += 1
                
                if n_iter_waited >= patience:
                    break
        
        log_likelihoods.append(loss.detach().item())
            
    mu, theta = mu.detach(), theta.detach()

    return mu, torch.clamp(theta, min=regularization)

In [None]:
mu_hat = weekly_total.mean(dim=['realization', 'forecast_time'])
theta_hat = std_estimator(weekly_total, dim=['realization', 'forecast_time']) + 1.0

tp_data = weekly_total.transpose('realization', 'forecast_time', ...).data

In [None]:
losses = []
mu, theta = fit_zero_inflated_normal_pytorch(tp_data, mu_hat.data, theta_hat.data, log_likelihoods=losses, lr=1e-1, tol=1e-4, max_epochs=1000, regularization=1e-3)

In [None]:
theta.shape

In [None]:
sns.histplot(data=theta[mu < 0].detach().numpy().flatten())

In [None]:
plt.imshow(theta[mu < 0][0])

In [None]:
plt.plot(losses)

In [None]:
sns.histplot(data=mu.detach().numpy().flatten())

In [None]:
sns.histplot(data=mu[(mu > 0) & (mu < 1)].detach().numpy().flatten())

In [None]:
sns.histplot(data=mu[mu < 0].detach().numpy().flatten())

In [None]:
sns.histplot(data=theta.detach().numpy().flatten())

In [None]:
rate[mix > 0.95].histc()

In [None]:
sns.histplot(data=rate[mix > 0.8].detach().numpy().flatten(), bins=20)

In [None]:
sns.histplot(data=mu[mix < 0.1].detach().numpy().flatten())

In [None]:
plt.imshow(mu[0].detach().numpy())

In [None]:
rate[0].max()

In [None]:
weekly_total.isel(biweekly_forecast=0, realization=1).plot()

In [None]:
losses

In [None]:
distr = NormalExpMixture(4.0, 3.0, 1.0, 0.5)

In [None]:
plt.plot(torch.exp(distr.log_prob(torch.arange(0.0, 10.0, step=1e-2))).numpy())

In [None]:
def fit_zero_inflated_normal_pytorch(data, mu_hat, theta_hat, rate_hat, regularization=1e-9, max_epochs=1, lr=1e-2, tol=1e-5, patience=5, return_losses=False):
    n_iter_waited = 0
    
    mu = torch.tensor(mu_hat, requires_grad=True)
    theta = torch.tensor(theta_hat, requires_grad=True)
    rate = torch.tensor(rate_hat, requires_grad=True)
    #mix = torch.tensor(rate_hat, requires_grad=True)
    mix = torch.zeros_like(rate, requires_grad=True)
    
    data = torch.tensor(data)
    
    optimizer = torch.optim.Adam([mu, theta, rate, mix], lr=lr)
    log_likelihoods = []
    for epoch in range(max_epochs):
        clamped_theta = torch.clamp(theta, min=regularization)
        clamped_rate = torch.clamp(rate, min=regularization)
        clamped_mix = torch.sigmoid(mix)
        
        estimated_distribution = NormalExpMixture(mu, clamped_theta, clamped_rate, clamped_mix)

        loss = -estimated_distribution.log_prob(data).mean()

        loss.backward()
                
        optimizer.step()
        optimizer.zero_grad()
        
        if len(log_likelihoods) > 0:
            percent_improvement = log_likelihoods[-1] / loss - 1.0
            best_loss = np.array(log_likelihoods).min()
            if np.abs(best_loss - loss.detach()) < tol:
                n_iter_waited += 1
                
                if n_iter_waited >= patience:
                    break
        
        log_likelihoods.append(loss.detach().item())
            
    mu, theta = mu.detach(), theta.detach()

    if return_losses:
        return mu, torch.clamp(theta, min=regularization), torch.clamp(rate, min=regularization), torch.sigmoid(mix), log_likelihoods
    else:
        return mu, torch.clamp(theta, min=regularization), torch.clamp(rate, min=regularization), torch.sigmoid(mix)