In [None]:
import os,sys
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

proj_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if proj_root not in sys.path:
    sys.path.insert(0, proj_root)

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from scipy.stats import logistic
from scipy.special import logit


from sklearn.datasets import make_blobs

import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset
from torch.cuda.amp import autocast, GradScaler

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Train with GPU support.")
else:
    device = torch.device('cpu')
    print("No GPU found, train with CPU support.")

import numpy as np
import json
import pandas as pd
import matplotlib.pyplot as plt


# own utils
from utils.graph import *
from utils.tram_models import *
from utils.tram_model_helpers import *
from utils.tram_data import *
from utils.continous import *
from utils.sampling_tram_data import *

Train with GPU support.


In [None]:
def dgp_sklearn(nobs=1000, nvars=7, seed=42):
    X, _ = make_blobs(n_samples=nobs, n_features=nvars, centers=1,cluster_std=1.0, random_state=seed)
    cols = [f'x{i+1}' for i in range(nvars)]
    return pd.DataFrame(X, columns=cols)



# Testing 2 variables

In [None]:
df=dgp_sklearn(nobs=1000, nvars=2, seed=42)
print(df.info())


data_type= {'x1':'cont','x2':'cont'} # cont:continous, ord:ordinal, oher:everything else than images


train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

quantiles = train_df.quantile([0.05, 0.95])
min_vals = quantiles.loc[0.05]
max_vals = quantiles.loc[0.95]

In [None]:
target_nodes={
            'x1': {'Modelnr': 0,
            'data_type': 'cont',
            'node_type': 'source',
            'parents': [],
            'parents_datatype': {},
            'transformation_terms_in_h()': {},
            'min': None,
            'max': None,
            'transformation_term_nn_models_in_h()': {}},
            'x2': {'Modelnr': 1,
            'data_type': 'cont',
            'node_type': 'sink',
            'parents': ['x1'],
            'parents_datatype': {'x1': 'cont'},
            'transformation_terms_in_h()': {'x1': 'ls'},
            'min': None,
            'max': None,
            'transformation_term_nn_models_in_h()': {'x1': 'LinearShift'}}
            }

In [None]:
for node in inputs:
    model = get_fully_specified_tram_model(node, inputs, verbose=False).to(device)
    # capture the representation directly
    all_outputs[test_key][node] = repr(model)
    
with open("tram_model_repr_outputs.json", "w") as f:
        json.dump(all_outputs, f, indent=2)

In [5]:
twovars_model_loader_test_dict={"test_1":{            
                                "input":{
                                                'x1': {
                                                'data_type': 'cont',
                                                'node_type': 'source',
                                                'parents': [],
                                                'parents_datatype': {},
                                                'transformation_terms_in_h()': {},
                                                'transformation_term_nn_models_in_h()': {}},
                                                'x2': {
                                                'data_type': 'cont',
                                                'node_type': 'sink',
                                                'parents': ['x1'],
                                                'parents_datatype': {'x1': 'cont'},
                                                'transformation_terms_in_h()': {'x1': 'ls'},
                                                'transformation_term_nn_models_in_h()': {'x1': 'LinearShift'}}
                                                },
                                "output":{}    
                                    }
                        }


# Save to JSON file
file_path = 'twovars_model_loader_test_dict.json'
with open(file_path, 'w') as f:
    json.dump(twovars_model_loader_test_dict, f, indent=4)

print(f"Dictionary has been saved to {file_path}")

Dictionary has been saved to twovars_model_loader_test_dict.json


In [6]:
file_path = 'twovars_model_loader_test_dict.json'

# Load the existing JSON data
with open(file_path, 'r') as f:
    data = json.load(f)

# Grab the inputs dict
inputs = data["test_1"]["input"]

# Determine device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Ensure the output dict exists
data["test_1"].setdefault("output", {})

# Iterate over each node, build the model, and store its repr
for node_name in inputs:
    model = get_fully_specified_tram_model(node_name, inputs, verbose=False).to(device)
    data["test_1"]["output"][node_name] = repr(model)

# Save the updated JSON back to file
with open(file_path, 'w') as f:
    json.dump(data, f, indent=4)

print(f"Updated outputs written to {file_path}")

Updated outputs written to twovars_model_loader_test_dict.json


In [7]:
# Load the existing JSON data
with open(file_path, 'r') as f:
    data = json.load(f)

In [8]:
data

{'test_1': {'input': {'x1': {'data_type': 'cont',
    'node_type': 'source',
    'parents': [],
    'parents_datatype': {},
    'transformation_terms_in_h()': {},
    'transformation_term_nn_models_in_h()': {}},
   'x2': {'data_type': 'cont',
    'node_type': 'sink',
    'parents': ['x1'],
    'parents_datatype': {'x1': 'cont'},
    'transformation_terms_in_h()': {'x1': 'ls'},
    'transformation_term_nn_models_in_h()': {'x1': 'LinearShift'}}},
  'output': {'x1': 'TramModel(\n  (nn_int): SimpleIntercept(\n    (fc): Linear(in_features=1, out_features=20, bias=False)\n  )\n)',
   'x2': 'TramModel(\n  (nn_int): SimpleIntercept(\n    (fc): Linear(in_features=1, out_features=20, bias=False)\n  )\n  (nn_shift): ModuleList(\n    (0): LinearShift(\n      (fc): Linear(in_features=1, out_features=1, bias=False)\n    )\n  )\n)'}}}

In [None]:
# Load the ground-truth data
with open('twovars_model_loader_test_dict.json', 'r') as f:
    test_data = json.load(f)

inputs = test_data["test_1"]["input"]
expected_outputs = test_data["test_1"]["output"]

# Determine device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def test_twovars_model_loader_ground_truth():
    """
    For each node in the test inputs, build the fully specified TRAM model,
    compare its repr() against the expected ground-truth, and fail if they differ.
    """
    for node_name in inputs:
        model = get_fully_specified_tram_model(node_name, inputs, verbose=False).to(device)
        actual_repr = repr(model)
        expected_repr = expected_outputs.get(node_name)
        assert expected_repr is not None, (
            f"No expected output found for node '{node_name}' in ground-truth JSON."
        )
        assert actual_repr == expected_repr, (
            f"Mismatch for node '{node_name}':\n"
            f"Expected: {expected_repr}\n"
            f"Actual:   {actual_repr}"
        )
        
test_twovars_model_loader_ground_truth()

In [None]:

for node in model_loader_test_dict["test_1"]['input']:
    print( get_fully_specified_tram_model(node, model_loader_test_dict["test_1"]['input'], verbose=False).to(device))
    print(model_loader_test_dict["test_1"]['output'][node])


In [None]:
tram_model

In [None]:
def get_base_model_class(class_name: str):
    # Strip digits to get the base class name
    for i, c in enumerate(class_name):
        if c.isdigit():
            return class_name[:i]
    return class_name

# --------- Group features by h_term base ---------
def group_by_base(term_dict, prefixes):
    if isinstance(prefixes, str):
        prefixes = (prefixes,)
    groups = defaultdict(list)
    for feat, conf in term_dict.items():
        h_term = conf['h_term']
        for prefix in prefixes:
            if h_term.startswith(prefix):
                if len(h_term) > len(prefix) and h_term[len(prefix)].isdigit():
                    key = h_term[:len(prefix)+1]
                else:
                    key = h_term
                groups[key].append((feat, conf))
                break
    return groups


def get_fully_specified_tram_model(node:str,target_nodes:dict,verbose=True): 

    ### iF node is a source -> Modeling as SimpleIntercept
    if target_nodes[node]['node_type'] == 'source':
        nn_int = SimpleIntercept()
        tram_model = TramModel(nn_int, None)  
        if verbose:
            print('>>>>>>>>>>>>  source node --> only  modelled only  by si') if verbose else None
            print(tram_model)
        return tram_model
    
    else:
        # read terms and model names form the config
        _,terms_dict,model_names_dict=ordered_parents(node, target_nodes)
        
        # Combine terms and model names and divide in intercept and shift terms
        model_dict=merge_transformation_dicts(terms_dict, model_names_dict)
        
        # separate intercept and shift terms
        intercepts_dict = {k: v for k, v in model_dict.items() if "ci" in v['h_term'] or 'si' in v['h_term']}        
        shifts_dict = {k: v for k, v in model_dict.items() if "ci" not in v['h_term'] and  'si' not in v['h_term']}        
        
        # make sure that nns are correctly defined afterwards
        nn_int, nn_shifts_list = None, []
        
        
        # --------- INTERCEPT TERM ---------
        intercept_groups = group_by_base(intercepts_dict, 'ci')

        if not intercept_groups:
            print('>>>>>>>>>>>> No ci detected --> intercept defaults to si') if verbose else None
            nn_int = SimpleIntercept()
        else:
            if len(intercept_groups) > 1:
                raise ValueError("Multiple intercept models detected; only one is currently supported.")

            group = list(intercept_groups.values())[0]
            any_class_name = group[0][1]['class_name']
            base_class_name = get_base_model_class(any_class_name)

            model_cls = globals()[base_class_name]
            n_features = len(group)
            nn_int = model_cls(n_features=n_features)

        # --------- SHIFT TERMS ---------
        shift_groups = group_by_base(shifts_dict, prefixes=('cs', 'ls'))

        for group in shift_groups.values():
            any_class_name = group[0][1]['class_name']
            base_class_name = get_base_model_class(any_class_name)

            model_cls = globals()[base_class_name]
            n_features = len(group)
            model = model_cls(n_features=n_features)
            nn_shifts_list.append(model)

        # --------- COMBINE TO NN CLASS ---------
        tram_model = TramModel(nn_int, nn_shifts_list)
        print('>>> TRAM MODEL:\n',tram_model) if verbose else None
        return tram_model

In [None]:
data_type = {'x2': 'cont','x3': 'cont','x4': 'cont','x5': 'cont','x6': 'cont','x7': 'cont','x8': 'cont','x9': 'cont'}
adj_matrix = np.array([
    ["0",  "0",  "0",  "0", "0",  "0",  "0",  "ls"],   # x1 → x2 (ci), x1 → x3 (ls)
    ["0",  "0",  "0",  "0", "0",  "0",  "0",  "cs12"],
    ["0",  "0",  "0",  "0", "0",  "0",  "0",  "cs21"],
    ["0",  "0",  "0",  "0", "0",  "0",  "0",  "cs22"],
    ["0",  "0",  "0",  "0", "0",  "0",  "0",  "ci12"],
    ["0",  "0",  "0",  "0", "0",  "0",  "0",  "ci11"],
    ["0",  "0",  "0",  "0", "0",  "0",  "0",  "cs11"],
    ["0",  "0",  "0",  "0", "0",  "0",  "0",  "0"]
], dtype=object)

plot_seed = 42
plot_dag(adj_matrix, data_type, seed=plot_seed)

In [None]:
# check if there are Ci or Compelx shifts in the models. If yes define the modelnames
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
plot_nn_names_matrix(nn_names_matrix,data_type)

In [None]:
for node in conf_dict:

tram_model = get_fully_specified_tram_model(node, conf_dict, verbose=True).to(device)
