In [1]:
import hypothesis as h
import torch
import numpy as np
import matplotlib.pyplot as plt

from hypothesis.train import RatioEstimatorTrainer as Trainer
from hypothesis.nn import build_ratio_estimator
from hypothesis.stat import highest_density_level

from sklearn.metrics import roc_auc_score

from tqdm import tqdm

In [2]:
h.plot.activate()

### Utilities

In [3]:
@torch.no_grad()
def coverage(estimator, cl=0.95, resolution=100, n=10000):
    covered = 0
    prior = torch.distributions.uniform.Uniform(-15, 15)
    linspace = torch.linspace(-15, 15, resolution)
    inputs, outputs = simulate_joint(n)
    for index in range(n):
        nominal = inputs[index].view(-1, 1)
        observable = outputs[index].view(-1, 1)
        log_nominal_pdf = estimator.log_ratio(inputs=nominal, outputs=observable) + prior.log_prob(nominal)
        nominal_pdf = log_nominal_pdf.exp().item()
        observable = observable.repeat(resolution, 1)
        log_ratios = estimator.log_ratio(inputs=linspace, outputs=observable)
        log_pdf = log_ratios + prior.log_prob(nominal)
        pdf = log_pdf.squeeze().exp()
        level = highest_density_level(pdf, 1 - cl)
        if nominal_pdf >= level:
            covered += 1
    emperical_coverage_probability = covered / n
    has_coverage = emperical_coverage_probability >= cl
    
    return has_coverage, emperical_coverage_probability

In [4]:
def jsd(dataset_a, dataset_b, epochs=1, discriminator=None):
    # Prepare the optimization procedure
    batch_size = 128
    ones = torch.ones(batch_size, 1)
    zeros = torch.zeros(batch_size, 1)
    model = h.nn.MLP(shape_xs=(2,), shape_ys=(1,))
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = torch.nn.BCELoss()
    final_loss = np.log(4)
    if discriminator is None:
        for _ in range(epochs):
            loader_a = torch.utils.data.DataLoader(dataset_a, batch_size=batch_size, drop_last=True)
            loader_b = torch.utils.data.DataLoader(dataset_b, batch_size=batch_size, drop_last=True)
            for in_a, in_b in zip(loader_a, loader_b):
                optimizer.zero_grad()
                # Prepare inputs dataset A.
                z_a = torch.cat([in_a["inputs"], in_a["outputs"]], dim=1)
                # Prepare inputs dataset B.
                z_b = torch.cat([in_b["inputs"], in_b["outputs"]], dim=1)
                # Forward pass
                y_a = model(z_a)
                y_b = model(z_b)
                loss_a = criterion(y_a, ones)
                loss_b = criterion(y_b, zeros)
                loss = loss_a + loss_b
                loss.backward()
                optimizer.step()
                final_loss = 0.99 * final_loss + 0.01 * loss.item()
        discriminator = model.eval()
        divergence = np.log(4) - final_loss
        return divergence, discriminator
    else:
        discriminator.eval()
        loader_a = torch.utils.data.DataLoader(dataset_a, batch_size=batch_size, drop_last=True)
        loader_b = torch.utils.data.DataLoader(dataset_b, batch_size=batch_size, drop_last=True)
        for in_a, in_b in zip(loader_a, loader_b):
            # Prepare inputs dataset A.
            z_a = torch.cat([in_a["inputs"], in_a["outputs"]], dim=1)
            # Prepare inputs dataset B.
            z_b = torch.cat([in_b["inputs"], in_b["outputs"]], dim=1)
            # Forward pass
            y_a = discriminator(z_a)
            y_b = discriminator(z_b)
            loss_a = criterion(y_a, ones)
            loss_b = criterion(y_b, zeros)
            loss = loss_a + loss_b
            final_loss = 0.8 * final_loss + 0.2 * loss.item() 
        divergence = np.log(4) - final_loss
        return divergence

In [5]:
class RatioEstimator(h.nn.ratio_estimation.BaseRatioEstimator):

    def __init__(self):
        random_variables = {"inputs": (1,), "outputs": (1,)}
        Class = build_ratio_estimator("mlp", random_variables)
        activation = torch.nn.SELU
        trunk = [128] * 3
        r = Class(activation=activation, trunk=trunk)
        super(RatioEstimator, self).__init__(r=r)
        self._r = r

    def log_ratio(self, **kwargs):
        return self._r.log_ratio(**kwargs)

In [6]:
@torch.no_grad()
def weights(estimator, inputs, outputs):
    ratios = estimator.log_ratio(inputs=inputs, outputs=outputs).exp()
    return (ratios / ratios.sum()).squeeze().numpy()

In [7]:
def simulate_joint(n=1000000):
    return simulate(n)

In [8]:
def simulate_marginals(n=1000000):
    inputs, outputs = simulate(n)
    indices = torch.randperm(n)
    inputs = inputs[indices, :]
    
    return inputs, outputs

In [9]:
def simulate(n):
    prior = torch.distributions.uniform.Uniform(-15, 15)
    inputs = prior.sample((n, 1,))
    outputs = torch.from_numpy(np.random.random(size=n).reshape(-1, 1)) + inputs
    inputs = inputs.float()
    outputs = outputs.float()
    
    return inputs, outputs

In [10]:
def train(epochs=1, n_train=1000000, conservative=False):
    # Simulate data
    inputs, outputs = simulate(n_train)
    # Prepare the model
    batch_size = 1024
    d_inputs = torch.utils.data.TensorDataset(inputs)
    d_outputs = torch.utils.data.TensorDataset(outputs)
    dataset = h.util.data.NamedDataset(inputs=d_inputs, outputs=d_outputs)
    model = RatioEstimator()
    # Allocate the optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
    if conservative:
        criterion = h.nn.ratio_estimation.BalancedCriterion(model, batch_size, gamma=100.0)
    else:
        criterion = h.nn.ratio_estimation.BaseCriterion(model, batch_size)
    # Prepare the trainer
    trainer = Trainer(
        batch_size=batch_size,
        criterion=criterion,
        dataset_train=dataset,
        epochs=epochs,
        optimizer=optimizer,
        shuffle=True,
        workers=2)
    trainer.fit()
    
    return model.eval()

### Reweighting product of the marginals

In [11]:
model = train(epochs=2, n_train=1000000)

KeyboardInterrupt: 

In [None]:
n = 1000000
inputs, outputs = simulate_marginals(n)
w = weights(model, inputs, outputs)

In [None]:
n_reweighted = 10000

reweighted_indices = np.random.choice(torch.arange(n), size=n_reweighted, replace=False, p=w)
rw_inputs = inputs[reweighted_indices, :]
rw_outputs = outputs[reweighted_indices, :]

d_reweighted = h.util.data.NamedDataset(
    inputs=torch.utils.data.TensorDataset(rw_inputs),
    outputs=torch.utils.data.TensorDataset(rw_outputs))

In [None]:
j_inputs, j_outputs = simulate_joint(n_reweighted)
d_joint = h.util.data.NamedDataset(
    inputs=torch.utils.data.TensorDataset(j_inputs),
    outputs=torch.utils.data.TensorDataset(j_outputs))

In [None]:
divergence, discriminator = jsd(d_joint, d_reweighted, epochs=2)

print("Train divergence:", divergence)

In [None]:
n = 1000000
inputs, outputs = simulate_marginals(n)
w = weights(model, inputs, outputs)

n_reweighted = 25000

reweighted_indices = np.random.choice(torch.arange(n), size=n_reweighted, replace=False, p=w)
rw_inputs = inputs[reweighted_indices, :]
rw_outputs = outputs[reweighted_indices, :]

d_reweighted = h.util.data.NamedDataset(
    inputs=torch.utils.data.TensorDataset(rw_inputs),
    outputs=torch.utils.data.TensorDataset(rw_outputs))

j_inputs, j_outputs = simulate_joint(n_reweighted)
d_joint = h.util.data.NamedDataset(
    inputs=torch.utils.data.TensorDataset(j_inputs),
    outputs=torch.utils.data.TensorDataset(j_outputs))

divergence = jsd(d_joint, d_reweighted, discriminator=discriminator)

print("Test divergence:", divergence)

#### AUC with respect to reweighted product of the marginals

In [None]:
ones = np.ones((n_reweighted))
zeros = np.zeros((n_reweighted))
y = np.hstack((ones, zeros))

z_joint = torch.cat([j_inputs, j_outputs], dim=1)
y_ones = discriminator(z_joint).detach().numpy()

z_rw = torch.cat([rw_inputs, rw_outputs], dim=1)
y_zeros = discriminator(z_rw).detach().numpy()

x = np.vstack((y_ones, y_zeros))
roc_auc_score(y, x)

### Coverage

In [None]:
coverage(model)

## Merging for evaluation

In [None]:
def evaluate(epochs=1, n_train=500000, n_reweighted=10000, conservative=False):
    assert n_reweighted < n_train
    
    # Train model
    model = train(epochs=epochs, n_train=n_train, conservative=conservative)
    
    # Compute JSD
    ## Train
    inputs, outputs = simulate_marginals(n_train)
    w = weights(model, inputs, outputs)
    reweighted_indices = np.random.choice(torch.arange(n_train), size=n_reweighted, replace=False, p=w)
    rw_inputs = inputs[reweighted_indices, :]
    rw_outputs = outputs[reweighted_indices, :]
    d_reweighted = h.util.data.NamedDataset(
        inputs=torch.utils.data.TensorDataset(rw_inputs),
        outputs=torch.utils.data.TensorDataset(rw_outputs))
    j_inputs, j_outputs = simulate_joint(n_reweighted)
    d_joint = h.util.data.NamedDataset(
        inputs=torch.utils.data.TensorDataset(j_inputs),
        outputs=torch.utils.data.TensorDataset(j_outputs))
    _, discriminator = jsd(d_joint, d_reweighted, epochs=2)
    
    ## Test
    inputs, outputs = simulate_marginals(n_train)
    w = weights(model, inputs, outputs)
    reweighted_indices = np.random.choice(torch.arange(n_train), size=n_reweighted, replace=False, p=w)
    rw_inputs = inputs[reweighted_indices, :]
    rw_outputs = outputs[reweighted_indices, :]
    d_reweighted = h.util.data.NamedDataset(
        inputs=torch.utils.data.TensorDataset(rw_inputs),
        outputs=torch.utils.data.TensorDataset(rw_outputs))
    j_inputs, j_outputs = simulate_joint(n_reweighted)
    d_joint = h.util.data.NamedDataset(
        inputs=torch.utils.data.TensorDataset(j_inputs),
        outputs=torch.utils.data.TensorDataset(j_outputs))
    divergence = jsd(d_joint, d_reweighted, discriminator=discriminator)
    
    # Compute AUC
    ones = np.ones((n_reweighted))
    zeros = np.zeros((n_reweighted))
    y = np.hstack((ones, zeros))
    z_joint = torch.cat([j_inputs, j_outputs], dim=1)
    y_ones = discriminator(z_joint).detach().numpy()
    z_rw = torch.cat([rw_inputs, rw_outputs], dim=1)
    y_zeros = discriminator(z_rw).detach().numpy()
    x = np.vstack((y_ones, y_zeros))
    auc = roc_auc_score(y, x)
    
    # Compute coverage
    coverage_probability = coverage(model)
    
    return divergence, auc, coverage_probability

In [None]:
epochs = np.arange(5) + 1
results = []
for n in tqdm(epochs):
    results.append(evaluate(epochs=n))

In [None]:
results # Columns are: JSD between joint and reweighted product of the marginals, AUC, coverage (this is for 95% CL)

## Conservative evaluation

In [None]:
epochs = np.arange(5) + 1
results_conservative = []
for n in tqdm(epochs):
    results_conservative.append(evaluate(epochs=n, conservative=True))

In [None]:
results_conservative # Columns are: JSD between joint and reweighted product of the marginals, AUC, coverage (this is for 95% CL)