In [1]:
import torch
import numpy as np
import torch
import torch
import torch.nn.functional as F

torch.set_printoptions(precision=8)

## R -> to_Theta, Python transform_intercepts_continpus

In [2]:


def transform_intercepts_continous(theta_tilde:torch.Tensor) -> torch.Tensor:
    
    """
    Transforms the unordered theta_tilde to ordered theta values for the bernstein polynomial
    E.G: 
    theta_1 = theta_tilde_1
    theta_2 = theta_tilde_1 + exp(theta_tilde_2)
    ..
    :param theta_tilde: The unordered theta_tilde values
    :return: The ordered theta values
    """

    # Compute the shift based on the last dimension size
    last_dim_size = theta_tilde.shape[-1]
    shift = torch.log(torch.tensor(2.0)) * last_dim_size / 2

    # Get the width values by applying softplus from the second position onward
    widths = torch.nn.functional.softplus(theta_tilde[..., 1:])

    # Concatenate the first value (raw) with the softplus-transformed widths
    widths = torch.cat([theta_tilde[..., [0]], widths], dim=-1)

    # Return the cumulative sum minus the shift
    return torch.cumsum(widths, dim=-1) - shift


In [3]:
# same tensor as in R code 2 samples

# train  data 2 samples 3 vars

# > t_i
# tf.Tensor(
# [[ 0.28106958 -0.33505845  2.945311  ]
#  [ 0.19878516 -0.23470472  0.7153738 ]], shape=(2, 3), dtype=float32)

# h_params
# tf.Tensor(
# [[[ 0.          0.          0.01346988 -0.03040301 -0.0009475
#     0.03353934  0.03301353 -0.04421922  0.02690415  0.0534528
#    -0.07279044  0.06620763 -0.00113248  0.09650994  0.02934606
#     0.00773158 -0.10510534 -0.01168713 -0.03555111  0.019867
#    -0.09727976 -0.0474655 ]
#   [ 0.         -0.00938768  0.14461783 -0.00913084 -0.02329957
#    -0.01995979 -0.07978858  0.08354648  0.07621053  0.03968899
#     0.04010997 -0.07187632 -0.02174324 -0.02492543  0.04793847
#     0.03261854 -0.01283982 -0.01570366 -0.01853073  0.04250432
#     0.0073978   0.03714988]
#   [-0.00955794  0.00252406  0.00787929 -0.05027757 -0.01628618
#    -0.01388468  0.03728567 -0.02157371  0.00681018 -0.01150865
#    -0.08478781 -0.03085445  0.03231757  0.05626081 -0.0752314
#    -0.01583911  0.06897556 -0.06667089 -0.04729394 -0.02368151
#     0.03042696 -0.00568471]]

#  [[ 0.          0.          0.01346988 -0.03040301 -0.0009475
#     0.03353934  0.03301353 -0.04421922  0.02690415  0.0534528
#    -0.07279044  0.06620763 -0.00113248  0.09650994  0.02934606
#     0.00773158 -0.10510534 -0.01168713 -0.03555111  0.019867
#    -0.09727976 -0.0474655 ]
#   [ 0.         -0.00663939  0.14461783 -0.00913084 -0.02329957
#    -0.01995979 -0.07978858  0.08354648  0.07621053  0.03968899
#     0.04010997 -0.07187632 -0.02174324 -0.02492543  0.04793847
#     0.03261854 -0.01283982 -0.01570366 -0.01853073  0.04250432
#     0.0073978   0.03714988]
#   [-0.00955794  0.00178513  0.00787929 -0.05027757 -0.01628618
#    -0.01388468  0.03728567 -0.02157371  0.00681018 -0.01150865
#    -0.08478781 -0.03085445  0.03231757  0.05626081 -0.0752314
#    -0.01583911  0.06897556 -0.06667089 -0.04729394 -0.02368151
#     0.03042696 -0.00568471]]], shape=(2, 3, 22), dtype=float32)

# kmin:
# tf.Tensor([ 0.12124275 -0.80636215 -4.9503837 ], shape=(3), dtype=float32)

# kmax:
# tf.Tensor([0.79330695 0.41867393 4.934168  ], shape=(3), dtype=float32)


#theat_thilde  2 samples 3 vars 20 features
# tf.Tensor(
# [[[ 0.01346988 -0.03040301 -0.0009475   0.03353934  0.03301353
#    -0.04421922  0.02690415  0.0534528  -0.07279044  0.06620763
#    -0.00113248  0.09650994  0.02934606  0.00773158 -0.10510534
#    -0.01168713 -0.03555111  0.019867   -0.09727976 -0.0474655 ]
#   [ 0.14461783 -0.00913084 -0.02329957 -0.01995979 -0.07978858
#     0.08354648  0.07621053  0.03968899  0.04010997 -0.07187632
#    -0.02174324 -0.02492543  0.04793847  0.03261854 -0.01283982
#    -0.01570366 -0.01853073  0.04250432  0.0073978   0.03714988]
#   [ 0.00787929 -0.05027757 -0.01628618 -0.01388468  0.03728567
#    -0.02157371  0.00681018 -0.01150865 -0.08478781 -0.03085445
#     0.03231757  0.05626081 -0.0752314  -0.01583911  0.06897556
#    -0.06667089 -0.04729394 -0.02368151  0.03042696 -0.00568471]]

#  [[ 0.01346988 -0.03040301 -0.0009475   0.03353934  0.03301353
#    -0.04421922  0.02690415  0.0534528  -0.07279044  0.06620763
#    -0.00113248  0.09650994  0.02934606  0.00773158 -0.10510534
#    -0.01168713 -0.03555111  0.019867   -0.09727976 -0.0474655 ]
#   [ 0.14461783 -0.00913084 -0.02329957 -0.01995979 -0.07978858
#     0.08354648  0.07621053  0.03968899  0.04010997 -0.07187632
#    -0.02174324 -0.02492543  0.04793847  0.03261854 -0.01283982
#    -0.01570366 -0.01853073  0.04250432  0.0073978   0.03714988]
#   [ 0.00787929 -0.05027757 -0.01628618 -0.01388468  0.03728567
#    -0.02157371  0.00681018 -0.01150865 -0.08478781 -0.03085445
#     0.03231757  0.05626081 -0.0752314  -0.01583911  0.06897556
#    -0.06667089 -0.04729394 -0.02368151  0.03042696 -0.00568471]]], shape=(2, 3, 20), dtype=float32)

# only the first since this funciton accepts only 1  target


#output h_dag extra

# > h_I
# tf.Tensor(
# [[-0.18897358 -0.08517991  0.1750973 ]
#  [-0.2702095  -0.0309604   0.02766471]], shape=(2, 3), dtype=float32)


### h(0)   and dh(0) 

# > h_dag(R_START, theta)
# tf.Tensor(
# [[0.31085795 0.32204473 0.30668354]
#  [0.31085795 0.32204473 0.30668354]], shape=(2, 3), dtype=float32)

# > h_dag_dash(R_START, theta)
# tf.Tensor(
# [[0.66965276 0.7118679  0.69034123]
#  [0.66965276 0.7118679  0.69034123]], shape=(2, 3), dtype=float32)

#### right h(1)   and dh(1) 

# > h_dag(L_START, theta)
# tf.Tensor(
# [[-0.34583557 -0.33927712 -0.346116  ]
#  [-0.34583557 -0.33927712 -0.346116  ]], shape=(2, 3), dtype=float32)

# > h_dag_dash(L_START, theta)
# tf.Tensor(
# [[0.6780876  0.68857914 0.66835433]
#  [0.6780876  0.68857914 0.66835433]], shape=(2, 3), dtype=float32)


In [4]:
# thetas from Model 

r_vgl=torch.Tensor([[ 0.01346988, -0.03040301, -0.0009475,   0.03353934,  0.03301353, 
              -0.04421922,  0.02690415,  0.0534528,  -0.07279044,  0.06620763,
   -0.00113248,  0.09650994,  0.02934606,  0.00773158, -0.10510534,
   -0.01168713, -0.03555111,  0.019867,   -0.09727976, -0.0474655 ],
                    
    [ 0.01346988, -0.03040301, -0.0009475,   0.03353934,  0.03301353,
   -0.04421922,  0.02690415,  0.0534528,  -0.07279044,  0.06620763,
   -0.00113248,  0.09650994,  0.02934606,  0.00773158, -0.10510534,
   -0.01168713, -0.03555111,  0.019867,   -0.09727976, -0.0474655 ]])
r_vgl   # same data see the to samples are the same

tensor([[ 0.01346988, -0.03040301, -0.00094750,  0.03353934,  0.03301353,
         -0.04421922,  0.02690415,  0.05345280, -0.07279044,  0.06620763,
         -0.00113248,  0.09650994,  0.02934606,  0.00773158, -0.10510534,
         -0.01168713, -0.03555111,  0.01986700, -0.09727976, -0.04746550],
        [ 0.01346988, -0.03040301, -0.00094750,  0.03353934,  0.03301353,
         -0.04421922,  0.02690415,  0.05345280, -0.07279044,  0.06620763,
         -0.00113248,  0.09650994,  0.02934606,  0.00773158, -0.10510534,
         -0.01168713, -0.03555111,  0.01986700, -0.09727976, -0.04746550]])

In [5]:
thetas=transform_intercepts_continous(r_vgl)
thetas

tensor([[-6.91800213, -6.23994064, -5.54726696, -4.83720970, -4.12741947,
         -3.45613742, -2.74944782, -2.02921724, -1.37180281, -0.64500427,
          0.04757690,  0.79014301,  1.49807072,  2.19509125,  2.83706570,
          3.52438641,  4.19991684,  4.90304661,  5.54873657,  6.21843243],
        [-6.91800213, -6.23994064, -5.54726696, -4.83720970, -4.12741947,
         -3.45613742, -2.74944782, -2.02921724, -1.37180281, -0.64500427,
          0.04757690,  0.79014301,  1.49807072,  2.19509125,  2.83706570,
          3.52438641,  4.19991684,  4.90304661,  5.54873657,  6.21843243]])

In [6]:

# theta after to_theta()
# tf.Tensor(
# [[[-6.918002   -6.2399406  -5.547267   -4.8372097  -4.1274195
#    -3.4561374  -2.7494478  -2.0292172  -1.3718033  -0.6450043
#     0.0475769   0.790143    1.4980707   2.1950912   2.8370657
#     3.5243864   4.199916    4.9030457   5.5487356   6.2184315 ]
#   [-6.786854   -6.098262   -5.4166965  -4.7334795  -4.079431
#    -3.3436384  -2.61166    -1.8984714  -1.1850681  -0.5272136
#     0.15512085  0.83588314  1.5532866   2.2628756   2.949623
#     3.6349497   4.3188744   5.0334997   5.7303524   6.4422474 ]
#   [-6.9235926  -6.255268   -5.570231   -4.8840017  -4.172038
#    -3.4896195  -2.7930613  -2.1056519  -1.454      -0.7761612
#    -0.06672478  0.65494823  1.3111868   1.9964457   2.7246752
#     3.3850422   4.054822    4.7361984   5.4446745   6.134983  ]]

#  [[-6.918002   -6.2399406  -5.547267   -4.8372097  -4.1274195
#    -3.4561374  -2.7494478  -2.0292172  -1.3718033  -0.6450043
#     0.0475769   0.790143    1.4980707   2.1950912   2.8370657
#     3.5243864   4.199916    4.9030457   5.5487356   6.2184315 ]
#   [-6.786854   -6.098262   -5.4166965  -4.7334795  -4.079431
#    -3.3436384  -2.61166    -1.8984714  -1.1850681  -0.5272136
#     0.15512085  0.83588314  1.5532866   2.2628756   2.949623
#     3.6349497   4.3188744   5.0334997   5.7303524   6.4422474 ]
#   [-6.9235926  -6.255268   -5.570231   -4.8840017  -4.172038
#    -3.4896195  -2.7930613  -2.1056519  -1.454      -0.7761612
#    -0.06672478  0.65494823  1.3111868   1.9964457   2.7246752
#     3.3850422   4.054822    4.7361984   5.4446745   6.134983  ]]], shape=(2, 3, 20), dtype=float32)

- works Same output as R CODE 

## r -> h_dag extra python  

In [7]:
# same tensor as in R code 2 samples

# train  data 2 samples 3 vars
# tf.Tensor(
# [[ 0.28106958 -0.33505845  2.945311  ]
#  [ 0.19878516 -0.23470472  0.7153738 ]], shape=(2, 3), dtype=float32)

# h_params
# tf.Tensor(
# [[[ 0.          0.          0.01346988 -0.03040301 -0.0009475
#     0.03353934  0.03301353 -0.04421922  0.02690415  0.0534528
#    -0.07279044  0.06620763 -0.00113248  0.09650994  0.02934606
#     0.00773158 -0.10510534 -0.01168713 -0.03555111  0.019867
#    -0.09727976 -0.0474655 ]
#   [ 0.         -0.00938768  0.14461783 -0.00913084 -0.02329957
#    -0.01995979 -0.07978858  0.08354648  0.07621053  0.03968899
#     0.04010997 -0.07187632 -0.02174324 -0.02492543  0.04793847
#     0.03261854 -0.01283982 -0.01570366 -0.01853073  0.04250432
#     0.0073978   0.03714988]
#   [-0.00955794  0.00252406  0.00787929 -0.05027757 -0.01628618
#    -0.01388468  0.03728567 -0.02157371  0.00681018 -0.01150865
#    -0.08478781 -0.03085445  0.03231757  0.05626081 -0.0752314
#    -0.01583911  0.06897556 -0.06667089 -0.04729394 -0.02368151
#     0.03042696 -0.00568471]]

#  [[ 0.          0.          0.01346988 -0.03040301 -0.0009475
#     0.03353934  0.03301353 -0.04421922  0.02690415  0.0534528
#    -0.07279044  0.06620763 -0.00113248  0.09650994  0.02934606
#     0.00773158 -0.10510534 -0.01168713 -0.03555111  0.019867
#    -0.09727976 -0.0474655 ]
#   [ 0.         -0.00663939  0.14461783 -0.00913084 -0.02329957
#    -0.01995979 -0.07978858  0.08354648  0.07621053  0.03968899
#     0.04010997 -0.07187632 -0.02174324 -0.02492543  0.04793847
#     0.03261854 -0.01283982 -0.01570366 -0.01853073  0.04250432
#     0.0073978   0.03714988]
#   [-0.00955794  0.00178513  0.00787929 -0.05027757 -0.01628618
#    -0.01388468  0.03728567 -0.02157371  0.00681018 -0.01150865
#    -0.08478781 -0.03085445  0.03231757  0.05626081 -0.0752314
#    -0.01583911  0.06897556 -0.06667089 -0.04729394 -0.02368151
#     0.03042696 -0.00568471]]], shape=(2, 3, 22), dtype=float32)

# kmin:
# tf.Tensor([ 0.12124275 -0.80636215 -4.9503837 ], shape=(3), dtype=float32)

# kmax:
# tf.Tensor([0.79330695 0.41867393 4.934168  ], shape=(3), dtype=float32)


#theat_thilde  2 samples 3 vars 20 features
# tf.Tensor(
# [[[ 0.01346988 -0.03040301 -0.0009475   0.03353934  0.03301353
#    -0.04421922  0.02690415  0.0534528  -0.07279044  0.06620763
#    -0.00113248  0.09650994  0.02934606  0.00773158 -0.10510534
#    -0.01168713 -0.03555111  0.019867   -0.09727976 -0.0474655 ]
#   [ 0.14461783 -0.00913084 -0.02329957 -0.01995979 -0.07978858
#     0.08354648  0.07621053  0.03968899  0.04010997 -0.07187632
#    -0.02174324 -0.02492543  0.04793847  0.03261854 -0.01283982
#    -0.01570366 -0.01853073  0.04250432  0.0073978   0.03714988]
#   [ 0.00787929 -0.05027757 -0.01628618 -0.01388468  0.03728567
#    -0.02157371  0.00681018 -0.01150865 -0.08478781 -0.03085445
#     0.03231757  0.05626081 -0.0752314  -0.01583911  0.06897556
#    -0.06667089 -0.04729394 -0.02368151  0.03042696 -0.00568471]]

#  [[ 0.01346988 -0.03040301 -0.0009475   0.03353934  0.03301353
#    -0.04421922  0.02690415  0.0534528  -0.07279044  0.06620763
#    -0.00113248  0.09650994  0.02934606  0.00773158 -0.10510534
#    -0.01168713 -0.03555111  0.019867   -0.09727976 -0.0474655 ]
#   [ 0.14461783 -0.00913084 -0.02329957 -0.01995979 -0.07978858
#     0.08354648  0.07621053  0.03968899  0.04010997 -0.07187632
#    -0.02174324 -0.02492543  0.04793847  0.03261854 -0.01283982
#    -0.01570366 -0.01853073  0.04250432  0.0073978   0.03714988]
#   [ 0.00787929 -0.05027757 -0.01628618 -0.01388468  0.03728567
#    -0.02157371  0.00681018 -0.01150865 -0.08478781 -0.03085445
#     0.03231757  0.05626081 -0.0752314  -0.01583911  0.06897556
#    -0.06667089 -0.04729394 -0.02368151  0.03042696 -0.00568471]]], shape=(2, 3, 20), dtype=float32)

# only the first since this funciton accepts only 1  target


#output h_dag extra

# > h_I
# tf.Tensor(
# [[-0.18897358 -0.08517991  0.1750973 ]
#  [-0.2702095  -0.0309604   0.02766471]], shape=(2, 3), dtype=float32)


### h(0)   and dh(0) 

# > h_dag(R_START, theta)
# tf.Tensor(
# [[0.31085795 0.32204473 0.30668354]
#  [0.31085795 0.32204473 0.30668354]], shape=(2, 3), dtype=float32)

# > h_dag_dash(R_START, theta)
# tf.Tensor(
# [[0.66965276 0.7118679  0.69034123]
#  [0.66965276 0.7118679  0.69034123]], shape=(2, 3), dtype=float32)

#### right h(1)   and dh(1) 

# > h_dag(L_START, theta)
# tf.Tensor(
# [[-0.34583557 -0.33927712 -0.346116  ]
#  [-0.34583557 -0.33927712 -0.346116  ]], shape=(2, 3), dtype=float32)

# > h_dag_dash(L_START, theta)
# tf.Tensor(
# [[0.6780876  0.68857914 0.66835433]
#  [0.6780876  0.68857914 0.66835433]], shape=(2, 3), dtype=float32)


## h_dag_extra calls hdag  and hdag dash which call bernstein poly

In [8]:
L_START=0.0001
R_START=1-L_START



### Bernstein basis 

In [9]:
def bernstein_basis(tensor, M):
    """
    Compute the Bernstein basis polynomials for a given input tensor.
    Args:
        tensor (torch.Tensor): Input tensor of shape (n_samples).
        M (int): Degree of the Bernstein polynomial.
    Returns:
        torch.Tensor: Tensor of shape (B, Nodes, M+1) with the Bernstein basis.
    """
    tensor = torch.as_tensor(tensor)
    dtype = tensor.dtype
    M = torch.tensor(M, dtype=dtype, device=tensor.device)

    # Expand dims to allow broadcasting
    tensor_expanded = tensor.unsqueeze(-1)  # shape (B, Nodes, 1)

    # Clip values to avoid log(0)
    eps = torch.finfo(dtype).eps
    tensor_expanded = torch.clamp(tensor_expanded, min=eps, max=1 - eps)

    k_values = torch.arange(M + 1, dtype=dtype, device=tensor.device)  # shape (M+1,)
    
    # Log binomial coefficient: log(M choose k)
    log_binomial_coeff = (
        torch.lgamma(M + 1) 
        - torch.lgamma(k_values + 1) 
        - torch.lgamma(M - k_values + 1)
    )

    # Log powers
    log_powers = (
        k_values * torch.log(tensor_expanded)
        + (M - k_values) * torch.log(1 - tensor_expanded)
    )

    # Bernstein basis in log space
    log_bernstein = log_binomial_coeff + log_powers  # shape (B, Nodes, M+1)

    return torch.exp(log_bernstein)


In [10]:
def test_bernstein_basis_1d():
    samples = torch.tensor([0.2, 0.4, 0.4], dtype=torch.float32)
    M = 4
    output = bernstein_basis(samples, M)

    # Should return shape (N, M+1)
    assert output.shape == (len(samples), M + 1), f"Expected shape {(len(samples), M + 1)}, got {output.shape}"

    # All values should be >= 0
    assert torch.all(output >= 0), "All Bernstein basis values should be non-negative"

    # Each row should sum to ~1
    row_sums = output.sum(dim=1)
    assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-6), f"Sums were not 1: {row_sums}"

test_bernstein_basis_1d()

### hdag

In [11]:
def h_dag(targets: torch.Tensor, thetas: torch.Tensor) -> torch.Tensor:
    """
    Args:
        targets: shape (n,)
        thetas: shape (n, b)
    Returns:
        Tensor: shape (n,)
    """
    _, b = thetas.shape
    B = bernstein_basis(targets, b - 1)  # shape (n, b)
    return torch.mean(B * thetas, dim=1)



print(f'h_dag: {h_dag(L_START, thetas)}')

# > h_dag(L_START, theta)
# tf.Tensor(
# [[-0.34583557 -0.33927712 -0.346116  ]
#  [-0.34583557 -0.33927712 -0.346116  ]], shape=(2, 3), dtype=float32)

print(f'h_dag: {h_dag(R_START, thetas)}')

# > h_dag(R_START, theta)
# tf.Tensor(
# [[0.31085795 0.32204473 0.30668354]
#  [0.31085795 0.32204473 0.30668354]], shape=(2, 3), dtype=float32)

h_dag: tensor([-0.34583560, -0.34583560])
h_dag: tensor([0.31085801, 0.31085801])


###  h_dag_dash

In [12]:
def h_dag_dash(targets: torch.Tensor, thetas: torch.Tensor) -> torch.Tensor:
    """
    Args:
        targets: shape (n,)
        thetas: shape (n, b)
    Returns:
        Tensor: shape (n,)
    """
    _, b = thetas.shape
    dtheta = thetas[:, 1:] - thetas[:, :-1]         # shape (n, b-1)
    B_dash = bernstein_basis(targets, b - 2)        # shape (n, b-1)
    return torch.sum(B_dash * dtheta, dim=1)




print(f'h_dag_dash: {h_dag_dash(L_START, thetas)}')

# > h_dag_dash(L_START, theta)
# tf.Tensor(
# [[0.6780876  0.68857914 0.66835433]
#  [0.6780876  0.68857914 0.66835433]], shape=(2, 3), dtype=float32)



print(f'h_dag_dash: {h_dag_dash(R_START, thetas)}')

# > h_dag_dash(R_START, theta)
# tf.Tensor(
# [[0.66965276 0.7118679  0.69034123]
#  [0.66965276 0.7118679  0.69034123]], shape=(2, 3), dtype=float32)


h_dag_dash: tensor([0.67808759, 0.67808759])
h_dag_dash: tensor([0.66965276, 0.66965276])


###  h_extarpolated

In [13]:
# kmin:
# tf.Tensor([ 0.12124275 -0.80636215 -4.9503837 ], shape=(3), dtype=float32)

# kmax:
# tf.Tensor([0.79330695 0.41867393 4.934168  ], shape=(3), dtype=float32)

In [14]:
targets=torch.tensor([ 0.28106958,0.19878516 ])

# tf.Tensor(
# [[ 0.28106958 -0.33505845  2.945311  ]
#  [ 0.19878516 -0.23470472  0.7153738 ]], shape=(2, 3), dtype=float32)

In [15]:
def h_extrapolated(thetas: torch.Tensor, targets: torch.Tensor, k_min: float, k_max: float) -> torch.Tensor:
    """
    Args:
        thetas: shape (n, b)
        targets: shape (n,)
        k_min: float, lower bound of scaling (not tracked in graph)
        k_max: float, upper bound of scaling (not tracked in graph)
    Returns:
        Tensor of shape (n,)
    """
    # Constants (not part of the graph)
    L_START = 0.0
    R_START = 1.0

    # Detach constants from graph
    L_tensor = torch.tensor(L_START, dtype=targets.dtype, device=targets.device)
    R_tensor = torch.tensor(R_START, dtype=targets.dtype, device=targets.device)

    # Scale targets
    t_i = (targets - k_min) / (k_max - k_min)  # shape (n,)
    t_i_exp = t_i.unsqueeze(-1)  # shape (n, 1)

    # Extrapolation at left (t_i < 0)
    b0 = h_dag(L_tensor.expand_as(targets), thetas).unsqueeze(-1)     # (n, 1)
    slope0 = h_dag_dash(L_tensor.expand_as(targets), thetas).unsqueeze(-1)  # (n, 1)
    h_left = slope0 * (t_i_exp - L_tensor) + b0

    # Start with placeholder
    h = h_left.clone()

    # Mask for left extrapolation
    mask0 = t_i_exp < L_tensor
    h = torch.where(mask0, h_left, t_i_exp)  # placeholder fill

    # Extrapolation at right (t_i > 1)
    b1 = h_dag(R_tensor.expand_as(targets), thetas).unsqueeze(-1)
    slope1 = h_dag_dash(R_tensor.expand_as(targets), thetas).unsqueeze(-1)
    h_right = slope1 * (t_i_exp - R_tensor) + b1

    mask1 = t_i_exp > R_tensor
    h = torch.where(mask1, h_right, h)

    # In-domain: t_i ∈ [0,1]
    mask_mid = (t_i_exp >= L_tensor) & (t_i_exp <= R_tensor)
    h_center = h_dag(t_i, thetas).unsqueeze(-1)
    h = torch.where(mask_mid, h_center, h)

    return h.squeeze(-1)



#output h_dag extra

# > h_I
# tf.Tensor(
# [[-0.18897358 -0.08517991  0.1750973 ]
#  [-0.2702095  -0.0309604   0.02766471]], shape=(2, 3), dtype=float32)




print(f'h_extrapolated: {h_extrapolated(thetas, targets, 0.12124275, 0.79330695)}')

h_extrapolated: tensor([-0.18897405, -0.27021015])


### h_dag_dash_extrapolated

In [16]:
def h_dash_extrapolated(thetas: torch.Tensor, targets: torch.Tensor, k_min: float, k_max: float) -> torch.Tensor:
    """
    Extrapolated version of h_dag_dash for out-of-domain values.
    
    Args:
        t_targetsi: shape (n,)
        thetas: shape (n, b)
        k_min: float (not tracked by autograd)
        k_max: float (not tracked by autograd)
    
    Returns:
        Tensor: shape (n,)
    """
    # Constants
    L_START = 0.0001
    R_START = 1.0-L_START

    # Detach constants from graph
    L_tensor = torch.tensor(L_START, dtype=targets.dtype, device=targets.device)
    R_tensor = torch.tensor(R_START, dtype=targets.dtype, device=targets.device)

    # Scale input
    t_scaled = (targets - k_min) / (k_max - k_min)
    t_exp = t_scaled.unsqueeze(-1)  # shape (n, 1)

    # Left extrapolation: constant slope at L_START
    slope0 = h_dag_dash(L_tensor.expand_as(targets), thetas).unsqueeze(-1)  # (n, 1)
    mask0 = t_exp < L_tensor
    h_dash = torch.where(mask0, slope0, t_exp)  # placeholder init

    # Right extrapolation: constant slope at R_START
    slope1 = h_dag_dash(R_tensor.expand_as(targets), thetas).unsqueeze(-1)  # (n, 1)
    mask1 = t_exp > R_tensor
    h_dash = torch.where(mask1, slope1, h_dash)

    # In-domain interpolation
    mask_mid = (t_exp >= L_tensor) & (t_exp <= R_tensor)
    h_center = h_dag_dash(t_scaled, thetas).unsqueeze(-1)  # (n, 1)
    h_dash = torch.where(mask_mid, h_center, h_dash)

    return h_dash.squeeze(-1)  # shape (n,)



print(f'h_dag_dash_extra: {h_dash_extrapolated(thetas, targets, 0.12124275, 0.79330695)}')


h_dag_dash_extra: tensor([0.69799966, 0.69808316])


## wrap all in loss fct

In [17]:
# skeleton for loss analog to R CODE

def contram_nll(outputs, targets, min_max):
    """
    Args:
        outputs: dict with keys 'int_out' and 'shift_out'
        targets: shape (n,)
        min_max: tuple of two floats or tensors (min, max)
    Returns:
        scalar NLL
    """
    # Ensure min and max are not part of graph
    min_val = torch.tensor(min_max[0], dtype=targets.dtype, device=targets.device)
    max_val = torch.tensor(min_max[1], dtype=targets.dtype, device=targets.device)

    thetas_tilde = outputs['int_out']  # shape (n, b)
    thetas = transform_intercepts_continous(thetas_tilde)

    Shifts = outputs['shift_out']  # shape (n,)
    

    # Compute h
    h_I = h_extrapolated(thetas, targets, min_val, max_val)  # shape (n,)
    h = h_I + torch.sum(Shifts)  # shape (n,)

    # Latent logistic density log-prob
    log_latent_density = -h - 2 * torch.nn.functional.softplus(-h)  # shape (n,)

    # Derivative term (log |h'| - log(scale))
    h_dash = h_dash_extrapolated(thetas, targets, min_val, max_val)  # shape (n,)
    log_hdash = torch.log(torch.abs(h_dash)) - torch.log(max_val - min_val)  # shape (n,)

    # Final NLL
    nll = -torch.mean(log_latent_density + log_hdash)

    return nll


# testit

In [18]:
# targets
# tf.Tensor(
# [[ 0.28106958 -0.33505845  2.945311  ]
#  [ 0.19878516 -0.23470472  0.7153738 ]], shape=(2, 3), dtype=float32)

#outputs from model
#   [-0.00955794  0.00252406  0.00787929 -0.05027757 -0.01628618
#    -0.01388468  0.03728567 -0.02157371  0.00681018 -0.01150865
#    -0.08478781 -0.03085445  0.03231757  0.05626081 -0.0752314
#    -0.01583911  0.06897556 -0.06667089 -0.04729394 -0.02368151
#     0.03042696 -0.00568471]]


In [19]:
# h_params
# tf.Tensor(
# [[[ 0.          0.          0.01346988 -0.03040301 -0.0009475
#     0.03353934  0.03301353 -0.04421922  0.02690415  0.0534528
#    -0.07279044  0.06620763 -0.00113248  0.09650994  0.02934606
#     0.00773158 -0.10510534 -0.01168713 -0.03555111  0.019867
#    -0.09727976 -0.0474655 ]
#   [ 0.         -0.00938768  0.14461783 -0.00913084 -0.02329957
#    -0.01995979 -0.07978858  0.08354648  0.07621053  0.03968899
#     0.04010997 -0.07187632 -0.02174324 -0.02492543  0.04793847
#     0.03261854 -0.01283982 -0.01570366 -0.01853073  0.04250432
#     0.0073978   0.03714988]
#   [-0.00955794  0.00252406  0.00787929 -0.05027757 -0.01628618
#    -0.01388468  0.03728567 -0.02157371  0.00681018 -0.01150865
#    -0.08478781 -0.03085445  0.03231757  0.05626081 -0.0752314
#    -0.01583911  0.06897556 -0.06667089 -0.04729394 -0.02368151
#     0.03042696 -0.00568471]]

#  [[ 0.          0.          0.01346988 -0.03040301 -0.0009475
#     0.03353934  0.03301353 -0.04421922  0.02690415  0.0534528
#    -0.07279044  0.06620763 -0.00113248  0.09650994  0.02934606
#     0.00773158 -0.10510534 -0.01168713 -0.03555111  0.019867
#    -0.09727976 -0.0474655 ]
#   [ 0.         -0.00663939  0.14461783 -0.00913084 -0.02329957
#    -0.01995979 -0.07978858  0.08354648  0.07621053  0.03968899
#     0.04010997 -0.07187632 -0.02174324 -0.02492543  0.04793847
#     0.03261854 -0.01283982 -0.01570366 -0.01853073  0.04250432
#     0.0073978   0.03714988]
#   [-0.00955794  0.00178513  0.00787929 -0.05027757 -0.01628618
#    -0.01388468  0.03728567 -0.02157371  0.00681018 -0.01150865
#    -0.08478781 -0.03085445  0.03231757  0.05626081 -0.0752314
#    -0.01583911  0.06897556 -0.06667089 -0.04729394 -0.02368151
#     0.03042696 -0.00568471]]], shape=(2, 3, 22), dtype=float32)

# kmin:
# tf.Tensor([ 0.12124275 -0.80636215 -4.9503837 ], shape=(3), dtype=float32)

# kmax:
# tf.Tensor([0.79330695 0.41867393 4.934168  ], shape=(3), dtype=float32)

In [None]:
int_out=torch.tensor([[ 0.00787929, -0.05027757, -0.01628618,
   -0.01388468,  0.03728567, -0.02157371,  0.00681018, -0.01150865,
   -0.08478781, -0.03085445,  0.03231757,  0.05626081, -0.0752314,
   -0.01583911,  0.06897556, -0.06667089, -0.04729394, -0.02368151,
    0.03042696, -0.00568471],
                      [  0.00787929, -0.05027757, -0.01628618,
   -0.01388468,  0.03728567, -0.02157371,  0.00681018, -0.01150865,
   -0.08478781, -0.03085445,  0.03231757,  0.05626081, -0.0752314,
   -0.01583911,  0.06897556, -0.06667089, -0.04729394, -0.02368151,
    0.03042696, -0.00568471]])


shifts_out=torch.tensor([[-0.00955794,  0.00252406 ]])


outputs={'int_out':int_out,'shift_out':shifts_out}


min_max=torch.Tensor([-4.9503837, 4.934168])



contram_nll(outputs, targets, min_max)

  min_val = torch.tensor(min_max[0], dtype=targets.dtype, device=targets.device)
  max_val = torch.tensor(min_max[1], dtype=targets.dtype, device=targets.device)


tensor(4.04906702)

In [24]:
int_out=torch.tensor([[ 1.00787929, -0.05027757, -0.01628618,
   -0.01388468,  0.03728567, -0.02157371,  0.00681018, -0.01150865,
   -0.08478781, -0.03085445,  0.3231757,  0.95626081, -0.0752314,
   -0.01583911,  0.07897556, -0.06667089, -0.04729394, -0.92368151,
    0.03042696, -0.0088471],[  0.00787929, -0.05027757, -0.01628618,
   -0.01388468,  0.03728567, -0.02157371,  0.00681018, -0.01150865,
   -0.08478781, -0.03085445,  0.03231757,  0.05626081, -0.0752314,
   -0.01583911,  0.06897556, -0.06667089, -0.04729394, -0.02368151,
    0.03042696, -0.00568471]])


shifts_out=None


outputs={'int_out':int_out,'shift_out':shifts_out}


min_max=torch.Tensor([-4.9503837, 4.934168])



contram_nll(outputs, targets, min_max)

  min_val = torch.tensor(min_max[0], dtype=targets.dtype, device=targets.device)
  max_val = torch.tensor(min_max[1], dtype=targets.dtype, device=targets.device)


TypeError: sum(): argument 'input' (position 1) must be Tensor, not NoneType

In [25]:
int_out.shape

torch.Size([2, 20])

In [22]:
#TODO test fucnitno with r values before mean reduction