in this file a potenitla api is developed for the tramdag 

## TRAM Config

In [1]:
"config.json" #has to be created either manually or by running createw_config.ipynb first!

'config.json'

In [2]:
# cfg=TramDagConfig.load(CONFIG_PATH="config.json") # basically just loads a json file to a dictionary

# cfg.compute_scaling(df_train, write=True) # computes min max levels from training data and writes to cfg
# # checks all specifications and throws warnings
# # returns a dict

In [3]:
from utils.configuration import *


class TramDagConfig:
    def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None):
        """
        Initialize TramDagConfig.

        Args:
            conf_dict: optional dict with configuration. If None, starts empty.
            CONF_DICT_PATH: optional path to config file.
        """
        self.conf_dict = conf_dict or {}
        self.CONF_DICT_PATH = CONF_DICT_PATH
        # TODO write each configuration as an attribute? Or keep as dict?

    @classmethod
    def load(cls, CONF_DICT_PATH: str):
        """
        Alternative constructor: load config directly from a file.
        """
        conf = load_configuration_dict(CONF_DICT_PATH)
        return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH)

    def save(self, CONF_DICT_PATH: str = None):
        """
        Save config to file. If path is not provided, fall back to stored path.
        """
        path = CONF_DICT_PATH or self.CONF_DICT_PATH
        if path is None:
            raise ValueError("No CONF_DICT_PATH provided to save config.")
        write_configuration_dict(self.conf_dict, path)

    def compute_scaling(self, df: pd.DataFrame, write: bool = True):
        """
        Derive scaling information (min, max, levels) from data USE training data.
        """
        print("[INFO] Make sure to provide only training data to compute_scaling!")
        # calculate 5% and 95% quantiles for min and max values
        quantiles = df.quantile([0.05, 0.95])
        min_vals = quantiles.loc[0.05]
        max_vals = quantiles.loc[0.95]

        # calculate levels for categorical variables
        levels_dict = create_levels_dict(df, self.conf_dict['data_type'])

        # TODO remove outer dependency of these functions (re-loading conf dict)
        adj_matrix = read_adj_matrix_from_configuration(self.CONF_DICT_PATH)
        nn_names_matrix = read_nn_names_matrix_from_configuration(self.CONF_DICT_PATH)

        node_dict = create_node_dict(
            adj_matrix,
            nn_names_matrix,
            self.conf_dict['data_type'],
            min_vals=min_vals,
            max_vals=max_vals,
            levels_dict=levels_dict
        )
        conf_dict = load_configuration_dict(self.CONF_DICT_PATH)
        conf_dict['nodes'] = node_dict
        self.conf_dict = conf_dict  # keep it in memory too

        if write and self.CONF_DICT_PATH is not None:
            try:
                write_configuration_dict(conf_dict, self.CONF_DICT_PATH)
                print(f'[INFO] Configuration with updated scaling saved to {self.CONF_DICT_PATH}')
            except Exception as e:
                print(f'[ERROR] Failed to save configuration: {e}')

### test it --> works!

In [12]:
train_df=pd.read_csv('/home/bule/TramDag/dev_experiment_logs/exp_6_2/exp_6_2_train.csv')
val_df=pd.read_csv('/home/bule/TramDag/dev_experiment_logs/exp_6_2/exp_6_2_val.csv')

cfg = TramDagConfig.load("/home/bule/TramDag/dev_experiment_logs/exp_6_2/configuration.json")


cfg.compute_scaling(train_df) # computes min max levels from training data and writes to cfg

[INFO] Make sure to provide only training data to compute_scaling!
[INFO] Configuration with updated scaling saved to /home/bule/TramDag/dev_experiment_logs/exp_6_2/configuration.json


In [5]:
#TODO add veryfier such that nothing is missing for later training such as experiment name 

# TramDagDataset

In [14]:
import inspect
from utils.tram_data import GenericDataset
from torch.utils.data import Dataset, DataLoader

class TramDagDataset(Dataset):
    DEFAULTS = {
        "batch_size": 32,
        "shuffle": True,
        "num_workers": 4,
        "pin_memory": False,
        "return_intercept_shift": True,
        "debug": False,
        "transform": None,
    }

    def __init__(self):
        """Empty init. Use classmethods like .from_dataframe()."""
        pass

    @classmethod
    def from_dataframe(cls, df, cfg, **kwargs):
        self = cls()
        if not isinstance(df, pd.DataFrame):
            raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}")

        # merge defaults with overrides
        settings = dict(cls.DEFAULTS)
        settings.update(kwargs)

        # infer variable name automatically
        callers_locals = inspect.currentframe().f_back.f_locals
        inferred = None
        for var_name, var_val in callers_locals.items():
            if var_val is df:
                inferred = var_name
                break
        df_name = inferred or "dataframe"

        if settings["shuffle"]:
            if any(x in df_name.lower() for x in ["val", "validation", "test"]):
                print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?")

        self.cfg = cfg
        self.df = df.copy()
        self._apply_settings(settings)
        self._build_dataloaders()
        return self

    def _apply_settings(self, settings: dict):
        """Apply settings from defaults + overrides."""
        self.batch_size = settings["batch_size"]
        self.shuffle = settings["shuffle"]
        self.num_workers = settings["num_workers"]
        self.pin_memory = settings["pin_memory"]
        self.return_intercept_shift = settings["return_intercept_shift"]
        self.debug = settings["debug"]
        self.transform = settings["transform"]

        # nodes dict
        self.nodes_dict = self.cfg.conf_dict["nodes"]

        # validate dict attributes for all configurable params
        for name, val in {
            "batch_size": self.batch_size,
            "shuffle": self.shuffle,
            "num_workers": self.num_workers,
            "pin_memory": self.pin_memory,
            "return_intercept_shift": self.return_intercept_shift,
            "debug": self.debug,
            "transform": self.transform,
        }.items():
            self._check_keys(name, val)

    def _build_dataloaders(self):
        """Build node-specific dataloaders from df + settings."""
        self.loaders = {}
        for node in self.nodes_dict:
            ds = GenericDataset(
                self.df,
                target_col=node,
                target_nodes=self.nodes_dict,
                transform=self.transform if not isinstance(self.transform, dict) else self.transform[node],
                return_intercept_shift=self.return_intercept_shift if not isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift[node],
                debug=self.debug if not isinstance(self.debug, dict) else self.debug[node],
            )

            batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size
            shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else bool(self.shuffle)
            num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers
            pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory

            self.loaders[node] = DataLoader(
                ds,
                batch_size=batch_size,
                shuffle=shuffle_flag,
                num_workers=num_workers,
                pin_memory=pin_memory,
            )

    def _check_keys(self, attr_name, attr_value):
        """Check if dict keys match cfg.conf_dict['nodes'].keys()."""
        if isinstance(attr_value, dict):
            expected_keys = set(self.nodes_dict.keys())
            given_keys = set(attr_value.keys())
            if expected_keys != given_keys:
                raise ValueError(
                    f"[ERROR] the provided attribute '{attr_name}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
                    f"Expected: {expected_keys}, but got: {given_keys}\n"
                    f"Please provide values for all variables."
                )

    def summary(self):
        print("\n[TramDagDataset Summary]")
        print("=" * 60)

        # ---- DataFrame section ----
        print("\n[DataFrame]")
        print("Shape:", self.df.shape)
        print("\nHead:")
        print(self.df.head())

        print("\nDtypes:")
        print(self.df.dtypes)

        print("\nDescribe:")
        print(self.df.describe(include="all"))

        # ---- Settings per node ----
        print("\n[Node Settings]")
        for node in self.nodes_dict.keys():
            batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size
            shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else bool(self.shuffle)
            num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers
            pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory
            rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift
            debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug
            transform = self.transform[node] if isinstance(self.transform, dict) else self.transform

            print(
                f" Node '{node}': "
                f"batch_size={batch_size}, "
                f"shuffle={shuffle_flag}, "
                f"num_workers={num_workers}, "
                f"pin_memory={pin_memory}, "
                f"return_intercept_shift={rshift}, "
                f"debug={debug_flag}, "
                f"transform={transform}"
            )
        print("=" * 60 + "\n")

    def __getitem__(self, idx):
        return self.df.iloc[idx].to_dict()

    def __len__(self):
        return len(self.df)


## testit -> works

In [20]:
td_train_data=TramDagDataset.from_dataframe(train_df,cfg)  


In [21]:
td_val_data=TramDagDataset.from_dataframe(val_df,cfg,shuffle=False)  

# TramDagModel

In [11]:
td_model = TramDagModel.from_config(cfg, set_initial_weights=True) 
# builds the models according to the config
# returns a dict of TramModels for each node {'x1':model1,'x2':model2,...}

td_model.summary()


[INFO] Building model for node 'x1' with settings: {'set_initial_weights': True, 'debug': False}

[INFO] Building model for node 'x2' with settings: {'set_initial_weights': True, 'debug': False}

[INFO] Building model for node 'x3' with settings: {'set_initial_weights': True, 'debug': False}

[TramDagModel Summary]
 Node 'x1': TramModel
   - set_initial_weights: True
   - debug: False
 Node 'x2': TramModel
   - set_initial_weights: True
   - debug: False
 Node 'x3': TramModel
   - set_initial_weights: True
   - debug: False



In [10]:
from utils.tram_model_helpers import train_val_loop, get_fully_specified_tram_model
import torch
from torch.optim import Adam
import os


class TramDagModel:
    # ---- defaults used at construction time ----
    DEFAULTS_CONFIG = {
        "set_initial_weights": True,
        "debug":False,
        
    }

    # ---- defaults used at fit() time ----
    DEFAULTS_FIT = {
        "epochs": 100,
        "train_list": None,
        "callbacks": None,
        "learning_rate": 0.01,
        "device": "auto",
        "optimizers": None,
        "schedulers": None,
        "use_scheduler": False,
        "save_linear_shifts": True,
        "debug":False,
        "verbose": 1,
    }

    def __init__(self):
        """Empty init. Use classmethods like .from_config()."""
        pass

    @classmethod
    def from_config(cls, cfg, **kwargs):
        """
        Build one TramModel per node based on configuration and kwargs.
        Kwargs can be scalars (applied to all nodes) or dicts {node: value}.
        """
        self = cls()
        self.cfg = cfg
        self.nodes_dict = self.cfg.conf_dict["nodes"]

        # merge defaults with user overrides
        settings = dict(cls.DEFAULTS_CONFIG)
        settings.update(kwargs)

        # initialize settings storage
        self.settings = {k: {} for k in settings.keys()}

        # validate dict-typed args
        for k, v in settings.items():
            if isinstance(v, dict):
                expected = set(self.nodes_dict.keys())
                given = set(v.keys())
                if expected != given:
                    raise ValueError(
                        f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
                        f"Expected: {expected}, but got: {given}\n"
                        f"Please provide values for all variables."
                    )

        # build one model per node
        self.models = {}
        for node in self.nodes_dict.keys():
            per_node_kwargs = {}
            for k, v in settings.items():
                resolved = v[node] if isinstance(v, dict) else v
                per_node_kwargs[k] = resolved
                self.settings[k][node] = resolved
            print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}")
            self.models[node] = get_fully_specified_tram_model(
                node=node,
                configuration_dict=self.cfg.conf_dict,
                **per_node_kwargs
            )
        return self

    def fit(self, td_train_data, td_val_data, **kwargs):
        """
        Fit TRAM models for specified nodes.
        All kwargs can be scalar (applied to all nodes) or dict {node: value}.
        """
        # merge defaults with overrides
        settings = dict(self.DEFAULTS_FIT)
        settings.update(kwargs)

        device = torch.device(
            "cuda" if (settings["device"] == "auto" and torch.cuda.is_available()) else settings["device"]
        )
        train_list = settings["train_list"] or list(self.models.keys())

        results = {}
        for node in train_list:
            model = self.models[node]

            # resolve epochs
            node_epochs = settings["epochs"][node] if isinstance(settings["epochs"], dict) else settings["epochs"]

            # resolve optimizer
            if settings["optimizers"] and node in settings["optimizers"]:
                optimizer = settings["optimizers"][node]
            else:
                optimizer = Adam(model.parameters(), lr=settings["learning_rate"])

            # resolve scheduler
            if settings["schedulers"] and node in settings["schedulers"]:
                scheduler = settings["schedulers"][node]
            else:
                scheduler = None

            # grab loaders
            train_loader = td_train_data.loaders[node]
            val_loader = td_val_data.loaders[node]

            NODE_DIR = os.path.join("models", node)
            os.makedirs(NODE_DIR, exist_ok=True)

            if settings["verbose"]:
                print(f"\n[INFO] Training node '{node}' for {node_epochs} epochs on {device}")

            history = train_val_loop(
                node=node,
                target_nodes=self.nodes_dict,
                NODE_DIR=NODE_DIR,
                tram_model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                epochs=node_epochs,
                optimizer=optimizer,
                use_scheduler=(scheduler is not None),
                scheduler=scheduler,
                save_linear_shifts=settings["save_linear_shifts"],
                verbose=settings["verbose"],
                device=device,
                debug=False
            )
            results[node] = history

        return results

    def summary(self):
        print("\n[TramDagModel Summary]")
        print("=" * 60)
        for node, model in self.models.items():
            print(f" Node '{node}': {model.__class__.__name__}")
            for k, v in self.settings.items():
                if node in v:
                    print(f"   - {k}: {v[node]}")
        print("=" * 60 + "\n")


In [23]:
td_model.fit( td_train_data, td_val_data, epochs=1,debug=True)

# the fit functin trains all models in the train_list independently for the specified epochs 

#td_fit object contains the history and the best models for each node as well as the cfg 


[INFO] Training node 'x1' for 1 epochs on cuda
No existing model found. Starting fresh...
Saved new best model.
Epoch 1/1  Train NLL: -0.4294  Val NLL: -0.5253  [Train: 26.06s  Val: 1.85s  Total: 27.92s]

[INFO] Training node 'x2' for 1 epochs on cuda
No existing model found. Starting fresh...
Saved new best model.
Epoch 1/1  Train NLL: 0.9413  Val NLL: 0.5135  [Train: 30.16s  Val: 1.81s  Total: 31.98s]

[INFO] Training node 'x3' for 1 epochs on cuda
No existing model found. Starting fresh...
Saved new best model.
Epoch 1/1  Train NLL: 1.2151  Val NLL: 1.2097  [Train: 18.55s  Val: 1.19s  Total: 19.74s]


{'x1': None, 'x2': None, 'x3': None}

In [None]:
td_model.history() # show_training_history(node_list,EXPERIMENT_DIR)

In [None]:
td_model.show_hdag_for_source_nodes()#show_hdag_for_source_nodes(configuration_dict,EXPERIMENT_DIR,device=device,xmin_plot=0,xmax_plot=1) # TODO for other nodes funciton

In [None]:
td_model.inspect_trafo_standart_logistic()
#inspect_trafo_standart_logistic(configuration_dict,EXPERIMENT_DIR,train_df,val_df,device,verbose=False)

In [None]:
td_model.get_latent() # returns or saves as attribute to td_model , this calls the 
#all_latents_df = create_latent_df_for_full_dag(configuration_dict, EXPERIMENT_DIR, train_df, verbose=True)


In [None]:
td_model.sample()

sampled_by_node, latents_by_node=sample_full_dag(configuration_dict,
                EXPERIMENT_DIR,
                device,
                do_interventions={},
                predefined_latent_samples_df=None,#all_latents_df,
                number_of_samples= 10_000,
                batch_size = 32,
                delete_all_previously_sampled=True,
                verbose=True,
                debug=False)