# minimal example of ontram implementation
- anlogous to https://github.com/liherz/ontram_pytorch.git

In [1]:
# Load dependencies
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import os
from sklearn.datasets import load_wine
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from utils.graph import *

In [2]:
experiment_name = "ordinal_example"   ## <--- set experiment name
seed=42
np.random.seed(seed)

LOG_DIR="/home/bule/TramDag/dev_experiment_logs"
EXPERIMENT_DIR = os.path.join(LOG_DIR, experiment_name)
DATA_PATH = EXPERIMENT_DIR # <----------- change to different source if needed
CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, f"configuration.json")

In [3]:
# Load the dataset
wine = load_wine()
wine

{'data': array([[1.423e+01, 1.710e+00, 2.430e+00, ..., 1.040e+00, 3.920e+00,
         1.065e+03],
        [1.320e+01, 1.780e+00, 2.140e+00, ..., 1.050e+00, 3.400e+00,
         1.050e+03],
        [1.316e+01, 2.360e+00, 2.670e+00, ..., 1.030e+00, 3.170e+00,
         1.185e+03],
        ...,
        [1.327e+01, 4.280e+00, 2.260e+00, ..., 5.900e-01, 1.560e+00,
         8.350e+02],
        [1.317e+01, 2.590e+00, 2.370e+00, ..., 6.000e-01, 1.620e+00,
         8.400e+02],
        [1.413e+01, 4.100e+00, 2.740e+00, ..., 6.100e-01, 1.600e+00,
         5.600e+02]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [4]:
df=pd.DataFrame(wine['data'], columns=wine['feature_names'])
df['target']=wine['target']

In [5]:
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

In [6]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 142 entries, 158 to 102
Data columns (total 14 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   alcohol                       142 non-null    float64
 1   malic_acid                    142 non-null    float64
 2   ash                           142 non-null    float64
 3   alcalinity_of_ash             142 non-null    float64
 4   magnesium                     142 non-null    float64
 5   total_phenols                 142 non-null    float64
 6   flavanoids                    142 non-null    float64
 7   nonflavanoid_phenols          142 non-null    float64
 8   proanthocyanins               142 non-null    float64
 9   color_intensity               142 non-null    float64
 10  hue                           142 non-null    float64
 11  od280/od315_of_diluted_wines  142 non-null    float64
 12  proline                       142 non-null    float64
 13  target  

In [7]:
data_type={key:value for key, value in zip(train_df.columns, ['cont']*13+['ord'])}
data_type

{'alcohol': 'cont',
 'malic_acid': 'cont',
 'ash': 'cont',
 'alcalinity_of_ash': 'cont',
 'magnesium': 'cont',
 'total_phenols': 'cont',
 'flavanoids': 'cont',
 'nonflavanoid_phenols': 'cont',
 'proanthocyanins': 'cont',
 'color_intensity': 'cont',
 'hue': 'cont',
 'od280/od315_of_diluted_wines': 'cont',
 'proline': 'cont',
 'target': 'ord'}

In [8]:
configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)
configuration_dict

{'date_of_creation': '2025-07-15 06:47:07',
 'experiment_name': 'ordinal_example',
 'PATHS': {'DATA_PATH': '/home/bule/TramDag/dev_experiment_logs/ordinal_example',
  'LOG_DIR': '/home/bule/TramDag/dev_experiment_logs',
  'EXPERIMENT_DIR': '/home/bule/TramDag/dev_experiment_logs/ordinal_example'},
 'data_type': None,
 'adj_matrix': None,
 'model_names': None,
 'seed': None,
 'nodes': None}

In [None]:
data_type={'x1':'cont','x2':'cont','target':'ord'}  # continous , images , ordinal

levels_dict={'target':len(np.unique(df['target']))}  

adj_matrix = np.array([
    [ "0", "0", "ls"],  
    [ "0", "0", "ci"],  
    [ "0", "0", "0"],  
], object)

nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
nn_names_matrix

*************
 Model has Complex intercepts and Complex shifts, please add your Model to the modelzoo 
*************


array([['0', '0', 'LinearShift'],
       ['0', '0', 'ComplexInterceptDefaultTabular'],
       ['0', '0', '0']], dtype=object)

In [36]:
quantiles = train_df.quantile([0.05, 0.95])
min_vals = quantiles.loc[0.05]
max_vals = quantiles.loc[0.95]


In [42]:
def create_node_dict_v2(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict=None):
    """
    Creates a configuration dictionary for TRAMADAG based on an adjacency matrix,
    a neural network names matrix, and a data type dictionary.
    """
    if not validate_adj_matrix(adj_matrix):
        raise ValueError("Invalid adjacency matrix. Please check the criteria.")
    
    if len(data_type) != adj_matrix.shape[0]:
        raise ValueError("Data type dictionary should have the same length as the adjacency matrix.")
    
    target_nodes = {}
    G, edge_labels = create_nx_graph(adj_matrix, node_labels=list(data_type.keys()))
    
    sources = [node for node in G.nodes if G.in_degree(node) == 0]
    sinks = [node for node in G.nodes if G.out_degree(node) == 0]
    
    for i, node in enumerate(G.nodes):
        parents = list(G.predecessors(node))
        target_nodes[node] = {}
        target_nodes[node]['Modelnr'] = i
        target_nodes[node]['data_type'] = data_type[node]
        
        # write the levels of the ordinal outcome
        if data_type[node]=='ord':
            if levels_dict is None:
                print('provide levels_dict e.g. {"x3":3}')
            else:
                target_nodes[node]['levels'] = levels_dict[node]
            
    
        target_nodes[node]['node_type'] = "source" if node in sources else "sink" if node in sinks else "internal"
        target_nodes[node]['parents'] = parents
        target_nodes[node]['parents_datatype'] = {parent:data_type[parent] for parent in parents}
        target_nodes[node]['transformation_terms_in_h()'] = {parent: edge_labels[(parent, node)] for parent in parents if (parent, node) in edge_labels}
        target_nodes[node]['min'] = min_vals.iloc[i].tolist()   
        target_nodes[node]['max'] = max_vals.iloc[i].tolist()

        
        transformation_term_nn_models = {}
        for parent in parents:
            parent_idx = list(data_type.keys()).index(parent)  
            child_idx = list(data_type.keys()).index(node) 
            
            if nn_names_matrix[parent_idx, child_idx] != "0":
                transformation_term_nn_models[parent] = nn_names_matrix[parent_idx, child_idx]
        target_nodes[node]['transformation_term_nn_models_in_h()'] = transformation_term_nn_models
    return target_nodes

In [43]:
target_nodes=create_node_dict_v2(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)

In [None]:
from utils.tram_model_helpers import *
from utils.tram_models import *

In [44]:
# TODO : Where a SimpleIntercept model is there must be C-1 thetas 


def get_fully_specified_tram_model_v2(node: str, target_nodes: dict, verbose=True):
    # Source nodes get a simple intercept only
    if target_nodes[node]['node_type'] == 'source':
        # if target node is ordinal we only need c-1 thetas
        if target_nodes[node]['data_type']=='ord':
            nn_int = SimpleIntercept(n_thetas=target_nodes[node]['levels']-1)
        else:    
            nn_int = SimpleIntercept()
            
        model = TramModel(nn_int, None)
        if verbose:
            print("Source → SimpleIntercept only")
        return model

    # Otherwise gather terms and model names
    _, terms_dict, model_names_dict = ordered_parents(node, target_nodes)
    model_dict = merge_transformation_dicts(terms_dict, model_names_dict)

    # Split intercepts vs. shifts
    intercepts_dict = {
        k: v for k, v in model_dict.items()
        if "ci" in v['h_term'] or "si" in v['h_term']
    }
    shifts_dict = {
        k: v for k, v in model_dict.items()
        if "ci" not in v['h_term'] and "si" not in v['h_term']
    }

    # Build intercept network
    intercept_groups = group_by_base(intercepts_dict, prefixes=("ci", "si"))
    if not intercept_groups:

        if target_nodes[node]['data_type']=='ord':
            nn_int = SimpleIntercept(n_thetas=int(target_nodes[node]['levels'])-1)
        else:    
            nn_int = SimpleIntercept()

    else:
        if len(intercept_groups) > 1:
            raise ValueError("Multiple intercept models detected; only one is supported.")
        feats = next(iter(intercept_groups.values()))
        cls_name = feats[0][1]['class_name']
        base = get_base_model_class(cls_name)
        
        
        if target_nodes[node]['data_type']=='ord':
            nn_int = globals()[base](n_features=len(feats),n_thetas=int(target_nodes[node]['levels'])-1)
        else:    
            nn_int = globals()[base](n_features=len(feats))
        

    # Build shift networks (handles both "cs" and "ls")
    shift_groups = group_by_base(shifts_dict, prefixes=("cs", "ls"))
    nn_shifts = []
    for feats in shift_groups.values():
        cls_name = feats[0][1]['class_name']
        base = get_base_model_class(cls_name)
        nn_shifts.append(globals()[base](n_features=len(feats)))

    # Combine into TramModel
    tram_model = TramModel(nn_int, nn_shifts)
    if verbose:
        print("Constructed TRAM model:", tram_model)
    return tram_model

In [45]:
get_fully_specified_tram_model_v2('target', target_nodes, verbose=True)

Constructed TRAM model: TramModel(
  (nn_int): ComplexInterceptDefaultTabular(
    (fc1): Linear(in_features=1, out_features=8, bias=True)
    (relu1): ReLU()
    (fc2): Linear(in_features=8, out_features=8, bias=True)
    (relu2): ReLU()
    (fc3): Linear(in_features=8, out_features=2, bias=False)
  )
  (nn_shift): ModuleList(
    (0): LinearShift(
      (fc): Linear(in_features=1, out_features=1, bias=False)
    )
  )
)


TramModel(
  (nn_int): ComplexInterceptDefaultTabular(
    (fc1): Linear(in_features=1, out_features=8, bias=True)
    (relu1): ReLU()
    (fc2): Linear(in_features=8, out_features=8, bias=True)
    (relu2): ReLU()
    (fc3): Linear(in_features=8, out_features=2, bias=False)
  )
  (nn_shift): ModuleList(
    (0): LinearShift(
      (fc): Linear(in_features=1, out_features=1, bias=False)
    )
  )
)

In [48]:
from utils.loss_ordinal import *

In [49]:
def train_val_loop_v3(
                   node,
                   target_nodes,
                   NODE_DIR,
                   tram_model,
                   train_loader,
                   val_loader,
                   epochs,
                   optimizer,
                   use_scheduler,
                   scheduler,
                   save_linear_shifts=False,
                   verbose=1,
                   device='cpu'):
    
        # get all paths  for this training run
        MODEL_PATH,LAST_MODEL_PATH,TRAIN_HIST_PATH,VAL_HIST_PATH=model_train_val_paths(NODE_DIR)
        
        # this is needed for the preprocessing of the inputs such that they are in the correct order
        _, ordered_transformation_terms_in_h, _=ordered_parents(node, target_nodes)
        
        # this is needed for the scaling if there is a bernstein polynomial for contionous outcomes
        min_vals = torch.tensor(target_nodes[node]['min'], dtype=torch.float32).to(device)
        max_vals = torch.tensor(target_nodes[node]['max'], dtype=torch.float32).to(device)
        min_max = torch.stack([min_vals, max_vals], dim=0)
        print(f"Min-Max values for {node}: {min_max}")
        print(min_max.shape)
        
        
        ###### Load Model & History #####
        if os.path.exists(MODEL_PATH) and os.path.exists(TRAIN_HIST_PATH) and os.path.exists(VAL_HIST_PATH):
            print("Existing model found. Loading weights and history...")
            tram_model.load_state_dict(torch.load(MODEL_PATH))

            with open(TRAIN_HIST_PATH, 'r') as f:
                train_loss_hist = json.load(f)
            with open(VAL_HIST_PATH, 'r') as f:
                val_loss_hist = json.load(f)

            start_epoch = len(train_loss_hist)
            best_val_loss = min(val_loss_hist)
        else:
            print("No existing model found. Starting fresh...")
            train_loss_hist, val_loss_hist = [], []
            start_epoch = 0
            best_val_loss = float('inf')
        
        ##### Training and Validation loop
        for epoch in range(start_epoch, epochs):
            epoch_start = time.time()

            #####  Training #####
            train_start = time.time()
            train_loss = 0.0
            tram_model.train()
            for x, y in train_loader:
                optimizer.zero_grad()
                y = y.to(device)
                int_input, shift_list = preprocess_inputs(x, ordered_transformation_terms_in_h.values(), device=device)
                y_pred = tram_model(int_input=int_input, shift_input=shift_list)
                
                
                print(f"y_pred shape: {y_pred}, y shape: {y}")
                print(f'min_max^:{min_max}')
                
                
                
                #TODO loss
                if target_nodes[node]['data_type'] == 'ord':
                    loss = contram_nll(y_pred, y, min_max=min_max)
                else:
                    loss = ontram_nll(y_pred, y)
                
                
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                
            if use_scheduler:
                scheduler.step()
                
            train_time = time.time() - train_start

            avg_train_loss = train_loss / len(train_loader)
            train_loss_hist.append(avg_train_loss)

            ##### Validation #####
            
            val_start = time.time()
            val_loss = 0.0
            tram_model.eval()
            
            with torch.no_grad():
                for x, y in val_loader:
                    y = y.to(device)
                    int_input, shift_list = preprocess_inputs(x, ordered_transformation_terms_in_h.values(), device=device)
                    y_pred = tram_model(int_input=int_input, shift_input=shift_list)
                    loss = contram_nll(y_pred, y, min_max=min_max)
                    val_loss += loss.item()
            val_time = time.time() - val_start

            avg_val_loss = val_loss / len(val_loader)
            val_loss_hist.append(avg_val_loss)

            ##### Save linear shift weights #####
            
            if save_linear_shifts and tram_model.nn_shift is not None:
                # Define the path for the cumulative JSON file
                shift_path = os.path.join(NODE_DIR, "linear_shifts_all_epochs.json")

                # Load existing data if the file exists
                if os.path.exists(shift_path):
                    with open(shift_path, 'r') as f:
                        all_shift_weights = json.load(f)
                else:
                    all_shift_weights = {}

                # Prepare current epoch's shift weights
                epoch_weights = {}
                for i in range(len(tram_model.nn_shift)):
                    shift_layer = tram_model.nn_shift[i]
                    
                    if hasattr(shift_layer, 'fc') and hasattr(shift_layer.fc, 'weight'):
                        epoch_weights[f"shift_{i}"] = shift_layer.fc.weight.detach().cpu().tolist()
                    else:
                        print(f"shift_{i}: 'fc' or 'weight' layer does not exist.")
                
                # Add to the dictionary under current epoch
                all_shift_weights[f"epoch_{epoch+1}"] = epoch_weights
                
                # Write back the updated dictionary
                with open(shift_path, 'w') as f:
                    json.dump(all_shift_weights, f)
                if verbose > 1:
                    print(f'shift weights: {epoch_weights}')
                    print(f"Appended linear shift weights for epoch {epoch+1} to: {shift_path}")

            ##### Saving #####
            save_start = time.time()
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(tram_model.state_dict(), MODEL_PATH)
                if verbose > 0:
                    print("Saved new best model.")

            torch.save(tram_model.state_dict(), LAST_MODEL_PATH)

            with open(TRAIN_HIST_PATH, 'w') as f:
                json.dump(train_loss_hist, f)
            with open(VAL_HIST_PATH, 'w') as f:
                json.dump(val_loss_hist, f)
            save_time = time.time() - save_start

            epoch_total = time.time() - epoch_start

            ##### Epoch Summary #####
            if verbose>0:
                print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
                print(f"  [Train: {train_time:.2f}s | Val: {val_time:.2f}s | Save: {save_time:.2f}s | Total: {epoch_total:.2f}s]")


In [None]:
train_val_loop_v3(
                   node,
                   target_nodes,
                   NODE_DIR,
                   tram_model,
                   train_loader,
                   val_loader,
                   epochs,
                   optimizer,
                   use_scheduler,
                   scheduler,
                   save_linear_shifts=False,
                   verbose=1,
                   device='cpu')

In [None]:
train_list=['x1','x2','x3','x4'] #<-  set the nodes which have to be trained , useful if further training is required else lsit all vars

batch_size = 32#4112
epochs = 3# <- if you want a higher numbe rof epochs, set the number higher and it loads the old model and starts from there
learning_rate=0.1
use_scheduler =  True

In [None]:
# For each NODE 
configuration_dict = load_configuration_dict(CONF_DICT_PATH)
target_nodes = configuration_dict['nodes']


for node in target_nodes:
    
    print(f'\n----*----------*-------------*--------------- Node: {node} ------------*-----------------*-------------------*--')
    ########################## 0. Skip nodes ###############################
    if node not in train_list:# Skip if node is not in train_list
        print(f"Skipping node {node} as it's not in the training list.")
        continue
    if (target_nodes[node]['node_type'] == 'source') and (target_nodes[node]['node_type'] == 'other'):# Skip unsupported types
        print(f"Node type : other , is not supported yet")
        continue

    ########################## 1. Setup Paths ###############################
    NODE_DIR = os.path.join(EXPERIMENT_DIR, f'{node}')
    os.makedirs(NODE_DIR, exist_ok=True)
    
    # Check if training is complete
    if not check_if_training_complete(node, NODE_DIR, epochs):
        continue
    
    
    ########################## 2. Create Model ##############################
    tram_model = get_fully_specified_tram_model(node, target_nodes, verbose=True).to(device)

    
    ########################## 3. Create Dataloaders ########################
    train_loader, val_loader = get_dataloader(node, target_nodes, train_df, val_df, batch_size=batch_size, verbose=True)

    ########################## 5. Optimizer & Scheduler ######################.
    
    optimizer =torch.optim.Adam(tram_model.parameters(), lr=learning_rate)
    
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    else:
        scheduler = None

In [None]:


tram_model = get_fully_specified_tram_model(node, target_nodes, verbose=True).to(device)

train_loader, val_loader = get_dataloader(node, target_nodes, train_df, val_df, batch_size=batch_size, verbose=True)