In [1]:
import sys
sys.path.append('../src')
sys.path.append('../src/pgm')
from typing import List, Dict, Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.transforms import Transform
from torch.distributions import constraints
from torch import Tensor

import pyro
import pyro.distributions as dist
from pyro.distributions.conditional import (
    ConditionalTransformModule,
    ConditionalTransformedDistribution
)

from flow_pgm import BasePGM

### Discrete Causal Mechanisms

This notebook presents a discrete causal mechanism (i.e. continuous cause(s) and discrete effect) using a Gumbel-Softmax/Concrete TransformedDistribution. It works by applying a special Softmax bijector (i.e. SoftmaxCentered) to a Gumbel source distribution.

In [2]:
from sklearn.datasets import load_breast_cancer
data = load_breast_cancer()
X = torch.tensor(data["data"]).float()
Y = torch.tensor(data["target"])

def min_max_scale(X: Tensor):
    X_min = X.min(dim=0).values
    return 2 * (X - X_min) / (X.max(dim=0).values - X_min) - 1  # [-1, 1]
X = min_max_scale(X)

N, M = X.shape
K = 2  # simple binary classification task
split = int(0.8 * N)
# NOTE: one-hot encoding to K+1 for training under a softmax bijection (see SoftmaxCentered)
x_train, y_train = X[:split], F.one_hot(Y[:split], num_classes=K + 1).float()
x_valid, y_valid = X[split:], F.one_hot(Y[split:], num_classes=K + 1).float()

eps = 1e-3  # option to relax discrete labels
y_train = y_train * (1 - eps) + eps / (K + 1)
y_valid = y_valid * (1 - eps) + eps / (K + 1)

In [3]:
class SoftmaxCentered(Transform):
    """
    Implements softmax as a bijection, the forward transformation appends a value to the
    input and the inverse removes it. The appended coordinate represents a pivot, e.g., 
    softmax(x) = exp(x-c) / sum(exp(x-c)) where c is the implicit last coordinate.

    Adapted from a Tensorflow implementation: https://tinyurl.com/48vuh7yw 
    """
    domain = constraints.real_vector
    codomain = constraints.simplex

    def __init__(self, temperature: float = 1.):
        super().__init__()
        self.temperature = temperature

    def __call__(self, x: Tensor):
        zero_pad = torch.zeros(*x.shape[:-1], 1, device=x.device)
        x_padded = torch.cat([x, zero_pad], dim=-1)
        return (x_padded / self.temperature).softmax(dim=-1)

    def _inverse(self, y: Tensor):
        log_y = torch.log(y.clamp(min=1e-12))
        unorm_log_probs = log_y[..., :-1] - log_y[..., -1:]
        return unorm_log_probs * self.temperature

    # def log_abs_det_jacobian(self, x: Tensor, y: Tensor): 
    #     """ -log|det(dx/dy)| """
    #     Kplus1 = torch.tensor(x.size(-1) + 1, dtype=x.dtype, device=x.device)
    #     return 0.5 * kp1.log() + torch.sum(x, dim=-1) - \
    #         Kplus1 * F.softplus(torch.logsumexp(x, dim=-1))

    def log_abs_det_jacobian(self, x: Tensor, y: Tensor):
        """ log|det(dy/dx)| """
        Kplus1 = torch.tensor(y.size(-1), dtype=y.dtype, device=y.device)
        return 0.5 * Kplus1.log() + torch.sum(torch.log(y.clamp(min=1e-12)), dim=-1)

    def forward_shape(self, shape: torch.Size):
        return shape[:-1] + (shape[-1] + 1,)  # forward appends one dim

    def inverse_shape(self, shape: torch.Size):
        if shape[-1] <= 1:
            raise ValueError
        return shape[:-1] + (shape[-1] - 1,)  # inverse removes last dim


class ConditionalAffineTransform(ConditionalTransformModule):
    def __init__(self, context_nn: nn.Module, event_dim: int = 0, **kwargs: Any):
        super().__init__(**kwargs)
        self.event_dim = event_dim
        self.context_nn = context_nn

    def condition(self, context: Tensor):
        loc, log_scale = self.context_nn(context)
        return torch.distributions.transforms.AffineTransform(
            loc, F.softplus(log_scale), event_dim=self.event_dim
        )

class BreastPGM(BasePGM):
    def __init__(
        self, 
        hidden_dims: List[int] = [128, 128], 
        num_inputs: int = 30, 
        num_classes: int = 2,
        temperature: float = 1.
    ):
        super().__init__()
        self.variables = {"x": "continuous", "y": "binary"}
        self.affine_transform = ConditionalAffineTransform(
            pyro.nn.DenseNN(
                num_inputs, 
                hidden_dims, 
                param_dims=[num_classes, num_classes], 
                nonlinearity=nn.ReLU()),
            event_dim=1
        )
        self.softmax_transform = SoftmaxCentered(temperature)
        self.x_loc = nn.Parameter(torch.zeros(num_inputs))
        self.x_logscale = nn.Parameter(torch.zeros(num_inputs))
        self.register_buffer("y_base_loc", torch.zeros(num_classes))
        self.register_buffer("y_base_scale", torch.ones(num_classes))

    def model(self):
        pyro.module("PGM", self)
        cause = pyro.sample("x", dist.Normal(
                self.x_loc, F.softplus(self.x_logscale, beta=np.log(2))
            ).to_event(1)
        )
        base_dist = dist.Gumbel(self.y_base_loc, self.y_base_scale).to_event(1)
        flow_dist = ConditionalTransformedDistribution(
            base_dist, [self.affine_transform, self.softmax_transform],
        ).condition(cause)
        effect = pyro.sample("y", flow_dist)
        return {"x": cause, "y": effect}
    
    def cond_model(self, obs: Dict[str, Tensor]):
        with pyro.plate("obs", obs["x"].shape[0]):
            pyro.condition(self.model, data=obs)()
    
    def guide_pass(self, obs: Any):
        pass

    def predict_y(self, obs: Dict[str, Tensor]):
        cond_model = pyro.condition(self.model, data=obs)
        cond_trace = pyro.poutine.trace(cond_model).get_trace()
        return cond_trace.nodes["y"]["value"].argmax(-1)

model = BreastPGM(hidden_dims=[128, 128], num_inputs=M, num_classes=K, temperature=1)
model.cuda()

optimizer = torch.optim.Adam(model.parameters())
elbo_fn = pyro.infer.Trace_ELBO()
pyro.clear_param_store()
x_train, y_train = x_train.cuda(), y_train.cuda() 
x_valid, y_valid = x_valid.cuda(), y_valid.cuda()

for epoch in range(500):
    optimizer.zero_grad()
    loss = elbo_fn.differentiable_loss(
        model.cond_model,
        model.guide_pass, 
        {"x": x_train, "y": y_train}
    ) / x_train.shape[0]
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        print(f'epoch {epoch}: loss = {loss.item():.4f}')
        with torch.no_grad():
            train_acc = (model.predict_y({"x": x_train}) == y_train.argmax(-1)).float().mean()       
            valid_acc = (model.predict_y({"x": x_valid}) == y_valid.argmax(-1)).float().mean()
            print(f'train acc: {train_acc.item():.4f}, valid acc: {valid_acc.item():.4f}')

epoch 0: loss = 17.3793
train acc: 0.4615, valid acc: 0.4386
epoch 50: loss = -14.4376
train acc: 0.5934, valid acc: 0.4649
epoch 100: loss = -16.4228
train acc: 0.7055, valid acc: 0.6842
epoch 150: loss = -19.1945
train acc: 0.9231, valid acc: 0.8333
epoch 200: loss = -21.5080
train acc: 0.9297, valid acc: 0.9561
epoch 250: loss = -22.3139
train acc: 0.9560, valid acc: 0.9386
epoch 300: loss = -22.2545
train acc: 0.9451, valid acc: 0.9211
epoch 350: loss = -24.6088
train acc: 0.9802, valid acc: 0.9649
epoch 400: loss = -26.7565
train acc: 0.9736, valid acc: 0.9561
epoch 450: loss = -27.4272
train acc: 0.9692, valid acc: 0.9737


In [4]:
N_valid = x_valid.shape[0]
counterfactuals = model.counterfactual(
    obs={"x": x_valid, "y": y_valid},
    intervention={"x": 2 * torch.rand_like(x_valid) - 1}, # randomly intervene on cause
)
assert N_valid != (counterfactuals["y"].argmax(-1) == y_valid.argmax(-1)).sum().item()

counterfactuals = model.counterfactual(
    obs={"x": x_valid, "y": y_valid},
    intervention={"y": y_valid}, # "do nothing"
)
assert N_valid == (counterfactuals["y"].argmax(-1) == y_valid.argmax(-1)).sum().item()

In [5]:
N = 1
self = model
base_dist = dist.Gumbel(self.y_base_loc, self.y_base_scale).to_event(1)
flow_dist = ConditionalTransformedDistribution(
    base_dist, [self.affine_transform, self.softmax_transform],
).condition(x_valid[:N])

with torch.no_grad():
    y = flow_dist.sample()
    init_y = y.clone()
    print(f'y: {init_y}')
    for fn in reversed(flow_dist.transforms):
        x = fn.inv(y)
        print(f'f_inv: {fn.inv}, x: {x}')
        y = x

    print(f'x: {x}')
    for fn in flow_dist.transforms:
        y = fn(x)
        print(f'f: {fn}, y: {y}')
        x = y
    assert torch.allclose(init_y, y)

y: tensor([[1.3999e-02, 9.8567e-01, 3.2696e-04]], device='cuda:0')
f_inv: _InverseTransform(SoftmaxCentered()), x: tensor([[3.7569, 8.0112]], device='cuda:0')
f_inv: _InverseTransform(AffineTransform()), x: tensor([[-0.2828,  0.6130]], device='cuda:0')
x: tensor([[-0.2828,  0.6130]], device='cuda:0')
f: AffineTransform(), y: tensor([[3.7569, 8.0112]], device='cuda:0')
f: SoftmaxCentered(), y: tensor([[1.3999e-02, 9.8567e-01, 3.2696e-04]], device='cuda:0')
