In [1]:
from itertools import cycle
import numpy as np
import torch

import matplotlib.pyplot as plt
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from tqdm.notebook import tqdm

from CP import *
from INIT import *
from Logger import *
from LoadSynthetic import *
from Load import init
from Load import *

plotly.offline.init_notebook_mode(connected=True)
plt.rcParams['text.usetex'] = True

In [2]:
seeds = [13, 2, 47, 1, 15, 31, 89, 666, 3, 43, 5197, 558213, 4236, 410, 74888, 1563, 1794, 711489, 156874, 123, \
    744741,442262,53750,550509,751836,73427,316551,321489,264450,691340, \
    256212,248195,322953,469088,515349,717046,904096,600809,566875,335409, \
    128274,11299,40740,135231,78276,268373,148066,569507,212406,584182]
seeds = seeds[:10]

alpha = 0.1
var_bins = 3

plot_clusters = False

In [3]:
class Oracle(UQModel):
    def __init__(self, data):
        super().__init__(r"$\text{Oracle}$")
        self.data = data

    def predict(self, X, std = False, alpha = None):

        if self.data["X_val"].shape == X.shape and torch.any(torch.prod(self.data["X_val"] == X, dim = 1)):
            indices = torch.all(self.data["X_val"] == X, dim = 1)
            mean = self.data["mean_val"][indices]
            var = self.data["var_val"][indices]
        else:
            indices = torch.all(self.data["X_test"] == X, dim = 1)
            mean = self.data["mean_test"][indices]
            var = self.data["var_test"][indices]

        sqrt = torch.sqrt(var)
        if alpha:
            z = st.norm.ppf((1 - alpha) + (alpha / 2))
            lower = mean - z * sqrt
            upper = mean + z * sqrt
            return torch.stack([mean, sqrt if std else var], dim = -1), torch.stack([lower, upper], dim = -1)
        else:
            return torch.stack([mean, sqrt if std else var], dim = -1)

class SigmaShifter(UQModel):
    def __init__(self, data, param = 0):
        super().__init__(rf"$\sigma\text{{-shifted ({param})}}$" if param != 0 else r"$\text{Oracle}$")
        self.data = data
        self.param = param

    def predict(self, X, std = False, alpha = None):

        if self.data["X_val"].shape == X.shape and torch.any(torch.prod(self.data["X_val"] == X, dim = 1)):
            indices = torch.all(self.data["X_val"] == X, dim = 1)
            mean = self.data["mean_val"][indices]
            var = self.data["var_val"][indices]
            seed = 5
        else:
            indices = torch.all(self.data["X_test"] == X, dim = 1)
            mean = self.data["mean_test"][indices]
            var = self.data["var_test"][indices]
            seed = 17

        np.random.seed(seed)
        rand = np.random.normal(0, self.param, var.shape[0])
        var = var + torch.sign(var + rand) * rand
        sqrt = torch.sqrt(var)

        if alpha:
            z = st.norm.ppf((1 - alpha) + (alpha / 2))
            lower = mean - z * sqrt
            upper = mean + z * sqrt
            return torch.stack([mean, sqrt if std else var], dim = -1), torch.stack([lower, upper], dim = -1)
        else:
            return torch.stack([mean, sqrt if std else var], dim = -1)

class Scaler(UQModel):
    def __init__(self, data, param = 1):
        super().__init__(rf"$\sigma\text{{-scaled ({param})}}$")
        self.data = data
        self.param = param

    def predict(self, X, std = False, alpha = None):

        if self.data["X_val"].shape == X.shape and torch.any(torch.prod(self.data["X_val"] == X, dim = 1)):
            indices = torch.all(self.data["X_val"] == X, dim = 1)
            mean = self.data["mean_val"][indices]
            var = self.data["var_val"][indices]
        else:
            indices = torch.all(self.data["X_test"] == X, dim = 1)
            mean = self.data["mean_test"][indices]
            var = self.data["var_test"][indices]

        var = var * self.param
        sqrt = torch.sqrt(var)

        if alpha:
            z = st.norm.ppf((1 - alpha) + (alpha / 2))
            lower = mean - z * sqrt
            upper = mean + z * sqrt
            return torch.stack([mean, sqrt if std else var], dim = -1), torch.stack([lower, upper], dim = -1)
        else:
            return torch.stack([mean, sqrt if std else var], dim = -1)
        
class Shifter(UQModel):
    def __init__(self, data, param = 1, var_scale = False):
        super().__init__(rf"$\mu\text{{-shifted ({param})}}$" if not var_scale else r"$\mu\text{-shifted (}\sigma\text{)}$")
        self.data = data
        self.param = param
        self.var_scale = var_scale

    def predict(self, X, std = False, alpha = None):

        if self.data["X_val"].shape == X.shape and torch.any(torch.prod(self.data["X_val"] == X, dim = 1)):
            indices = torch.all(self.data["X_val"] == X, dim = 1)
            mean = self.data["mean_val"][indices]
            var = self.data["var_val"][indices]
        else:
            indices = torch.all(self.data["X_test"] == X, dim = 1)
            mean = self.data["mean_test"][indices]
            var = self.data["var_test"][indices]

        sqrt = torch.sqrt(var)
        mean = mean + (self.param if not self.var_scale else self.param * sqrt)

        if alpha:
            z = st.norm.ppf((1 - alpha) + (alpha / 2))
            lower = mean - z * sqrt
            upper = mean + z * sqrt
            return torch.stack([mean, sqrt if std else var], dim = -1), torch.stack([lower, upper], dim = -1)
        else:
            return torch.stack([mean, sqrt if std else var], dim = -1)
        

In [4]:
def routine(datasource, feature_choice, seed, folder):

     if feature_choice:
          data = init("synth" + datasource, seed = seed, to_torch = True)
     else:
          data = init(datasource, seed = seed, to_torch = True)

     models = []
     model = Oracle(data)
     models.append((model.name, model))
     for param in [1e-2, .1, 1]:
          model = SigmaShifter(data, param = param)
          models.append((model.name, model))
     models.append((Scaler(data, 5).name, Scaler(data, 5)))
     models.append((Shifter(data).name, Shifter(data)))
     models.append((Shifter(data, var_scale = True).name, Shifter(data, var_scale = True)))

     covs_dict = {}
     widths_dict = {}
     r2s_dict = {}
     
     conditional_covs = {}
     conditional_r2s = {}
     conditional_widths = {}
     conditional_scores = {}

     def CP(scores):
          scores, _ = torch.sort(scores)
          level = min((1 + (1 / scores.shape[0])) * (1 - alpha), 1)
          index = math.ceil(level * scores.shape[0]) - 1
          return scores[index].item()

     def save(name, r2, covs, widths):

        r2s_dict[name] = r2
        covs_dict[name] = np.mean(covs.numpy())
        if torch.is_tensor(widths):
            widths_dict[name] = np.mean(widths.numpy())
        else:
            widths_dict[name] = widths
            
     def save_conditional(name, r2, covs, widths):
          
          if name not in conditional_covs.keys():
               conditional_r2s[name] = []
               conditional_covs[name] = []
               conditional_widths[name] = []
               conditional_scores[name] = []
               
          conditional_covs[name].append(np.mean(covs.numpy()))
          conditional_r2s[name].append(r2)
          if torch.is_tensor(widths):
               conditional_widths[name].append(np.mean(widths.numpy()))
          else:
               conditional_widths[name].append(widths)

     # variances = []
     # pis = []
     with torch.no_grad():
          for name, model in models:
               val_preds, val_pi = model.predict(data["X_val"], alpha = alpha)
               var = np.sort(val_preds[:, 1])
               preds, pi = model.predict(data["X_test"], alpha = alpha)

               splits = [-1e10]
               for i in range(var_bins - 1):
                    splits.append(var[(i + 1) * (len(var) // var_bins)])
               splits.append(1e10)

               r2 = r2_score(data["y_test"], preds[:, 0])

               test = [splits[i+1] == splits[i+2] for i in range(var_bins-1)]
               override = True in test

               mse_crit = CP(torch.abs(val_preds[:, 0] - data["y_val"]))
               int_crit = CP(torch.maximum(val_pi[:, 0] - data["y_val"], data["y_val"] - val_pi[:, 1]))
               nmse_crit = CP(torch.abs(val_preds[:, 0] - data["y_val"]) / torch.sqrt(val_preds[:, 1])) * torch.sqrt(preds[:, 1])

               for i in range(var_bins):
                    val_index = (val_preds[:, 1] >= splits[i]) & (val_preds[:, 1] <= splits[i + 1])
                    test_index = (preds[:, 1] >= splits[i]) & (preds[:, 1] <= splits[i + 1])

                    if override:
                         val_index = range(val_preds.shape[0])
                         test_index = range(preds.shape[0])

                    y_test = data["y_test"][test_index]

                    r2_ = r2_score(y_test, preds[test_index, 0])
                    # save_conditional(name + "#base", r2, (y_test >= pi[test_index, 0]) & (y_test <= pi[test_index, 1]), \
                    #      pi[test_index, 1] - pi[test_index, 0])

                    save(name + "#PointCP", r2, (data["y_test"] >= preds[:, 0] - mse_crit) & (data["y_test"] <= preds[:, 0] + mse_crit), 2 * mse_crit)
                    save_conditional(name + "#mPointCP", r2_, (y_test >= preds[test_index, 0] - mse_crit) & (y_test <= preds[test_index, 0] + mse_crit), 2 * mse_crit)

                    save(name + "#IntCP", r2, (data["y_test"] >= pi[:, 0] - int_crit) & (data["y_test"] <= pi[:, 1] + int_crit), pi[:, 1] - pi[:, 0] + 2 * int_crit)
                    save_conditional(name + "#mIntCP", r2_, (y_test >= pi[test_index, 0] - int_crit) & (y_test <= pi[test_index, 1] + int_crit), pi[test_index, 1] - pi[test_index, 0] + 2 * int_crit)

                    save(name + "#NCP", r2, (data["y_test"] >= preds[:, 0] - nmse_crit) & (data["y_test"] <= preds[:, 0] + nmse_crit), 2 * nmse_crit)
                    save_conditional(name + "#mNCP", r2_, (y_test >= preds[test_index, 0] - nmse_crit[test_index]) & (y_test <= preds[test_index, 0] + nmse_crit[test_index]), 2 * nmse_crit[test_index])

                    # crit = CP(torch.abs(val_preds[val_index, 0] - data["y_val"][val_index]))
                    # save_conditional(name + "#PointCP", r2_, (y_test >= preds[test_index, 0] - crit) & (y_test <= preds[test_index, 0] + crit), 2 * crit)

                    # crit = CP(torch.maximum(val_pi[val_index, 0] - data["y_val"][val_index], data["y_val"][val_index] - val_pi[val_index, 1]))
                    # save_conditional(name + "#IntCP", r2_, (y_test >= pi[test_index, 0] - crit) & (y_test <= pi[test_index, 1] + crit), pi[test_index, 1] - pi[test_index, 0] + 2 * crit)

                    # crit = CP(torch.abs(val_preds[val_index, 0] - data["y_val"][val_index]) / torch.sqrt(val_preds[val_index, 1])) * torch.sqrt(preds[test_index, 1])
                    # save_conditional(name + "#NCP", r2_, (y_test >= preds[test_index, 0] - crit) & (y_test <= preds[test_index, 0] + crit), 2 * crit)

     r2s = np.stack([r2s_dict[name] for name in r2s_dict.keys()], axis = -1)
     covs = np.stack([covs_dict[name] for name in covs_dict.keys()], axis = -1)
     widths = np.stack([widths_dict[name] for name in covs_dict.keys()], axis = -1)
     marginal = (r2s, covs, widths, list(covs_dict.keys()))

     covs = np.zeros((len(conditional_covs.keys()), len(list(conditional_covs.items())[0][1])))
     r2s = np.zeros_like(covs)
     widths = np.zeros_like(covs)
     for i, key in enumerate(conditional_covs.keys()):
          covs[i, :] = conditional_covs[key]
          r2s[i, :] = conditional_r2s[key]
          widths[i, :] = conditional_widths[key]
     conditional = (r2s, covs, widths, list(conditional_covs.keys()))

     return marginal, conditional

In [5]:
folder = "./PLOTS/1000/"

for combi in [("constant", "bimodal", "Homoskedastic"), ("cm", "normal", "Type 1"), ("parametric", "normal", "Type 2"), ("dim1", "uniform", "Type 3"), ("bimodal", None, "Type 4")]:

    title = combi[2]
    var_control = .1

    if combi[1] == None:
        datasource = "bimodal"
    else:
        source = combi[0]
        feature_choice = combi[1]
        datasource = source + "_" + feature_choice + "_" + ("high_coupling" if var_control >= 1 else "low_coupling")

    r2s_marginal = []
    coverages_marginal = []
    widths_marginal = []
    columns_marginal = None

    r2s_conditional = []
    coverages_conditional = []
    widths_conditional = []
    columns_conditional = []

    ks_names = None
    ks = []
        
    for i, s in enumerate(tqdm(seeds)):
        result, result_conditional = routine(datasource, feature_choice, s, folder)

        r2s_marginal.append(result[0])
        coverages_marginal.append(result[1])
        widths_marginal.append(result[2])
        columns_marginal = result[3]
        
        r2s_conditional.append(result_conditional[0])
        coverages_conditional.append(result_conditional[1])
        widths_conditional.append(result_conditional[2])
        columns_conditional = result_conditional[3]

    r2s_marginal = np.stack(r2s_marginal, axis = 0)
    coverages_marginal = np.stack(coverages_marginal, axis = 0)
    widths_marginal = np.stack(widths_marginal, axis = 0)

    r2s_conditional = np.stack(r2s_conditional, axis = 0)
    coverages_conditional = np.stack(coverages_conditional, axis = 0)
    widths_conditional = np.stack(widths_conditional, axis = 0)

    mergers = ["QR", "QRF", "ClusterQR"]
    conditional_scores = 3

    fig = make_subplots(rows = var_bins, cols = 1, subplot_titles = ["Low variance", "Medium variance", "High variance"])
    fig2 = make_subplots(rows = var_bins, cols = 1, subplot_titles = ["Low variance", "Medium variance", "High variance"])
    fig3 = make_subplots(rows = var_bins, cols = 1, subplot_titles = ["Low variance", "Medium variance", "High variance"])
    fig_ks = make_subplots(rows = var_bins, cols = 1, subplot_titles = ["Low - Medium", "Medium - High", "Low - High"])

    index = []
    for j, name in enumerate(columns_conditional):
        if name in [n + "*#base" for n in mergers] or(not "*" in name and name not in [n + "#base" for n in mergers]):
            index.append(j)

    columns = [columns_conditional[i].replace("*", "") for i in index]
    columntext = []
    for c in columns:
        text = c.split("#")
        columntext.append((text[0], text[1]))

    r2s = r2s_conditional[:, index, :]
    coverages = coverages_conditional[:, index, :]
    widths = widths_conditional[:, index, :]

    text = [""]
    text.extend([c[1] for c in columntext])
    text.append("")
    vals = [-.5]
    vals.extend(range(len(columns)))
    vals.append(len(columns) - 0.5)

    for i in range(coverages.shape[-1]):

        bg_colors = cycle(plotly.colors.qualitative.Pastel2)

        r2_data = pd.DataFrame(r2s[:, :, i], columns = columns)
        cov_data = pd.DataFrame(coverages[:, :, i], columns = columns)
        width_data = pd.DataFrame(widths[:, :, i], columns = columns)

        for j, model in enumerate(columns):

            if j % conditional_scores == 0:
                colors = cycle(plotly.colors.DEFAULT_PLOTLY_COLORS)
            color = next(colors)

            fig.add_trace(go.Box(y = cov_data[model], x0 = j, line_color = color, showlegend = False), row = i + 1, col = 1)
            fig2.add_trace(go.Violin(y = width_data[model], x0 = j, spanmode = "hard", line_color = color, showlegend = False), row = i + 1, col = 1)
            fig3.add_trace(go.Violin(y = r2_data[model], x0 = j, spanmode = "hard", line_color = color, showlegend = False), row = i + 1, col = 1)

            if j % conditional_scores == 0:
                bg_color = next(bg_colors)
                fig.add_vrect(x0 = j-0.5, x1 = j+(conditional_scores-0.5), fillcolor = bg_color, layer = "below", line_width = 0, opacity = 0.5, row = i + 1, col = 1)
                fig2.add_vrect(x0 = j-0.5, x1 = j+(conditional_scores-0.5), fillcolor = bg_color, layer = "below", line_width = 0, opacity = 0.5, row = i + 1, col = 1)
                fig3.add_vrect(x0 = j-0.5, x1 = j+(conditional_scores-0.5), fillcolor = bg_color, layer = "below", line_width = 0, opacity = 0.5, row = i + 1, col = 1)

                if i == 0:
                    fig.add_trace(go.Scatter(x=[None], y=[None], mode="markers", name = columntext[j][0], marker = dict(size = 7, color = bg_color, symbol = 'square')), row = 1, col = 1)
                    fig2.add_trace(go.Scatter(x=[None], y=[None], mode="markers", name = columntext[j][0], marker = dict(size = 7, color = bg_color, symbol = 'square')), row = 1, col = 1)
                    fig3.add_trace(go.Scatter(x=[None], y=[None], mode="markers", name = columntext[j][0], marker = dict(size = 7, color = bg_color, symbol = 'square')), row = 1, col = 1)

        fig.add_trace(go.Scatter(x = [-.5, len(columns)-.5], y = [1 - alpha, 1 - alpha], line = dict(dash = 'dot', color = "red"), name = "Target coverage", marker = {"opacity": 0}, showlegend = False), row = i+1, col = 1)

    figparams = {"yaxis" + str(i+1): {"range": [np.floor(np.min(coverages[:, :, i]) * 10) / 10, np.ceil(np.max(coverages[:, :, i]) * 10) / 10]} for i in range(var_bins)}
    figparams.update({"xaxis" + str(i+1): {"range": [vals[0], vals[-1]], "tickvals": vals, "ticktext": text, "showticklabels": False} for i in range(var_bins-1)})
    figparams.update({"xaxis" + str(var_bins): {"tickangle": 45, "range": [vals[0], vals[-1]], "tickvals": vals, "ticktext": text}})
    figparams.update(dict(font = {"size": 15}, title_font = {"size": 20}, title_x = .45, title_y = .99, width = 1000, margin = dict(l = 0, r = 0, t = 50, b = 0), title_automargin = True))
    fig.update_layout(figparams)
    
    fig.update_layout(title_text = f'{title}: Conditional PI coverage')

    fig.write_image(folder + datasource + "_conditional_coverage-" + str(alpha) + "_synth.svg", scale = 3)
    fig.write_image(folder + datasource + "_conditional_coverage-" + str(alpha) + "_synth.png", scale = 3)

    figparams = {"xaxis" + str(i+1): {"range": [vals[0], vals[-1]], "tickvals": vals, "ticktext": text, "showticklabels": False} for i in range(var_bins-1)}
    figparams.update({"xaxis" + str(var_bins): {"tickangle": 45, "range": [vals[0], vals[-1]], "tickvals": vals, "ticktext": text}})
    fig2.update_layout(figparams)
    fig2.update_layout(title_text = f'{title}: Conditional PI width')

    fig2.write_image(folder + datasource + "_conditional_widths-" + str(alpha) + "_synth.svg", scale = 3)
    fig2.write_image(folder + datasource + "_conditional_widths-" + str(alpha) + "_synth.png", scale = 3)

    fig3.update_layout(figparams)
    fig3.update_layout(title_text = rf'$\text{{{title}: Conditional }}R^2\text{{-value}}$')

    fig3.write_image(folder + datasource + "_conditional_r2s-" + str(alpha) + "_synth.svg", scale = 3)
    fig3.write_image(folder + datasource + "_conditional_r2s-" + str(alpha) + "_synth.png", scale = 3)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]