# minimal example of ontram implementation
- anlogous to https://github.com/liherz/ontram_pytorch.git

In [1]:
# Load dependencies
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import os
from sklearn.datasets import load_wine
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from utils.graph import *


if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Train with GPU support.")
else:
    device = torch.device('cpu')
    print("No GPU found, train with CPU support.")

Train with GPU support.


In [2]:
experiment_name = "ordinal_example"   ## <--- set experiment name
seed=42
np.random.seed(seed)

LOG_DIR="/home/bule/TramDag/dev_experiment_logs"
EXPERIMENT_DIR = os.path.join(LOG_DIR, experiment_name)
DATA_PATH = EXPERIMENT_DIR # <----------- change to different source if needed
CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, f"configuration.json")

In [3]:
# Load the dataset
wine = load_wine()
wine

{'data': array([[1.423e+01, 1.710e+00, 2.430e+00, ..., 1.040e+00, 3.920e+00,
         1.065e+03],
        [1.320e+01, 1.780e+00, 2.140e+00, ..., 1.050e+00, 3.400e+00,
         1.050e+03],
        [1.316e+01, 2.360e+00, 2.670e+00, ..., 1.030e+00, 3.170e+00,
         1.185e+03],
        ...,
        [1.327e+01, 4.280e+00, 2.260e+00, ..., 5.900e-01, 1.560e+00,
         8.350e+02],
        [1.317e+01, 2.590e+00, 2.370e+00, ..., 6.000e-01, 1.620e+00,
         8.400e+02],
        [1.413e+01, 4.100e+00, 2.740e+00, ..., 6.100e-01, 1.600e+00,
         5.600e+02]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [4]:
df=pd.DataFrame(wine['data'], columns=wine['feature_names'])
df['target']=wine['target']

In [5]:
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

In [6]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 142 entries, 158 to 102
Data columns (total 14 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   alcohol                       142 non-null    float64
 1   malic_acid                    142 non-null    float64
 2   ash                           142 non-null    float64
 3   alcalinity_of_ash             142 non-null    float64
 4   magnesium                     142 non-null    float64
 5   total_phenols                 142 non-null    float64
 6   flavanoids                    142 non-null    float64
 7   nonflavanoid_phenols          142 non-null    float64
 8   proanthocyanins               142 non-null    float64
 9   color_intensity               142 non-null    float64
 10  hue                           142 non-null    float64
 11  od280/od315_of_diluted_wines  142 non-null    float64
 12  proline                       142 non-null    float64
 13  target  

In [7]:
data_type={key:value for key, value in zip(train_df.columns, ['cont']*13+['ord'])}
data_type

{'alcohol': 'cont',
 'malic_acid': 'cont',
 'ash': 'cont',
 'alcalinity_of_ash': 'cont',
 'magnesium': 'cont',
 'total_phenols': 'cont',
 'flavanoids': 'cont',
 'nonflavanoid_phenols': 'cont',
 'proanthocyanins': 'cont',
 'color_intensity': 'cont',
 'hue': 'cont',
 'od280/od315_of_diluted_wines': 'cont',
 'proline': 'cont',
 'target': 'ord'}

In [8]:
configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)
configuration_dict

{'date_of_creation': '2025-07-15 14:19:09',
 'experiment_name': 'ordinal_example',
 'PATHS': {'DATA_PATH': '/home/bule/TramDag/dev_experiment_logs/ordinal_example',
  'LOG_DIR': '/home/bule/TramDag/dev_experiment_logs',
  'EXPERIMENT_DIR': '/home/bule/TramDag/dev_experiment_logs/ordinal_example'},
 'data_type': None,
 'adj_matrix': None,
 'model_names': None,
 'seed': None,
 'nodes': None}

In [9]:
data_type={key:value for key, value in zip(train_df.columns, ['cont']*13+['ord'])}

levels_dict={'target':len(np.unique(df['target']))}  

columns = [
    'alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium',
    'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins',
    'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline', 'target'
]

adj_matrix = np.full((len(columns), len(columns)), "0", dtype=object)

# Set last column (edges *to* 'target') as "ls", excluding self-loop
for i in range(len(columns) - 1):
    adj_matrix[i, -1] = "ls"
    
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
nn_names_matrix

array([['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
        'LinearShift'],
       ['0', '0', '0', '0', '0

In [10]:
quantiles = train_df.quantile([0.05, 0.95])
min_vals = quantiles.loc[0.05]
max_vals = quantiles.loc[0.95]


In [11]:
def create_node_dict_v2(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict=None):
    """
    Creates a configuration dictionary for TRAMADAG based on an adjacency matrix,
    a neural network names matrix, and a data type dictionary.
    """
    if not validate_adj_matrix(adj_matrix):
        raise ValueError("Invalid adjacency matrix. Please check the criteria.")
    
    if len(data_type) != adj_matrix.shape[0]:
        raise ValueError("Data type dictionary should have the same length as the adjacency matrix.")
    
    target_nodes = {}
    G, edge_labels = create_nx_graph(adj_matrix, node_labels=list(data_type.keys()))
    
    sources = [node for node in G.nodes if G.in_degree(node) == 0]
    sinks = [node for node in G.nodes if G.out_degree(node) == 0]
    
    for i, node in enumerate(G.nodes):
        parents = list(G.predecessors(node))
        target_nodes[node] = {}
        target_nodes[node]['Modelnr'] = i
        target_nodes[node]['data_type'] = data_type[node]
        
        # write the levels of the ordinal outcome
        if data_type[node]=='ord':
            if levels_dict is None:
                print('provide levels_dict e.g. {"x3":3}')
            else:
                target_nodes[node]['levels'] = levels_dict[node]
            
    
        target_nodes[node]['node_type'] = "source" if node in sources else "sink" if node in sinks else "internal"
        target_nodes[node]['parents'] = parents
        target_nodes[node]['parents_datatype'] = {parent:data_type[parent] for parent in parents}
        target_nodes[node]['transformation_terms_in_h()'] = {parent: edge_labels[(parent, node)] for parent in parents if (parent, node) in edge_labels}
        target_nodes[node]['min'] = min_vals.iloc[i].tolist()   
        target_nodes[node]['max'] = max_vals.iloc[i].tolist()

        
        transformation_term_nn_models = {}
        for parent in parents:
            parent_idx = list(data_type.keys()).index(parent)  
            child_idx = list(data_type.keys()).index(node) 
            
            if nn_names_matrix[parent_idx, child_idx] != "0":
                transformation_term_nn_models[parent] = nn_names_matrix[parent_idx, child_idx]
        target_nodes[node]['transformation_term_nn_models_in_h()'] = transformation_term_nn_models
    return target_nodes

In [12]:
target_nodes=create_node_dict_v2(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)
target_nodes

{'alcohol': {'Modelnr': 0,
  'data_type': 'cont',
  'node_type': 'source',
  'parents': [],
  'parents_datatype': {},
  'transformation_terms_in_h()': {},
  'min': 11.665000000000001,
  'max': 14.2295,
  'transformation_term_nn_models_in_h()': {}},
 'malic_acid': {'Modelnr': 1,
  'data_type': 'cont',
  'node_type': 'source',
  'parents': [],
  'parents_datatype': {},
  'transformation_terms_in_h()': {},
  'min': 1.0710000000000002,
  'max': 4.600999999999998,
  'transformation_term_nn_models_in_h()': {}},
 'ash': {'Modelnr': 2,
  'data_type': 'cont',
  'node_type': 'source',
  'parents': [],
  'parents_datatype': {},
  'transformation_terms_in_h()': {},
  'min': 1.92,
  'max': 2.7495,
  'transformation_term_nn_models_in_h()': {}},
 'alcalinity_of_ash': {'Modelnr': 3,
  'data_type': 'cont',
  'node_type': 'source',
  'parents': [],
  'parents_datatype': {},
  'transformation_terms_in_h()': {},
  'min': 14.030000000000001,
  'max': 25.0,
  'transformation_term_nn_models_in_h()': {}},
 'm

In [13]:
from utils.tram_model_helpers import *
from utils.tram_models import *

In [14]:
# TODO : Where a SimpleIntercept model is there must be C-1 thetas 


def get_fully_specified_tram_model_v2(node: str, target_nodes: dict, verbose=True):
    # Source nodes get a simple intercept only
    if target_nodes[node]['node_type'] == 'source':
        # if target node is ordinal we only need c-1 thetas
        if target_nodes[node]['data_type']=='ord':
            nn_int = SimpleIntercept(n_thetas=target_nodes[node]['levels']-1)
        else:    
            nn_int = SimpleIntercept()
            
        model = TramModel(nn_int, None)
        if verbose:
            print("Source → SimpleIntercept only")
        return model

    # Otherwise gather terms and model names
    _, terms_dict, model_names_dict = ordered_parents(node, target_nodes)
    model_dict = merge_transformation_dicts(terms_dict, model_names_dict)

    # Split intercepts vs. shifts
    intercepts_dict = {
        k: v for k, v in model_dict.items()
        if "ci" in v['h_term'] or "si" in v['h_term']
    }
    shifts_dict = {
        k: v for k, v in model_dict.items()
        if "ci" not in v['h_term'] and "si" not in v['h_term']
    }

    # Build intercept network
    intercept_groups = group_by_base(intercepts_dict, prefixes=("ci", "si"))
    if not intercept_groups:

        if target_nodes[node]['data_type']=='ord':
            nn_int = SimpleIntercept(n_thetas=int(target_nodes[node]['levels'])-1)
        else:    
            nn_int = SimpleIntercept()

    else:
        if len(intercept_groups) > 1:
            raise ValueError("Multiple intercept models detected; only one is supported.")
        feats = next(iter(intercept_groups.values()))
        cls_name = feats[0][1]['class_name']
        base = get_base_model_class(cls_name)
        
        
        if target_nodes[node]['data_type']=='ord':
            nn_int = globals()[base](n_features=len(feats),n_thetas=int(target_nodes[node]['levels'])-1)
        else:    
            nn_int = globals()[base](n_features=len(feats))
        

    # Build shift networks (handles both "cs" and "ls")
    shift_groups = group_by_base(shifts_dict, prefixes=("cs", "ls"))
    nn_shifts = []
    for feats in shift_groups.values():
        cls_name = feats[0][1]['class_name']
        base = get_base_model_class(cls_name)
        nn_shifts.append(globals()[base](n_features=len(feats)))

    # Combine into TramModel
    tram_model = TramModel(nn_int, nn_shifts)
    if verbose:
        print("Constructed TRAM model:", tram_model)
    return tram_model

In [15]:
get_fully_specified_tram_model_v2('target', target_nodes, verbose=True)

Constructed TRAM model: TramModel(
  (nn_int): SimpleIntercept(
    (fc): Linear(in_features=1, out_features=2, bias=False)
  )
  (nn_shift): ModuleList(
    (0-12): 13 x LinearShift(
      (fc): Linear(in_features=1, out_features=1, bias=False)
    )
  )
)


TramModel(
  (nn_int): SimpleIntercept(
    (fc): Linear(in_features=1, out_features=2, bias=False)
  )
  (nn_shift): ModuleList(
    (0-12): 13 x LinearShift(
      (fc): Linear(in_features=1, out_features=1, bias=False)
    )
  )
)

In [16]:
from utils.loss_ordinal import *

In [17]:
def train_val_loop_v3(
                   node,
                   target_nodes,
                   NODE_DIR,
                   tram_model,
                   train_loader,
                   val_loader,
                   epochs,
                   optimizer,
                   use_scheduler,
                   scheduler,
                   save_linear_shifts=False,
                   verbose=1,
                   device='cpu'):
    
        # get all paths  for this training run
        MODEL_PATH,LAST_MODEL_PATH,TRAIN_HIST_PATH,VAL_HIST_PATH=model_train_val_paths(NODE_DIR)
        
        # this is needed for the preprocessing of the inputs such that they are in the correct order
        _, ordered_transformation_terms_in_h, _=ordered_parents(node, target_nodes)
        
        # this is needed for the scaling if there is a bernstein polynomial for contionous outcomes
        min_vals = torch.tensor(target_nodes[node]['min'], dtype=torch.float32).to(device)
        max_vals = torch.tensor(target_nodes[node]['max'], dtype=torch.float32).to(device)
        min_max = torch.stack([min_vals, max_vals], dim=0)
        print(f"Min-Max values for {node}: {min_max}")
        print(min_max.shape)
        
        
        ###### Load Model & History #####
        if os.path.exists(MODEL_PATH) and os.path.exists(TRAIN_HIST_PATH) and os.path.exists(VAL_HIST_PATH):
            print("Existing model found. Loading weights and history...")
            tram_model.load_state_dict(torch.load(MODEL_PATH))

            with open(TRAIN_HIST_PATH, 'r') as f:
                train_loss_hist = json.load(f)
            with open(VAL_HIST_PATH, 'r') as f:
                val_loss_hist = json.load(f)

            start_epoch = len(train_loss_hist)
            best_val_loss = min(val_loss_hist)
        else:
            print("No existing model found. Starting fresh...")
            train_loss_hist, val_loss_hist = [], []
            start_epoch = 0
            best_val_loss = float('inf')
        
        ##### Training and Validation loop
        for epoch in range(start_epoch, epochs):
            epoch_start = time.time()

            #####  Training #####
            train_start = time.time()
            train_loss = 0.0
            tram_model.train()
            for x, y in train_loader:
                optimizer.zero_grad()
                y = y.to(device)
                int_input, shift_list = preprocess_inputs(x, ordered_transformation_terms_in_h.values(), device=device)
                y_pred = tram_model(int_input=int_input, shift_input=shift_list)
                
                
                # print(f"y_pred shape: {y_pred}, y shape: {y}")
                # print(f'min_max:{min_max}')

                if target_nodes[node]['data_type'] == 'ord':
                    loss = contram_nll(y_pred, y, min_max=min_max)
                else:
                    loss = ontram_nll(y_pred, y)
                
                
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                
            if use_scheduler:
                scheduler.step()
                
            train_time = time.time() - train_start

            avg_train_loss = train_loss / len(train_loader)
            train_loss_hist.append(avg_train_loss)

            ##### Validation #####
            
            val_start = time.time()
            val_loss = 0.0
            tram_model.eval()
            
            with torch.no_grad():
                for x, y in val_loader:
                    y = y.to(device)
                    int_input, shift_list = preprocess_inputs(x, ordered_transformation_terms_in_h.values(), device=device)
                    y_pred = tram_model(int_input=int_input, shift_input=shift_list)
                    loss = contram_nll(y_pred, y, min_max=min_max)
                    val_loss += loss.item()
            val_time = time.time() - val_start

            avg_val_loss = val_loss / len(val_loader)
            val_loss_hist.append(avg_val_loss)

            ##### Save linear shift weights #####
            
            if save_linear_shifts and tram_model.nn_shift is not None:
                # Define the path for the cumulative JSON file
                shift_path = os.path.join(NODE_DIR, "linear_shifts_all_epochs.json")

                # Load existing data if the file exists
                if os.path.exists(shift_path):
                    with open(shift_path, 'r') as f:
                        all_shift_weights = json.load(f)
                else:
                    all_shift_weights = {}

                # Prepare current epoch's shift weights
                epoch_weights = {}
                for i in range(len(tram_model.nn_shift)):
                    shift_layer = tram_model.nn_shift[i]
                    
                    if hasattr(shift_layer, 'fc') and hasattr(shift_layer.fc, 'weight'):
                        epoch_weights[f"shift_{i}"] = shift_layer.fc.weight.detach().cpu().tolist()
                    else:
                        print(f"shift_{i}: 'fc' or 'weight' layer does not exist.")
                
                # Add to the dictionary under current epoch
                all_shift_weights[f"epoch_{epoch+1}"] = epoch_weights
                
                # Write back the updated dictionary
                with open(shift_path, 'w') as f:
                    json.dump(all_shift_weights, f)
                if verbose > 1:
                    print(f'shift weights: {epoch_weights}')
                    print(f"Appended linear shift weights for epoch {epoch+1} to: {shift_path}")

            ##### Saving #####
            save_start = time.time()
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(tram_model.state_dict(), MODEL_PATH)
                if verbose > 0:
                    print("Saved new best model.")

            torch.save(tram_model.state_dict(), LAST_MODEL_PATH)

            with open(TRAIN_HIST_PATH, 'w') as f:
                json.dump(train_loss_hist, f)
            with open(VAL_HIST_PATH, 'w') as f:
                json.dump(val_loss_hist, f)
            save_time = time.time() - save_start

            epoch_total = time.time() - epoch_start

            ##### Epoch Summary #####
            if verbose>0:
                print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
                print(f"  [Train: {train_time:.2f}s | Val: {val_time:.2f}s | Save: {save_time:.2f}s | Total: {epoch_total:.2f}s]")


In [18]:
from utils.tram_data import get_dataloader
node='target'


epochs =1000

NODE_DIR = os.path.join(EXPERIMENT_DIR, f'{node}')
os.makedirs(NODE_DIR, exist_ok=True)

tram_model =get_fully_specified_tram_model_v2(node, target_nodes, verbose=True).to(device)
optimizer =torch.optim.Adam(tram_model.parameters(), lr=0.1)

train_loader, val_loader = get_dataloader(node, target_nodes, train_df, val_df, batch_size=32, verbose=True)

Constructed TRAM model: TramModel(
  (nn_int): SimpleIntercept(
    (fc): Linear(in_features=1, out_features=2, bias=False)
  )
  (nn_shift): ModuleList(
    (0-12): 13 x LinearShift(
      (fc): Linear(in_features=1, out_features=1, bias=False)
    )
  )
)


In [19]:
train_val_loop_v3(
            node,
            target_nodes,
            NODE_DIR,
            tram_model,
            train_loader,
            val_loader,
            epochs,
            optimizer,
            use_scheduler=False,
            scheduler=False,
            save_linear_shifts=False,
            verbose=1,
            device=device)


Min-Max values for target: tensor([0., 2.], device='cuda:0')
torch.Size([2])
No existing model found. Starting fresh...
Saved new best model.
Epoch 1/1000 | Train Loss: 346.7604 | Val Loss: 124.8182
  [Train: 0.79s | Val: 0.25s | Save: 0.01s | Total: 1.04s]
Epoch 2/1000 | Train Loss: 108.9699 | Val Loss: 157.0490
  [Train: 0.36s | Val: 0.26s | Save: 0.00s | Total: 0.62s]
Saved new best model.
Epoch 3/1000 | Train Loss: 122.5723 | Val Loss: 56.0741
  [Train: 0.35s | Val: 0.26s | Save: 0.01s | Total: 0.62s]
Epoch 4/1000 | Train Loss: 75.2929 | Val Loss: 128.4721
  [Train: 0.34s | Val: 0.24s | Save: 0.00s | Total: 0.59s]
Saved new best model.
Epoch 5/1000 | Train Loss: 70.4277 | Val Loss: 48.2570
  [Train: 0.35s | Val: 0.25s | Save: 0.01s | Total: 0.61s]
Saved new best model.
Epoch 6/1000 | Train Loss: 55.7168 | Val Loss: 44.1450
  [Train: 0.34s | Val: 0.25s | Save: 0.01s | Total: 0.60s]
Epoch 7/1000 | Train Loss: 39.9371 | Val Loss: 59.8834
  [Train: 0.35s | Val: 0.25s | Save: 0.00s | To

In [23]:
MODEL_PATH,LAST_MODEL_PATH,TRAIN_HIST_PATH,VAL_HIST_PATH=model_train_val_paths(NODE_DIR)


if os.path.exists(MODEL_PATH) and os.path.exists(TRAIN_HIST_PATH) and os.path.exists(VAL_HIST_PATH):
    print("Existing model found. Loading weights and history...")
    tram_model.load_state_dict(torch.load(MODEL_PATH))

Existing model found. Loading weights and history...


In [24]:
tram_model.eval()

_, ordered_transformation_terms_in_h, _=ordered_parents(node, target_nodes)

min_vals = torch.tensor(target_nodes[node]['min'], dtype=torch.float32).to(device)
max_vals = torch.tensor(target_nodes[node]['max'], dtype=torch.float32).to(device)
min_max = torch.stack([min_vals, max_vals], dim=0)

with torch.no_grad():
    for x, y in val_loader:
        y = y.to(device)
        int_input, shift_list = preprocess_inputs(x, ordered_transformation_terms_in_h.values(), device=device)
        y_pred = tram_model(int_input=int_input, shift_input=shift_list)
        # loss = contram_nll(y_pred, y, min_max=min_max)
        print(transform_intercepts_ordinal(y_pred['int_out']))

tensor([[       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        [       -inf, -1.5784e+00,  1.7363e+14,         inf],
        

In [27]:
get_pdf(get_cdf(y_pred))

tensor([[1.0000e+00, 7.1526e-07, 0.0000e+00],
        [6.9653e-01, 3.0347e-01, 0.0000e+00],
        [1.0000e+00, 5.9605e-07, 0.0000e+00],
        [9.9997e-01, 3.1114e-05, 0.0000e+00],
        [3.5395e-02, 9.6461e-01, 0.0000e+00],
        [3.1105e-01, 6.8895e-01, 0.0000e+00],
        [9.9996e-01, 4.2558e-05, 0.0000e+00],
        [1.0000e+00, 3.5763e-07, 0.0000e+00],
        [9.2137e-01, 7.8635e-02, 0.0000e+00],
        [4.9480e-01, 5.0520e-01, 0.0000e+00],
        [9.9753e-01, 2.4696e-03, 0.0000e+00],
        [1.4545e-01, 8.5455e-01, 0.0000e+00],
        [1.0000e+00, 4.7684e-07, 0.0000e+00],
        [9.8231e-01, 1.7689e-02, 0.0000e+00],
        [1.8965e-03, 9.9810e-01, 0.0000e+00],
        [1.0000e+00, 1.7881e-06, 0.0000e+00],
        [1.6537e-01, 8.3463e-01, 0.0000e+00],
        [9.9620e-01, 3.8044e-03, 0.0000e+00]], device='cuda:0')

In [28]:
y

tensor([2., 0., 2., 2., 0., 0., 2., 2., 1., 0., 1., 0., 2., 1., 0., 2., 0., 1.],
       device='cuda:0')

In [29]:
def accuracy(y_pred, y_true):
    predicted_labels = torch.argmax(y_pred, dim=1)
    return (predicted_labels == y_true).float().mean().item()

accuracy(get_pdf(get_cdf(y_pred)), y)

0.0555555559694767

In [26]:


def get_cdf(outputs):
    """"
    Get cumulative distribution function
    
    Args:
        outputs: output of a model of class OntramModel
    """
    int_in = outputs['int_out']
    shift_in = outputs['shift_out']
    
    # transform intercepts
    int = transform_intercepts_ordinal(int_in)

    if shift_in is not None:
        shift = torch.stack(shift_in, dim=1).sum(dim=1)    
        cdf = torch.sigmoid(torch.sub(int, shift))
    else:
        cdf = torch.sigmoid(int)
    return cdf

def get_pdf(cdf):
    """"
    Get probability density function
    
    Args:
        cdf: cumulative distirbution function returning from get_cdf
    """
    return torch.sub(cdf[:,1:], cdf[:,:-1])

def pred_proba(pdf, targets):
    """"
    Get probability for the true class
    
    Args:
        pdf: probability density function returning from get_pdf
        targets: Outcome classes, one hot encoded
    """
    target_class = torch.argmax(targets, dim=1)
    proba = pdf[torch.arange(pdf.shape[0]), target_class]
    return proba

def pred_class(pdf):
    """"
    Get the predicted class.
    
    Args:
        pdf: probability density function returning from get_pdf
    """
    return torch.argmax(pdf, dim=1)

def get_parameters_si(model, ls_pos=0):
    """"
    Get parameters of the intercept and the linear shift term for interpretability
    
    Args:
        model: a model of class OntramModel
        ls_pos: Position of the linear shift term in the list of nn_shift
    """
    params_int = []
    for l in model.nn_int.children():
        for p in l.parameters():
            params_int.append(p.detach().clone())
    
    if model.nn_shift is not None:
        print("Give parameters for nn_shift at position: ", ls_pos)
        params_ls = []
        for l in model.nn_shift[ls_pos].children():
            for p in l.parameters():
                params_ls.append(p.detach().clone())
        return {"params_int": params_int, "params_shift": params_ls}
    else:
        return {"params_int": params_int, "params_shift": None}