In [None]:
from enum import Enum, auto

import torch
import torch.nn as nn
import torch.nn.functional as F


class NoiseType(Enum):
    DIAGONAL = auto() #\Sigma = diag(\sigma^(k))
    ISOTROPIC = auto() #\Sigma = \sigma \times I
    ISOTROPIC_ACROSS_CLUSTERS = auto()
    FIXED = auto()


class MixtureDensityNetwork(nn.Module):
    """
    Mixture density network.

    [ Bishop, 1994 ]

    Parameters
    ----------
    dim_in: int; dimensionality of the covariates
    dim_out: int; dimensionality of the response variable
    n_components: int; number of components in the mixture model
    """
    def __init__(self, dim_in, dim_out, n_components, hidden_dim, noise_type=NoiseType.DIAGONAL, fixed_noise_level=None):
        super().__init__()
        assert (fixed_noise_level is not None) == (noise_type is NoiseType.FIXED)
        num_sigma_channels = {
            NoiseType.DIAGONAL: dim_out * n_components,
            NoiseType.ISOTROPIC: n_components,
            NoiseType.ISOTROPIC_ACROSS_CLUSTERS: 1,
            NoiseType.FIXED: 0,
        }[noise_type]
        self.dim_in, self.dim_out, self.n_components = dim_in, dim_out, n_components
        self.noise_type, self.fixed_noise_level = noise_type, fixed_noise_level
        self.pi_network = nn.Sequential(
            nn.Linear(dim_in, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_components),
        )
        self.normal_network = nn.Sequential(
            nn.Linear(dim_in, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim_out * n_components + num_sigma_channels)
        )

    def forward(self, x, eps=1e-6):
        #
        # Returns
        # -------
        # log_pi: (bsz, n_components)
        # mu: (bsz, n_components, dim_out)
        # sigma: (bsz, n_components, dim_out)
        #
        log_pi = torch.log_softmax(self.pi_network(x), dim=-1)
        normal_params = self.normal_network(x)
        mu = normal_params[..., :self.dim_out * self.n_components]
        sigma = normal_params[..., self.dim_out * self.n_components:]
        if self.noise_type is NoiseType.DIAGONAL:
            sigma = torch.exp(sigma + eps) # add eps to make it non-zero
        if self.noise_type is NoiseType.ISOTROPIC:
            sigma = torch.exp(sigma + eps).repeat(1, self.dim_out)
        if self.noise_type is NoiseType.ISOTROPIC_ACROSS_CLUSTERS:
            sigma = torch.exp(sigma + eps).repeat(1, self.n_components * self.dim_out)
        if self.noise_type is NoiseType.FIXED:
            sigma = torch.full_like(mu, fill_value=self.fixed_noise_level)
        mu = mu.reshape(-1, self.n_components, self.dim_out)
        sigma = sigma.reshape(-1, self.n_components, self.dim_out)
        return log_pi, mu, sigma

    def loss(self, x, y):
        log_pi, mu, sigma = self.forward(x)
        z_score = (y.unsqueeze(1) - mu) / sigma
        normal_loglik = (
            -0.5 * torch.einsum("bij,bij->bi", z_score, z_score)
            -torch.sum(torch.log(sigma), dim=-1)
        )
        loglik = torch.logsumexp(log_pi + normal_loglik, dim=-1)
        return -loglik

    def sample(self, x):
        log_pi, mu, sigma = self.forward(x)
        cum_pi = torch.cumsum(torch.exp(log_pi), dim=-1)
        rvs = torch.rand(len(x), 1).to(x)
        rand_pi = torch.searchsorted(cum_pi, rvs)
        rand_normal = torch.randn_like(mu) * sigma + mu
        samples = torch.take_along_dim(rand_normal, indices=rand_pi.unsqueeze(-1), dim=1).squeeze(dim=1)
        return samples

    def dens(self, x, y):
        log_pi, mu, sigma = self.forward(x)
        z_score = (y.unsqueeze(1) - mu) / sigma
        normal_loglik = (
            -0.5 * torch.einsum("bij,bij->bi", z_score, z_score)
            -torch.sum(torch.log(sigma), dim=-1)
        )
        loglik = torch.logsumexp(log_pi + normal_loglik, dim=-1)
        density = torch.exp(loglik)
        return density

In [None]:
import pandas as pd

## load in the training data
covs = pd.read_csv('/content/sample_data/covs.txt', sep = ' ', header = None)
resp = pd.read_csv('/content/sample_data/resp.txt', sep = ' ', header = None)


#print(covs)
#print(resp)

In [None]:
import logging

import numpy as np
import torch.optim as optim

if __name__ == "__main__":

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    x = covs.values
    y = resp.values
    x = torch.Tensor(x)
    y = torch.Tensor(y)

    model = MixtureDensityNetwork(6, 1, n_components=4, hidden_dim=30, noise_type=NoiseType.DIAGONAL)
    optimizer = optim.Adam(model.parameters(), lr=0.005)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 2000)

    for i in range(2000):
        optimizer.zero_grad()
        loss = model.loss(x, y).mean()
        loss.backward()
        optimizer.step()
        scheduler.step()
        if i % 100 == 0:
            logger.info(f"Iter: {i}\t" + f"Loss: {loss.data:.2f}")
            print(model.loss(x, y).mean())



In [None]:
## calibration step
cals_x = pd.read_csv('/content/sample_data/cals_x.txt', sep = ' ', header = None)
cals_y = pd.read_csv('/content/sample_data/cals_y.txt', sep = ' ', header = None)

cals_x = torch.Tensor(cals_x.values)
cals_y = torch.Tensor(cals_y.values)

f_hat = model.dens(cals_x, cals_y) ## the estimated values of \hat{f(y_i|x_i)}

score = torch.empty((0, 1), dtype=torch.float32) ## initialize the score vector

for i in range(5000):
    x_temp = cals_x[i]
    x_temp = x_temp.repeat(2530, 1) #2530 is arbitrary, could do a larger or smaller sample. The larger the sample, the more accurate the HPD cutoff is
    with torch.no_grad():
        y_hat = model.sample(x_temp) # get a sample of the data, take the \alpha lower quantile of the sample for the HPD cutoff
        preds_cal = model.dens(x_temp, y_hat)
    cutoff = torch.quantile(preds_cal, 0.1, interpolation='lower') # this is the HPD cutoff
    score_temp = f_hat[i] / cutoff
    score_temp = score_temp.unsqueeze(0).unsqueeze(0) # Reshape score_temp to be 2D
    score = torch.cat((score, score_temp), 0)
    if i % 500 == 0:
        print(i)
    del score_temp
    del preds_cal
    del x_temp
    del cutoff
qhat = torch.quantile(score, 0.1, interpolation='lower') ##final conformal adjustment


In [None]:
## final predictions
out_x = pd.read_csv('/content/sample_data/out_x.txt', sep = ' ', header = None)
out_y = pd.read_csv('/content/sample_data/out_y.txt', sep = ' ', header = None)

out_x = torch.Tensor(out_x.values)
out_y = torch.Tensor(out_y.values)

temp_cov = torch.empty((0, 1), dtype=torch.float32)
temp_len = torch.empty((0, 1), dtype=torch.float32)
fhat_out = model.dens(out_x, out_y)
for i in range(5000):
    x_temp = out_x[i]
    x_temp = x_temp.repeat(2530, 1)
    with torch.no_grad():
        y_hat = model.sample(x_temp)
        preds_cal = model.dens(x_temp, y_hat)
    cutoff = torch.quantile(preds_cal, 0.1, interpolation='lower') * qhat
    order = torch.argsort(preds_cal) ## preds_cal smallest to largest
    y_hat = y_hat[order]
    preds_cal = preds_cal[order]
    index = torch.where(preds_cal >= cutoff)[0]
    interval_values = y_hat[index]
    if torch.any(torch.diff(index) > 1):
        which_cutoff = torch.where(torch.diff(index) > 1)[0][0]
        low1 = interval_values[0]
        high1 = interval_values[which_cutoff]
        low2 = interval_values[which_cutoff + 1]
        high2 = torch.max(interval_values)
        len_temp = (high1 - low1) + (high2 - low2)
        len_temp = len_temp.unsqueeze(0).unsqueeze(0)
        temp_len = torch.cat((temp_len, len_temp), 0)
    else:
        len_temp = torch.max(interval_values) - torch.min(interval_values)
        len_temp = len_temp.unsqueeze(0).unsqueeze(0)
        temp_len = torch.cat((temp_len, len_temp), 0)
    cov_temp = fhat_out[i] >= cutoff
    cov_temp = cov_temp.unsqueeze(0).unsqueeze(0) # Reshape cov_temp
    temp_cov = torch.cat((temp_cov, cov_temp), 0)

    if i % 500 == 0:
        print(i)

temp_cov.mean()
temp_cov.std() / (5000 ** 0.5)
temp_len.mean()
temp_cov.mean() / (5000 ** 0.5)
median = torch.median(out_x[:, 0])
bright = torch.where(out_x[:, 0] < median)[0]
faint = torch.where(out_x[:, 0] >= median)[0]
bright_cov = temp_cov[bright]
bright_cov.mean()
bright_cov.std() / (2500 ** 0.5)
faint_cov = temp_cov[faint]
faint_cov.mean()
faint_cov.std() / (2500 ** 0.5)
bright_len = temp_len[bright]
bright_len.mean()
bright_len.std() / (2500 ** 0.5)
faint_len = temp_len[faint]
faint_len.mean()
faint_len.std() / (2500 ** 0.5)

## getting model parameters now
#from torch.nn.utils import parameters_to_vector as p2v
#p2v(model.parameters()).numel()
#params = list(model.parameters())
#print(len(params))
