In [19]:
import os,sys
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

proj_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if proj_root not in sys.path:
    sys.path.insert(0, proj_root)

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from scipy.stats import logistic
from scipy.special import logit


import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Train with GPU support.")
else:
    device = torch.device('cpu')
    print("No GPU found, train with CPU support.")


# own utils
from utils.graph import *
from utils.tram_models import *
from utils.tram_model_helpers import *
from utils.tram_data import *
from utils.continous import *
from utils.sampling_tram_data import *

Train with GPU support.


In [2]:
seed=42
np.random.seed(seed)

TEST_DIR="/home/bule/TramDag/testing/model_tests"

In [3]:
def dgp_sklearn(nobs=1000, nvars=7, seed=42):
    X, _ = make_blobs(n_samples=nobs, n_features=nvars, centers=1,cluster_std=1.0, random_state=seed)
    cols = [f'x{i+1}' for i in range(nvars)]
    return pd.DataFrame(X, columns=cols)

df=dgp_sklearn(nobs=1000, nvars=2, seed=42)
print(df.info())


data_type= {'x1':'cont','x2':'cont'} # cont:continous, ord:ordinal, oher:everything else than images


train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

quantiles = train_df.quantile([0.05, 0.95])
min_vals = quantiles.loc[0.05]
max_vals = quantiles.loc[0.95]

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   x1      1000 non-null   float64
 1   x2      1000 non-null   float64
dtypes: float64(2)
memory usage: 15.8 KB
None


In [None]:

def load_and_write_test_dict_to_json(input:dict, test_name,file_path = None):
    """
    input has to bea dictionary with the following structure:
                                "input":{
                                                'x1': {
                                                'data_type': 'cont',
                                                'node_type': 'source',
                                                'parents': [],
                                                'parents_datatype': {},
                                                'transformation_terms_in_h()': {},
                                                'transformation_term_nn_models_in_h()': {}},
                                                
                                For n nodes
    """
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
            
    except Exception as e:
        print(f"Error loading JSON file: {e}")
    
    data[test_name]={}
    data[test_name].setdefault("input", {})
    data[test_name].setdefault("output", {})

    
    data[test_name]["input"] = input
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    for node_name in input:
        model = get_fully_specified_tram_model(node_name, input, verbose=False).to(device)
        data[test_name]["output"][node_name] = repr(model)
    
    try:
        with open(file_path, 'w') as f:
            json.dump(data, f, indent=4)
        print(f"Updated outputs written to {file_path}")
    except Exception as e:
        print(f"Error writing to JSON file: {e}")
        
        
def run_model_loader_test(test_name: str, testdict_path: str, device: torch.device = None):
    """
    General test for fully_specified_tram_model loader based on ground-truth JSON.

    Args:
        test_name: Key in the JSON file identifying the test case.
        testdict_path: Path to the JSON file containing input and expected output.
        device: Torch device to move models to; defaults to CUDA if available, else CPU.

    Raises:
        AssertionError: If any model's repr() does not match the expected output.
        ValueError: If the test_name is not found in the JSON file.
    """
    # Load ground-truth data
    with open(testdict_path, 'r') as f:
        test_data = json.load(f)

    if test_name not in test_data:
        raise ValueError(f"Test name '{test_name}' not found in {testdict_path}")
    

    inputs = test_data[test_name]['input']
    expected_outputs = test_data[test_name].get('output', {})

    # Determine device
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Iterate and compare
    for node_name in inputs:
        model = get_fully_specified_tram_model(node_name, inputs, verbose=False).to(device)
        actual_repr = repr(model)
        expected_repr = expected_outputs.get(node_name)

        assert expected_repr is not None, (
            f"No expected output found for node '{node_name}' "
            f"in '{testdict_path}'."
        )
        assert actual_repr == expected_repr, (
            f"Mismatch for node '{node_name}':\n"
            f"  Expected: {expected_repr}\n"
            f"  Actual:   {actual_repr}"
        )

# Testing 2 variables
4 tests SI LS, CS and CI

In [None]:
# LS
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1'],
        'parents_datatype': {'x1': 'cont'},
        'transformation_terms_in_h()': {'x1': 'ls'},
        'transformation_term_nn_models_in_h()': {'x1': 'LinearShift'}}
        }

load_and_write_test_dict_to_json(input, "test_1", file_path = os.path.join(TEST_DIR, 'twovars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/twovars_model_loader_test_dict.json


In [6]:
# CS
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1'],
        'parents_datatype': {'x1': 'cont'},
        'transformation_terms_in_h()': {'x1': 'ls'},
        'transformation_term_nn_models_in_h()': {'x1': 'ComplexShiftDefaultTabular'}}
        }

load_and_write_test_dict_to_json(input, "test_2",file_path = os.path.join(TEST_DIR, 'twovars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/twovars_model_loader_test_dict.json


In [7]:
# CI
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1'],
        'parents_datatype': {'x1': 'cont'},
        'transformation_terms_in_h()': {'x1': 'ci'},
        'transformation_term_nn_models_in_h()': {'x1': 'ComplexInterceptDefaultTabular'}}
        }

load_and_write_test_dict_to_json(input, "test_3",file_path = os.path.join(TEST_DIR, 'twovars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/twovars_model_loader_test_dict.json


In [8]:
# this tests all the 3 cases for SI LS CS and CI

for test in ['test_1', 'test_2', 'test_3']:
    run_model_loader_test(test, os.path.join(TEST_DIR, 'twovars_model_loader_test_dict.json'))


# 3vars

In [10]:
# TODO here is a bug when 2 linera shifts are used in the same model, e.g. when we have a model like this:

#  X3 ~ LS(X1) + LS(X2)
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x3': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1','x2'],
        'parents_datatype': {'x1': 'cont','x2': 'cont'},
        'transformation_terms_in_h()': {'x1': 'ls','x2': 'ls'},
        'transformation_term_nn_models_in_h()': {'x1': 'LinearShift','x2': 'LinearShift'}}
        }

load_and_write_test_dict_to_json(input, "test_1", file_path = os.path.join(TEST_DIR, 'threevars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/threevars_model_loader_test_dict.json


In [11]:
#  X3 ~ CS(X1) + LS(X2)
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x3': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1','x2'],
        'parents_datatype': {'x1': 'cont','x2': 'cont'},
        'transformation_terms_in_h()': {'x1': 'cs','x2': 'ls'},
        'transformation_term_nn_models_in_h()': {'x1': 'ComplexShiftDefaultTabular','x2': 'LinearShift'}}
        }

load_and_write_test_dict_to_json(input, "test_2", file_path = os.path.join(TEST_DIR, 'threevars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/threevars_model_loader_test_dict.json


In [12]:
#  X3 ~ CS(X1) + CS(X2)
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x3': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1','x2'],
        'parents_datatype': {'x1': 'cont','x2': 'cont'},
        'transformation_terms_in_h()': {'x1': 'cs','x2': 'cs'},
        'transformation_term_nn_models_in_h()': {'x1': 'ComplexShiftDefaultTabular','x2': 'ComplexShiftDefaultTabular'}}
        }

load_and_write_test_dict_to_json(input, "test_3", file_path = os.path.join(TEST_DIR, 'threevars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/threevars_model_loader_test_dict.json


In [13]:
#  X3 ~ CS(X1) + CI(X2)
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x3': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1','x2'],
        'parents_datatype': {'x1': 'cont','x2': 'cont'},
        'transformation_terms_in_h()': {'x1': 'cs','x2': 'ci'},
        'transformation_term_nn_models_in_h()': {'x1': 'ComplexShiftDefaultTabular','x2': 'ComplexInterceptDefaultTabular'}}
        }

load_and_write_test_dict_to_json(input, "test_4", file_path = os.path.join(TEST_DIR, 'threevars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/threevars_model_loader_test_dict.json


In [14]:
#  X3 ~ CS11(X1) + CS12(X2)
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x3': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1','x2'],
        'parents_datatype': {'x1': 'cont','x2': 'cont'},
        'transformation_terms_in_h()': {'x1': 'cs11','x2': 'cs12'},
        'transformation_term_nn_models_in_h()': {'x1': 'ComplexShiftDefaultTabular11','x2': 'ComplexShiftDefaultTabular12'}}
        }

load_and_write_test_dict_to_json(input, "test_5", file_path = os.path.join(TEST_DIR, 'threevars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/threevars_model_loader_test_dict.json


In [15]:
#  X3 ~ CI11(X1) + CI12(X2)
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x3': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1','x2'],
        'parents_datatype': {'x1': 'cont','x2': 'cont'},
        'transformation_terms_in_h()': {'x1': 'ci11','x2': 'ci12'},
        'transformation_term_nn_models_in_h()': {'x1': 'ComplexInterceptDefaultTabular11','x2': 'ComplexInterceptDefaultTabular12'}}
        }

load_and_write_test_dict_to_json(input, "test_6", file_path = os.path.join(TEST_DIR, 'threevars_model_loader_test_dict.json'))

Updated outputs written to /home/bule/TramDag/testing/model_tests/threevars_model_loader_test_dict.json


In [16]:
# this tests all the 3 cases for SI LS CS and CI

for test in ['test_1', 'test_2', 'test_3', 'test_4', 'test_5', 'test_6']:
    run_model_loader_test(test, os.path.join(TEST_DIR, 'threevars_model_loader_test_dict.json'))


# 10 variables

testing cases for multiple groups

like ci11 ci12 , cs11 cs12, cs21 cs22, ls , cs , cs31, cs32 , 

In [20]:
#  X10 ~ CI12(X1) + CS12(X2) + LS(X3) + CS11(X4) + CS(X5) + CS22(X6) + CS21(X7) + CS32(X8) + CI11(X9)
input={
        'x1': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x2': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x3': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x4': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x5': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x6': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x7': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x8': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x9': {
        'data_type': 'cont',
        'node_type': 'source',
        'parents': [],
        'parents_datatype': {},
        'transformation_terms_in_h()': {},
        'transformation_term_nn_models_in_h()': {}},
        'x10': {
        'data_type': 'cont',
        'node_type': 'sink',
        'parents': ['x1','x2','x3','x4','x5','x7','x8','x9'],
        'parents_datatype': {'x1': 'cont','x2': 'cont','x2': 'cont','x3': 'cont','x4': 'cont','x5': 'cont','x6': 'cont','x7': 'cont','x8': 'cont','x9': 'cont'},
        'transformation_terms_in_h()': {'x1': 'ci12',
                                        'x2': 'cs12',
                                        'x3': 'ls',
                                        'x4': 'cs11',
                                        'x5': 'cs',
                                        'x6': 'cs22',
                                        'x7': 'cs21',
                                        'x8': 'cs32',
                                        'x9': 'ci11'},
        
        'transformation_term_nn_models_in_h()': {'x1': 'ComplexInterceptDefaultTabular12',
                                                 'x2': 'ComplexShiftDefaultTabular12',
                                                 'x3': 'LinearShift',
                                                 'x4': 'ComplexShiftDefaultTabular11',
                                                 'x5': 'ComplexShiftDefaultTabular',
                                                 'x6': 'ComplexShiftDefaultTabular22',
                                                 'x7': 'ComplexShiftDefaultTabular21',
                                                 'x8': 'ComplexShiftDefaultTabular',
                                                 'x9': 'ComplexInterceptDefaultTabular11'}}
        }

load_and_write_test_dict_to_json(input, "test_1", file_path = os.path.join(TEST_DIR, 'tenvars_model_loader_test_dict.json'))


Updated outputs written to /home/bule/TramDag/testing/model_tests/tenvars_model_loader_test_dict.json
