# Imports:

In [1]:
import pickle
from collections import defaultdict
import numpy as np
import pandas as pd
from scipy.special import expit, logit
from tqdm import tqdm, trange

import warnings
# Ignore all FutureWarnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import sys
sys.path.append("../")
from data_generating_utils import generate_population_data
from ddc_utils import (
    compute_all_three_logistic_models,
    compute_average_jn,
    get_pop_gs_for_binary_y,
    is_binomial_data_seperable,
)

# Hyperparams:

In [3]:
population_size = 100_000

number_of_coefficients = 1

num_iters_per_population = 10_000
true_beta = 1

In [4]:
ALL_SAMPLE_SIZES = [1000]

In [5]:
sample_probability_centering = 0.77
sample_probability_bias_factor = 1

In [6]:
link = 'Logit'

In [7]:
# for seeding + saving:
pop_index = 1
iter_val = 0
rand_generator = np.random.default_rng(seed=333 * pop_index + iter_val)

# Run:

In [8]:
feature_cols = [f'x_{i}' for i in range(number_of_coefficients)]

In [9]:
# hypervariabes where things will be saved as key of sample size -> list.
all_jns_per_sample_size_biased = defaultdict(list)
all_ddc_per_sample_size_biased = defaultdict(list)
all_sample_beta_per_sample_size_biased = defaultdict(list)

all_jns_per_sample_size_full = defaultdict(list)
all_ddc_per_sample_size_full = defaultdict(list)
all_sample_beta_per_sample_size_full = defaultdict(list)

all_realized_sample_sizes_per_sample_size = defaultdict(list)
all_pop_beta_per_sample_size = defaultdict(list)

sample_specific_non_separable_count = {}

In [None]:
pop_data = generate_population_data(
    population_size,
    number_of_coefficients,
    rand_generator,
    feature_cols=feature_cols,
    true_beta=true_beta,
    link=link,
)

pop_data["marginal_probs"] = expit(
    logit(sample_probability_centering)
    + sample_probability_bias_factor
    * (2 * pop_data["y"] - 1)
    * pop_data["x_0"]
)

In [12]:
pickle_filename = f'base_population_data_{link}_{true_beta}.pickle'
with open(pickle_filename, 'wb') as handle:
    pickle.dump(pop_data, handle)

In [13]:
model_iteration = [(0, 'Logit'), (1, 'Probit'), (2, 'CLogLog')]

## get population-level statistics:

In [14]:
pop_x = pop_data[feature_cols]
pop_y = pop_data['y']

In [15]:
# compute the population models, saving their betas and gs:

population_models = (
    compute_all_three_logistic_models(pop_x, pop_y)
)

pop_logit_model, pop_probit_model, pop_cloglog_model = population_models

pop_betas = [
    np.array(pop_logit_model.params),
    np.array(pop_probit_model.params),
    np.array(pop_cloglog_model.params),
]

Note that for the probit link, the link function is $g(\mu) = \Phi^{-1}(\mu)$, and so: 
$$\frac{dg}{d\mu} = \frac{1}{\phi(\Phi^{-1}(\mu))} = \frac{1}{\phi(x'\beta)} \implies \frac{d\mu}{dg} = \phi(x'\beta).$$

Additionally, for the cloglog link, the link function is $g(\mu) = \ln(-\ln(1-\mu))$, and so: 
$$\frac{dg}{d\mu} = \frac{1}{\ln(1 - \mu)(\mu - 1)}  \implies \frac{d\mu}{dg} = \ln(1 - \mu)(\mu - 1).$$



In [16]:
pop_gs = get_pop_gs_for_binary_y(population_models, pop_x, pop_y, population_size)

## Iterate per sample size:

In [17]:
for temp_sample_size in tqdm(ALL_SAMPLE_SIZES):
    non_separable_count = 0

    for _ in trange(num_iters_per_population, mininterval=10):
        # use sampling scheme to sample data:
        obtained_valid_sample = False

        while not obtained_valid_sample:
            pop_data["r0"] = 0
            pop_data["r"] = 0
            
            pop_data.loc[
                np.random.choice(pop_data.index, size=temp_sample_size, replace=False),
                "r0",
            ] = 1
            full_sampled_data = pop_data[pop_data["r0"] == 1]
            
            other_sample_indices = full_sampled_data.index[
                rand_generator.binomial(n=1, p=full_sampled_data["marginal_probs"]) == 1
            ]
            pop_data.loc[other_sample_indices, "r"] = 1
            
            # sample_data here means the biased sample data.
            sample_data = pop_data[pop_data["r"] == 1]
            realised_sample_size = len(other_sample_indices)

            if realised_sample_size < 1_000:
                if is_binomial_data_seperable(sample_data, "y", "x_0"):
                    non_separable_count = non_separable_count + 1
                    continue
            obtained_valid_sample = True

        # compute biased x, y, model, beta
        sample_x = sample_data[feature_cols]
        sample_y = sample_data["y"]
        sample_models = compute_all_three_logistic_models(sample_x, sample_y)
        sample_betas = [sample_model.params for sample_model in sample_models]
        sample_r = pop_data["r"]

        # compute full x, y, model, beta
        sample_x_full = full_sampled_data[feature_cols]
        sample_y_full = full_sampled_data["y"]
        sample_models_full = compute_all_three_logistic_models(
            sample_x_full, sample_y_full
        )
        sample_betas_full = [
            sample_model_full.params for sample_model_full in sample_models_full
        ]
        sample_r_full = pop_data["r0"]

        # compute biased versions of things:
        all_sample_beta_per_sample_size_biased[temp_sample_size].append(
            [pd.Series(sample_beta) for sample_beta in sample_betas]
        )
        all_ddc_per_sample_size_biased[temp_sample_size].append(
            [
                pop_gs[link_fn].corrwith(sample_r)[["x_0"]]
                for _, link_fn in model_iteration
            ]
        )
        all_jns_per_sample_size_biased[temp_sample_size].append(
            [
                compute_average_jn(
                    pop_betas[model_index],
                    sample_betas[model_index],
                    sample_x,
                    sample_y,
                    link_fn=link_fn,
                )
                for model_index, link_fn in model_iteration
            ]
        )

        all_realized_sample_sizes_per_sample_size[temp_sample_size].append(
            realised_sample_size
        )

        # compute full versions of things:
        all_sample_beta_per_sample_size_full[temp_sample_size].append(
            [pd.Series(sample_beta) for sample_beta in sample_betas_full]
        )
        all_ddc_per_sample_size_full[temp_sample_size].append(
            [
                pop_gs[link_fn].corrwith(sample_r_full)[["x_0"]]
                for _, link_fn in model_iteration
            ]
        )
        all_jns_per_sample_size_full[temp_sample_size].append(
            [
                compute_average_jn(
                    pop_betas[model_index],
                    sample_betas_full[model_index],
                    sample_x_full,
                    sample_y_full,
                    link_fn=link_fn,
                )
                for model_index, link_fn in model_iteration
            ]
        )

    sample_specific_non_separable_count[temp_sample_size] = non_separable_count

  0%|                                                                                            | 0/1 [00:00<?, ?it/s]
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

[A%|██▋                                                                           | 349/10000 [00:20<04:36, 34.86it/s]
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getat

  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

 25%|███████████████████▎                                                         | 2502/10000 [03:25<10:46, 11.60it/s][A
  result = getattr(ufunc, method)(*inputs, **kwargs)

 27%|█████████████████████                                                        | 2731/10000 [03:47<10:57, 11.05it/s][A
 28%|█████████████████████▊                                                       | 2839/10000 [03:57<10:52, 10.97it/s][A
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

 32%|████████████████████████▌                                                    | 3188/10000 [04:27<09:58, 11.38it/s][A
  result = getattr(ufunc,

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

 47%|████████████████████████████████████▎                                        | 4717/10000 [06:41<07:53, 11.15it/s][A
  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

 54%|█████████████████████████████████████████▍                                   | 5376/10000 [07:43<07:09, 10.75it/s][A
 55%|██████████████████████████████████████████▏                                  | 5482/10000 [07:53<07:02, 10.69it/s][A
  result = getattr(ufunc, method)(*inputs,


  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

 76%|██████████████████████████████████████████████████████████▋                  | 7625/10000 [11:10<03:30, 11.30it/s][A
 76%|██████████████████████████████████████████████████████████▋                  | 7625/10000 [11:21<03:30, 11.30it/s][A
 77%|███████████████████████████████████████████████████████████▌                 | 7737/10000 [11:21<03:24, 11.07it/s][A
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)

 83%|███████████████████████████████████████████████████████████████▋             | 8277/10000 

## combine data for each sample size:

In [18]:
sample_specific_non_separable_count

{1000: 0}

In [19]:
all_data_per_ss = []

In [20]:
all_raw_data = [
    all_jns_per_sample_size_biased,
    all_ddc_per_sample_size_biased,
    all_sample_beta_per_sample_size_biased,
    all_jns_per_sample_size_full,
    all_ddc_per_sample_size_full,
    all_sample_beta_per_sample_size_full,
    all_realized_sample_sizes_per_sample_size,
    all_pop_beta_per_sample_size,
]

In [21]:
pickle_filename = f'all_raw_data_{link}_{true_beta}.pickle'
with open(pickle_filename, 'wb') as handle:
    pickle.dump(all_raw_data, handle)

In [22]:
pickle_filename = f'pop_betas_{link}_{true_beta}.pickle'
with open(pickle_filename, 'wb') as handle:
    pickle.dump(pop_betas, handle)