## R -> to_Theta, Python transform_intercepts_continpus

In [None]:
import torch
torch.set_printoptions(precision=10)

def transform_intercepts_continous(theta_tilde:torch.Tensor) -> torch.Tensor:
    
    """
    Transforms the unordered theta_tilde to ordered theta values for the bernstein polynomial
    E.G: 
    theta_1 = theta_tilde_1
    theta_2 = theta_tilde_1 + exp(theta_tilde_2)
    ..
    :param theta_tilde: The unordered theta_tilde values
    :return: The ordered theta values
    """

    # Compute the shift based on the last dimension size
    last_dim_size = theta_tilde.shape[-1]
    shift = torch.log(torch.tensor(2.0)) * last_dim_size / 2

    # Get the width values by applying softplus from the second position onward
    widths = torch.nn.functional.softplus(theta_tilde[..., 1:])

    # Concatenate the first value (raw) with the softplus-transformed widths
    widths = torch.cat([theta_tilde[..., [0]], widths], dim=-1)

    # Return the cumulative sum minus the shift
    return torch.cumsum(widths, dim=-1) - shift


In [10]:
# same tensor as in R code 2 samples

#theat_thilde
# tf.Tensor(
# [[[ 0.01346988 -0.03040301 -0.0009475   0.03353934  0.03301353
#    -0.04421922  0.02690415  0.0534528  -0.07279044  0.06620763
#    -0.00113248  0.09650994  0.02934606  0.00773158 -0.10510534
#    -0.01168713 -0.03555111  0.019867   -0.09727976 -0.0474655 ]
#   [ 0.14461783 -0.00913084 -0.02329957 -0.01995979 -0.07978858
#     0.08354648  0.07621053  0.03968899  0.04010997 -0.07187632
#    -0.02174324 -0.02492543  0.04793847  0.03261854 -0.01283982
#    -0.01570366 -0.01853073  0.04250432  0.0073978   0.03714988]
#   [ 0.00787929 -0.05027757 -0.01628618 -0.01388468  0.03728567
#    -0.02157371  0.00681018 -0.01150865 -0.08478781 -0.03085445
#     0.03231757  0.05626081 -0.0752314  -0.01583911  0.06897556
#    -0.06667089 -0.04729394 -0.02368151  0.03042696 -0.00568471]]

#  [[ 0.01346988 -0.03040301 -0.0009475   0.03353934  0.03301353
#    -0.04421922  0.02690415  0.0534528  -0.07279044  0.06620763
#    -0.00113248  0.09650994  0.02934606  0.00773158 -0.10510534
#    -0.01168713 -0.03555111  0.019867   -0.09727976 -0.0474655 ]
#   [ 0.14461783 -0.00913084 -0.02329957 -0.01995979 -0.07978858
#     0.08354648  0.07621053  0.03968899  0.04010997 -0.07187632
#    -0.02174324 -0.02492543  0.04793847  0.03261854 -0.01283982
#    -0.01570366 -0.01853073  0.04250432  0.0073978   0.03714988]
#   [ 0.00787929 -0.05027757 -0.01628618 -0.01388468  0.03728567
#    -0.02157371  0.00681018 -0.01150865 -0.08478781 -0.03085445
#     0.03231757  0.05626081 -0.0752314  -0.01583911  0.06897556
#    -0.06667089 -0.04729394 -0.02368151  0.03042696 -0.00568471]]], shape=(2, 3, 20), dtype=float32)

# only the first since this funciton accepts only 1  target

In [11]:
r_vgl=torch.Tensor([[ 0.01346988, -0.03040301, -0.0009475,   0.03353934,  0.03301353, 
              -0.04421922,  0.02690415,  0.0534528,  -0.07279044,  0.06620763,
   -0.00113248,  0.09650994,  0.02934606,  0.00773158, -0.10510534,
   -0.01168713, -0.03555111,  0.019867,   -0.09727976, -0.0474655 ],
                    
    [ 0.01346988, -0.03040301, -0.0009475,   0.03353934,  0.03301353,
   -0.04421922,  0.02690415,  0.0534528,  -0.07279044,  0.06620763,
   -0.00113248,  0.09650994,  0.02934606,  0.00773158, -0.10510534,
   -0.01168713, -0.03555111,  0.019867,   -0.09727976, -0.0474655 ]])
r_vgl   # same data see the to samples are the same

tensor([[ 0.0134698804, -0.0304030105, -0.0009475000,  0.0335393399,
          0.0330135301, -0.0442192182,  0.0269041508,  0.0534528010,
         -0.0727904364,  0.0662076324, -0.0011324800,  0.0965099409,
          0.0293460600,  0.0077315802, -0.1051053405, -0.0116871297,
         -0.0355511084,  0.0198669992, -0.0972797573, -0.0474654995],
        [ 0.0134698804, -0.0304030105, -0.0009475000,  0.0335393399,
          0.0330135301, -0.0442192182,  0.0269041508,  0.0534528010,
         -0.0727904364,  0.0662076324, -0.0011324800,  0.0965099409,
          0.0293460600,  0.0077315802, -0.1051053405, -0.0116871297,
         -0.0355511084,  0.0198669992, -0.0972797573, -0.0474654995]])

In [12]:
transform_intercepts_continous(r_vgl)

tensor([[-6.9180021286, -6.2399406433, -5.5472669601, -4.8372097015,
         -4.1274194717, -3.4561374187, -2.7494478226, -2.0292172432,
         -1.3718028069, -0.6450042725,  0.0475769043,  0.7901430130,
          1.4980707169,  2.1950912476,  2.8370656967,  3.5243864059,
          4.1999168396,  4.9030466080,  5.5487365723,  6.2184324265],
        [-6.9180021286, -6.2399406433, -5.5472669601, -4.8372097015,
         -4.1274194717, -3.4561374187, -2.7494478226, -2.0292172432,
         -1.3718028069, -0.6450042725,  0.0475769043,  0.7901430130,
          1.4980707169,  2.1950912476,  2.8370656967,  3.5243864059,
          4.1999168396,  4.9030466080,  5.5487365723,  6.2184324265]])

In [None]:

# theta after to_theta()
# tf.Tensor(
# [[[-6.918002   -6.2399406  -5.547267   -4.8372097  -4.1274195
#    -3.4561374  -2.7494478  -2.0292172  -1.3718033  -0.6450043
#     0.0475769   0.790143    1.4980707   2.1950912   2.8370657
#     3.5243864   4.199916    4.9030457   5.5487356   6.2184315 ]
#   [-6.786854   -6.098262   -5.4166965  -4.7334795  -4.079431
#    -3.3436384  -2.61166    -1.8984714  -1.1850681  -0.5272136
#     0.15512085  0.83588314  1.5532866   2.2628756   2.949623
#     3.6349497   4.3188744   5.0334997   5.7303524   6.4422474 ]
#   [-6.9235926  -6.255268   -5.570231   -4.8840017  -4.172038
#    -3.4896195  -2.7930613  -2.1056519  -1.454      -0.7761612
#    -0.06672478  0.65494823  1.3111868   1.9964457   2.7246752
#     3.3850422   4.054822    4.7361984   5.4446745   6.134983  ]]

#  [[-6.918002   -6.2399406  -5.547267   -4.8372097  -4.1274195
#    -3.4561374  -2.7494478  -2.0292172  -1.3718033  -0.6450043
#     0.0475769   0.790143    1.4980707   2.1950912   2.8370657
#     3.5243864   4.199916    4.9030457   5.5487356   6.2184315 ]
#   [-6.786854   -6.098262   -5.4166965  -4.7334795  -4.079431
#    -3.3436384  -2.61166    -1.8984714  -1.1850681  -0.5272136
#     0.15512085  0.83588314  1.5532866   2.2628756   2.949623
#     3.6349497   4.3188744   5.0334997   5.7303524   6.4422474 ]
#   [-6.9235926  -6.255268   -5.570231   -4.8840017  -4.172038
#    -3.4896195  -2.7930613  -2.1056519  -1.454      -0.7761612
#    -0.06672478  0.65494823  1.3111868   1.9964457   2.7246752
#     3.3850422   4.054822    4.7361984   5.4446745   6.134983  ]]], shape=(2, 3, 20), dtype=float32)

- works Same output as R CODE 

hi 

In [None]:
def contram_nll(outputs,targets):
    int_in = outputs['int_out']
    shift_in = outputs['shift_out']

In [None]:
h = hi + H_ls t hcs