In [2]:
import torch
import torch.nn as nn

@torch.no_grad()
def init_last_layer_increasing(module: nn.Module, start: float = -2.0, end: float = 2.0):
    """
    Initialize the weights of the last nn.Linear in `module` so that the outputs
    are linearly increasing between [start, end].

    Assumes the last layer is nn.Linear with shape (n_thetas, in_features).
    """
    # Find last Linear layer
    last_linear = None
    for m in reversed(list(module.modules())):
        if isinstance(m, nn.Linear):
            last_linear = m
            break
    if last_linear is None:
        raise ValueError("No nn.Linear layer found in module.")

    n_thetas = last_linear.out_features
    in_features = last_linear.in_features

    # Desired increasing values
    values = torch.linspace(start, end, steps=n_thetas)

    # If in_features > 1, just repeat values / distribute across input dims
    w = torch.zeros((n_thetas, in_features))
    w[:, 0] = values  # put increasing sequence in first input channel

    last_linear.weight.copy_(w)

    if last_linear.bias is not None:
        last_linear.bias.zero_()  # optional: reset bias to 0

    return last_linear


class SimpleIntercept(nn.Module):
    """
    Intercept term , hI()
    Attributes:
        n_thetas (int): how many output thetas, for ordinal target this is the number of classes - 1, thetas are order of bernsteinpol() in continous case
    """
    def __init__(self, n_thetas=20):
        super(SimpleIntercept, self).__init__()  
        self.fc = nn.Linear(1,n_thetas, bias=False)

    def forward(self, x):
        return self.fc(x)
    

In [2]:
simple = SimpleIntercept(n_thetas=20)

last_layer = init_last_layer_increasing(simple, start=-3.0, end=3.0)

x = torch.ones(1, 1)
out = simple(x)

print("Weights:", last_layer.weight.squeeze())
print("Output :", out.squeeze())


Weights: tensor([-3.0000, -2.6842, -2.3684, -2.0526, -1.7368, -1.4211, -1.1053, -0.7895,
        -0.4737, -0.1579,  0.1579,  0.4737,  0.7895,  1.1053,  1.4211,  1.7368,
         2.0526,  2.3684,  2.6842,  3.0000], grad_fn=<SqueezeBackward0>)
Output : tensor([-3.0000, -2.6842, -2.3684, -2.0526, -1.7368, -1.4211, -1.1053, -0.7895,
        -0.4737, -0.1579,  0.1579,  0.4737,  0.7895,  1.1053,  1.4211,  1.7368,
         2.0526,  2.3684,  2.6842,  3.0000], grad_fn=<SqueezeBackward0>)


In [8]:
class ComplexInterceptDefaultTabular(nn.Module):
    """
    Complex shift term for tabular data. Can be any neural network architecture
    Attributes:
        n_thetas (int): number of features/predictors
    """
    def __init__(self, n_features=1,n_thetas=20):
        super(ComplexInterceptDefaultTabular, self).__init__()
        # Define the layers
        self.fc1 = nn.Linear(n_features, 8)  # First hidden layer (X_i -> 8)
        self.relu1 = nn.ReLU()               # ReLU activation
        self.fc2 = nn.Linear(8, 8)           # Second hidden layer (8 -> 8)
        self.relu2 = nn.ReLU()               # ReLU activation
        self.fc3 = nn.Linear(8, n_thetas, bias=False)  # Output layer (8 -> n_thetas, no bias)
        
    def forward(self, x):
        # Forward pass through the network
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x


In [9]:
complex = ComplexInterceptDefaultTabular(n_thetas=20)

last_layer = init_last_layer_increasing(simple, start=-1.0, end=1.0)

x = torch.ones(1, 1)
out = complex(x)

print("Weights:", last_layer.weight.squeeze())
print("Output :", out.squeeze())

Weights: tensor([-1.0000, -0.8947, -0.7895, -0.6842, -0.5789, -0.4737, -0.3684, -0.2632,
        -0.1579, -0.0526,  0.0526,  0.1579,  0.2632,  0.3684,  0.4737,  0.5789,
         0.6842,  0.7895,  0.8947,  1.0000], grad_fn=<SqueezeBackward0>)
Output : tensor([ 0.0335,  0.0188, -0.0706,  0.0505, -0.0428, -0.0543,  0.0060,  0.0718,
        -0.0494,  0.0033, -0.0310, -0.0633,  0.0071,  0.0415,  0.0021, -0.0073,
        -0.0649,  0.0008,  0.0810, -0.0247], grad_fn=<SqueezeBackward0>)


## inverser transform of thetas continous

In [14]:
def transform_intercepts_continous(theta_tilde:torch.Tensor) -> torch.Tensor:
    
    """
    Transforms the unordered theta_tilde to ordered theta values for the bernstein polynomial
    E.G: 
    theta_1 = theta_tilde_1
    theta_2 = theta_tilde_1 + exp(theta_tilde_2)
    ..
    :param theta_tilde: The unordered theta_tilde values
    :return: The ordered theta values
    """

    # Compute the shift based on the last dimension size
    last_dim_size = theta_tilde.shape[-1]
    shift = torch.log(torch.tensor(2.0)) * last_dim_size / 2

    # Get the width values by applying softplus from the second position onward
    widths = torch.nn.functional.softplus(theta_tilde[..., 1:])

    # Concatenate the first value (raw) with the softplus-transformed widths
    widths = torch.cat([theta_tilde[..., [0]], widths], dim=-1)

    # Return the cumulative sum minus the shift
    return torch.cumsum(widths, dim=-1) - shift


In [None]:

def inverse_transform_intercepts_continous(theta: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """
    Inverse of transform_intercepts_continous.
    Recovers theta_tilde from theta with numerical stability.
    Uses float64 internally, but returns the same dtype as input.
    """
    # Store original dtype and cast to float64 for precision
    orig_dtype = theta.dtype
    theta = theta.to(torch.float64)

    last_dim_size = theta.shape[-1]
    shift = torch.log(torch.tensor(2.0, dtype=torch.float64)) * last_dim_size / 2

    # Undo shift
    theta_shifted = theta + shift

    # Recover widths
    widths = torch.empty_like(theta_shifted)
    widths[..., 0] = theta_shifted[..., 0]
    widths[..., 1:] = theta_shifted[..., 1:] - theta_shifted[..., :-1]

    # Invert softplus (with clamp for numerical stability)
    theta_tilde = torch.empty_like(theta_shifted)
    theta_tilde[..., 0] = widths[..., 0]
    theta_tilde[..., 1:] = torch.log(torch.expm1(torch.clamp(widths[..., 1:], min=eps)))

    return theta_tilde.to(orig_dtype)



 thetas from COLR MODEL in R:

```
library(tram)
weight_init <- read_csv("weight_init.csv")
weight_init$ones <- rep(1, nrow(weight_init))
colr=Colr(x1~ones, data = weight_init,order=19)
colr$weights
colr$theta
```



```
thetas = torch.tensor([
    -2.03300268, -1.43547929, -1.14005127, -1.14005125, -0.13044674,
     0.04618960,  0.04618962,  0.04618963,  0.16484659,  0.16484661,
     0.16484662,  0.16484664,  0.16484665,  0.16484667,  0.16484668,
     0.16484670,  0.16484671,  0.40602064,  1.31082429,  2.38310636
], dtype=torch.float64)
```

In [None]:

## TODO initializet the correct theta_tilde vals from COLR

thetas = torch.tensor([
    -2.03300268, -1.43547929, -1.14005127, -1.14005125, -0.13044674,
     0.04618960,  0.04618962,  0.04618963,  0.16484659,  0.16484661,
     0.16484662,  0.16484664,  0.16484665,  0.16484667,  0.16484668,
     0.16484670,  0.16484671,  0.40602064,  1.31082429,  2.38310636
], dtype=torch.float64)

theta_tilde_recovered = inverse_transform_intercepts_continous(thetas)


thetas_recovered = transform_intercepts_continous(theta_tilde_recovered)

print("Recovered theta_tilde:\n", theta_tilde_recovered)

print("Recovered thetas_recovered:\n", thetas_recovered)


print("Round-trip error:", torch.max(torch.abs(thetas - thetas_recovered)))


Recovered theta_tilde:
 tensor([  4.8985,  -0.2014,  -1.0680, -17.7275,   0.5565,  -1.6440, -17.7275,
        -18.4207,  -2.0716, -17.7275, -18.4207, -17.7275, -18.4207, -17.7275,
        -18.4207, -17.7275, -18.4207,  -1.2992,   0.3862,   0.6534],
       dtype=torch.float64)
Recovered thetas_recovered:
 tensor([-2.0330, -1.4355, -1.1401, -1.1401, -0.1304,  0.0462,  0.0462,  0.0462,
         0.1648,  0.1648,  0.1648,  0.1648,  0.1648,  0.1648,  0.1648,  0.1648,
         0.1648,  0.4060,  1.3108,  2.3831], dtype=torch.float64)
Round-trip error: tensor(1.9047e-08, dtype=torch.float64)


## Inverse Transform ordinal

In [48]:
import torch
def transform_intercepts_ordinal(int_in):
    # get batch size
    bs = int_in.shape[0]

    # Initialize class 0 and K as constants (on same device as input)
    int0 = torch.full((bs, 1), -float('inf'), device=int_in.device)
    intK = torch.full((bs, 1), float('inf'), device=int_in.device)

    # Reshape to match the batch size
    int1 = int_in[:, 0].reshape(bs, 1)

    # Exponentiate and accumulate the values for the transformation
    intk = torch.cumsum(torch.exp(int_in[:, 1:]), dim=1)
    # intk = torch.cumsum(torch.square(int_in[:, 1:]), dim=1)

    # Concatenate intercepts along the second axis (columns)
    int_out = torch.cat([int0, int1, int1 + intk, intK], dim=1)

    return int_out

import torch

def inverse_transform_intercepts_ordinal(int_out: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """
    Inverse of transform_intercepts_ordinal (numerically stable).
    Recovers int_in from int_out.

    Parameters
    ----------
    int_out : torch.Tensor
        Tensor of shape (B, K+1) with -inf at [:,0] and +inf at [:,-1].
    eps : float, optional
        Small value to clamp diffs, avoids log(0) -> -inf.

    Returns
    -------
    int_in : torch.Tensor
        Tensor of shape (B, K-1).
    """
    bs, Kp1 = int_out.shape
    # int_out = [-inf, int1, ..., +inf]
    # Kp1 = K+1 → int_in has K-1 free parameters
    # shape = (B, K-1)
    
    # First parameter is copied
    first = int_out[:, 1:2]

    # Consecutive differences (skip -inf at [:,0] and +inf at [:,-1])
    diffs = int_out[:, 2:-1] - int_out[:, 1:-2]

    # Clamp for stability and invert exp
    rest = torch.log(torch.clamp(diffs, min=eps))

    return torch.cat([first, rest], dim=1)



int_in = torch.randn(4,1)  # (batch=1, K-1=4)

int_out = transform_intercepts_ordinal(int_in)

print("Transformed:\n", int_out)


int_in_recovered = inverse_transform_intercepts_ordinal(int_out)

print("Original:\n", int_in)
print("Recovered:\n", int_in_recovered)
print("Diff:\n", int_in - int_in_recovered)

Transformed:
 tensor([[   -inf,  0.5643,     inf],
        [   -inf,  0.7754,     inf],
        [   -inf, -0.5024,     inf],
        [   -inf, -0.5175,     inf]])
Original:
 tensor([[ 0.5643],
        [ 0.7754],
        [-0.5024],
        [-0.5175]])
Recovered:
 tensor([[ 0.5643],
        [ 0.7754],
        [-0.5024],
        [-0.5175]])
Diff:
 tensor([[0.],
        [0.],
        [0.],
        [0.]])


```
library(MASS)
library(readr)
data2 <- read_csv("tramdag_paper_test_3_9_25.csv")
data2$ones <- rep(1, nrow(data2))
View(data2)

data2$x3 <- factor(data2$x3, ordered = TRUE)


polr_fit=polr(x3~x1+x2, data2,method="logistic")
polr_fit$zeta
```
```
       0|1        1|2        2|3 
-2.0078939  0.4263169  1.0312569 

```

In [52]:
thetas = torch.tensor([[-2.0078939, 0.4263169, 1.0312569]], dtype=torch.float64)

# Build int_out (with -inf and +inf)
int_out = torch.cat([
    torch.full((1,1), -float("inf"), dtype=thetas.dtype),
    thetas,
    torch.full((1,1), float("inf"), dtype=thetas.dtype)
], dim=1)

# Inverse
theta_tilde = inverse_transform_intercepts_ordinal(int_out)
print("Recovered theta_tilde:", theta_tilde)

# Forward
thetas_recovered = transform_intercepts_ordinal(theta_tilde)
print("Recovered thetas:", thetas_recovered)

print("Diff:\n", thetas - thetas_recovered[:, 1:-1])


Recovered theta_tilde: tensor([[-2.0079,  0.8896, -0.5026]], dtype=torch.float64)
Recovered thetas: tensor([[   -inf, -2.0079,  0.4263,  1.0313,     inf]], dtype=torch.float64)
Diff:
 tensor([[0.0000e+00, 2.2204e-16, 2.2204e-16]], dtype=torch.float64)
