In [1]:
import torch
import torch.nn as nn

@torch.no_grad()
def init_last_layer_increasing(module: nn.Module, start: float = -2.0, end: float = 2.0):
    """
    Initialize the weights of the last nn.Linear in `module` so that the outputs
    are linearly increasing between [start, end].

    Assumes the last layer is nn.Linear with shape (n_thetas, in_features).
    """
    # Find last Linear layer
    last_linear = None
    for m in reversed(list(module.modules())):
        if isinstance(m, nn.Linear):
            last_linear = m
            break
    if last_linear is None:
        raise ValueError("No nn.Linear layer found in module.")

    n_thetas = last_linear.out_features
    in_features = last_linear.in_features

    # Desired increasing values
    values = torch.linspace(start, end, steps=n_thetas)

    # If in_features > 1, just repeat values / distribute across input dims
    w = torch.zeros((n_thetas, in_features))
    w[:, 0] = values  # put increasing sequence in first input channel

    last_linear.weight.copy_(w)

    if last_linear.bias is not None:
        last_linear.bias.zero_()  # optional: reset bias to 0

    return last_linear


class SimpleIntercept(nn.Module):
    """
    Intercept term , hI()
    Attributes:
        n_thetas (int): how many output thetas, for ordinal target this is the number of classes - 1, thetas are order of bernsteinpol() in continous case
    """
    def __init__(self, n_thetas=20):
        super(SimpleIntercept, self).__init__()  
        self.fc = nn.Linear(1,n_thetas, bias=False)

    def forward(self, x):
        return self.fc(x)
    

In [2]:
simple = SimpleIntercept(n_thetas=20)

last_layer = init_last_layer_increasing(simple, start=-3.0, end=3.0)

x = torch.ones(1, 1)
out = simple(x)

print("Weights:", last_layer.weight.squeeze())
print("Output :", out.squeeze())


Weights: tensor([-3.0000, -2.6842, -2.3684, -2.0526, -1.7368, -1.4211, -1.1053, -0.7895,
        -0.4737, -0.1579,  0.1579,  0.4737,  0.7895,  1.1053,  1.4211,  1.7368,
         2.0526,  2.3684,  2.6842,  3.0000], grad_fn=<SqueezeBackward0>)
Output : tensor([-3.0000, -2.6842, -2.3684, -2.0526, -1.7368, -1.4211, -1.1053, -0.7895,
        -0.4737, -0.1579,  0.1579,  0.4737,  0.7895,  1.1053,  1.4211,  1.7368,
         2.0526,  2.3684,  2.6842,  3.0000], grad_fn=<SqueezeBackward0>)


In [8]:
class ComplexInterceptDefaultTabular(nn.Module):
    """
    Complex shift term for tabular data. Can be any neural network architecture
    Attributes:
        n_thetas (int): number of features/predictors
    """
    def __init__(self, n_features=1,n_thetas=20):
        super(ComplexInterceptDefaultTabular, self).__init__()
        # Define the layers
        self.fc1 = nn.Linear(n_features, 8)  # First hidden layer (X_i -> 8)
        self.relu1 = nn.ReLU()               # ReLU activation
        self.fc2 = nn.Linear(8, 8)           # Second hidden layer (8 -> 8)
        self.relu2 = nn.ReLU()               # ReLU activation
        self.fc3 = nn.Linear(8, n_thetas, bias=False)  # Output layer (8 -> n_thetas, no bias)
        
    def forward(self, x):
        # Forward pass through the network
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x


In [9]:
complex = ComplexInterceptDefaultTabular(n_thetas=20)

last_layer = init_last_layer_increasing(simple, start=-1.0, end=1.0)

x = torch.ones(1, 1)
out = complex(x)

print("Weights:", last_layer.weight.squeeze())
print("Output :", out.squeeze())

Weights: tensor([-1.0000, -0.8947, -0.7895, -0.6842, -0.5789, -0.4737, -0.3684, -0.2632,
        -0.1579, -0.0526,  0.0526,  0.1579,  0.2632,  0.3684,  0.4737,  0.5789,
         0.6842,  0.7895,  0.8947,  1.0000], grad_fn=<SqueezeBackward0>)
Output : tensor([ 0.0335,  0.0188, -0.0706,  0.0505, -0.0428, -0.0543,  0.0060,  0.0718,
        -0.0494,  0.0033, -0.0310, -0.0633,  0.0071,  0.0415,  0.0021, -0.0073,
        -0.0649,  0.0008,  0.0810, -0.0247], grad_fn=<SqueezeBackward0>)
