in this file a potenitla api is developed for the tramdag 

## TRAM Config

In [None]:
"config.json" #has to be created either manually or by running createw_config.ipynb first!

In [None]:
# cfg=TramDagConfig.load(CONFIG_PATH="config.json") # basically just loads a json file to a dictionary

# cfg.compute_scaling(df_train, write=True) # computes min max levels from training data and writes to cfg
# # checks all specifications and throws warnings
# # returns a dict

In [None]:
from utils.configuration import *

class TramDagConfig:
    def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None):
        """
        Initialize TramDagConfig.

        Args:
            conf_dict: optional dict with configuration. If None, starts empty.
            CONF_DICT_PATH: optional path to config file.
        """
        self.conf_dict = conf_dict or {}
        self.CONF_DICT_PATH = CONF_DICT_PATH
        # TODO write each configuration as an attribute? Or keep as dict?

    @classmethod
    def load(cls, CONF_DICT_PATH: str):
        """
        Alternative constructor: load config directly from a file.
        """
        conf = load_configuration_dict(CONF_DICT_PATH)
        return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH)

    def save(self, CONF_DICT_PATH: str = None):
        """
        Save config to file. If path is not provided, fall back to stored path.
        """
        path = CONF_DICT_PATH or self.CONF_DICT_PATH
        if path is None:
            raise ValueError("No CONF_DICT_PATH provided to save config.")
        write_configuration_dict(self.conf_dict, path)

    def compute_scaling(self, df: pd.DataFrame, write: bool = True):
        """
        Derive scaling information (min, max, levels) from data USE training data.
        """
        print("[INFO] Make sure to provide only training data to compute_scaling!")
        # calculate 5% and 95% quantiles for min and max values
        quantiles = df.quantile([0.05, 0.95])
        min_vals = quantiles.loc[0.05]
        max_vals = quantiles.loc[0.95]

        # calculate levels for categorical variables
        levels_dict = create_levels_dict(df, self.conf_dict['data_type'])

        # TODO remove outer dependency of these functions (re-loading conf dict)
        adj_matrix = read_adj_matrix_from_configuration(self.CONF_DICT_PATH)
        nn_names_matrix = read_nn_names_matrix_from_configuration(self.CONF_DICT_PATH)

        node_dict = create_node_dict(
            adj_matrix,
            nn_names_matrix,
            self.conf_dict['data_type'],
            min_vals=min_vals,
            max_vals=max_vals,
            levels_dict=levels_dict
        )
        conf_dict = load_configuration_dict(self.CONF_DICT_PATH)
        conf_dict['nodes'] = node_dict
        self.conf_dict = conf_dict  # keep it in memory too

        if write and self.CONF_DICT_PATH is not None:
            try:
                write_configuration_dict(conf_dict, self.CONF_DICT_PATH)
                print(f'[INFO] Configuration with updated scaling saved to {self.CONF_DICT_PATH}')
            except Exception as e:
                print(f'[ERROR] Failed to save configuration: {e}')

### test it --> works!

In [11]:
train_df=pd.read_csv('/home/bule/TramDag/dev_experiment_logs/exp_6_2/exp_6_2_train.csv')
cfg = TramDagConfig.load("/home/bule/TramDag/dev_experiment_logs/exp_6_2/configuration.json")


cfg.compute_scaling(train_df) # computes min max levels from training data and writes to cfg

[INFO] Make sure to provide only training data to compute_scaling!
[INFO] Configuration with updated scaling saved to /home/bule/TramDag/dev_experiment_logs/exp_6_2/configuration.json


# TramDagDataset

In [None]:
from utils.tram_data import *
from utils.tram_data_helpers import *


class TramDagDataset:
    def __init__(self, df: pd.DataFrame, cfg: TramDagConfig, split="train"):
        self.df = df
        self.cfg = cfg
        self.split = split
        self.datasets = {}

        for node, meta in cfg.conf_dict["nodes"].items():
            self.datasets[node] = GenericDataset(
                df=df,
                target_col=node,
                target_nodes=cfg.conf_dict["nodes"],
                return_intercept_shift=True,
                return_y=True
            )

    def get_dataloader(self, node: str, batch_size=128, shuffle=True):
        ds = self.datasets[node]
        return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)


In [None]:
def get_dataloader( node,
                    target_nodes,
                    train_df=None,
                    val_df=None,
                    batch_size=32,
                    return_intercept_shift=False,
                    debug=False,
                    transform=None,
                    ):
    
        if transform is None:
            transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor()
            ])
        train_loader, val_loader = None, None

        if train_df is not None:
            train_ds = GenericDataset(train_df,target_col=node,target_nodes=target_nodes,transform=transform,return_intercept_shift=return_intercept_shift,debug=debug)
            train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)
        else:
            print("[INFO] train_df is None → skipping train dataloader.")

        if val_df is not None:
            val_ds = GenericDataset(val_df,target_col=node,target_nodes=target_nodes,transform=transform,return_intercept_shift=return_intercept_shift,debug=debug)
            val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        else:
            print("[INFO] val_df is None → skipping val dataloader.")

        if train_loader is None and val_loader is None:
            raise ValueError("[ERROR] Both train_df and val_df are None → no dataloaders created.")

        return train_loader, val_loader

In [None]:
td_train_data=TramDagDataset.from_dataframe(df_train,cfg) # also extracts min max vals from training data and the levels 
td_val_data  =TramDagDataset.from_dataframe(df_val,cfg) 

# additonal arguments could be batch_size, num_workers, pin_memory etc 
# returs a dictionary with {'x1':dataloader1,'x2':dataloader2,...}

In [None]:
td_model = TramDagModel.from_config(cfg) 
# builds the models according to the config
# returns a dict of TramModels for each node {'x1':model1,'x2':model2,...}

In [None]:
class TramDagModel:
    def __init__(self, cfg: TramDagConfig, device="cuda"):
        self.cfg = cfg
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.models = {}

        for node in cfg.conf_dict["nodes"]:
            self.models[node] = get_fully_specified_tram_model(node, cfg.conf_dict["nodes"]).to(self.device)

    def fit(self, train_dataset: TramDagDataset, val_dataset: TramDagDataset, epochs=20, lr=1e-3):
        optimizers = {node: torch.optim.Adam(self.models[node].parameters(), lr=lr)
                      for node in self.models}

        for epoch in range(epochs):
            for node, model in self.models.items():
                model.train()
                train_loader = train_dataset.get_dataloader(node)
                for (int_input, shift_list), y in train_loader:
                    int_input = int_input.to(self.device)
                    shift_list = [s.to(self.device) for s in shift_list] if shift_list else None
                    y = y.to(self.device)

                    outputs = model(int_input, shift_list)
                    if self.cfg.conf_dict["nodes"][node]["data_type"] == "cont":
                        min_max = (self.cfg.conf_dict["nodes"][node]["min"],
                                   self.cfg.conf_dict["nodes"][node]["max"])
                        loss = contram_nll(outputs, y, min_max)
                    else:
                        loss = ontram_nll(outputs, y)

                    optimizers[node].zero_grad()
                    loss.backward()
                    optimizers[node].step()
            print(f"Epoch {epoch+1}/{epochs} finished.")

    def save(self, path: str):
        torch.save({
            "cfg": self.cfg.conf_dict,
            "state_dicts": {node: model.state_dict() for node, model in self.models.items()}
        }, path)

    @classmethod
    def load(cls, path: str, device="cuda"):
        checkpoint = torch.load(path, map_location=device)
        cfg = TramDagConfig(checkpoint["cfg"])
        obj = cls(cfg, device=device)
        for node, model in obj.models.items():
            model.load_state_dict(checkpoint["state_dicts"][node])
        return obj

    def predict(self, df: pd.DataFrame, node: str, batch_size=128):
        dataset = TramDagDataset(df, self.cfg, split="test")
        loader = dataset.get_dataloader(node, batch_size=batch_size, shuffle=False)
        model = self.models[node].eval()
        preds = []
        with torch.no_grad():
            for (int_input, shift_list), _ in loader:
                int_input = int_input.to(self.device)
                shift_list = [s.to(self.device) for s in shift_list] if shift_list else None
                out = model(int_input, shift_list)
                preds.append(out["int_out"].cpu())
        return torch.cat(preds, dim=0)


In [None]:
# returns a trainer object which can train all models in the dag independently
td_model.fit( td_train_data, td_val_data, epochs=100 train_list=['x1','x2','x3'], callbacks=[], learning_rate,  device="auto")

# the fit functin trains all models in the train_list independently for the specified epochs 

#td_fit object contains the history and the best models for each node as well as the cfg 

In [None]:
td_model.history() # show_training_history(node_list,EXPERIMENT_DIR)

In [None]:
td_model.show_hdag_for_source_nodes()#show_hdag_for_source_nodes(configuration_dict,EXPERIMENT_DIR,device=device,xmin_plot=0,xmax_plot=1) # TODO for other nodes funciton

In [None]:
td_model.inspect_trafo_standart_logistic()
#inspect_trafo_standart_logistic(configuration_dict,EXPERIMENT_DIR,train_df,val_df,device,verbose=False)

In [None]:
td_model.get_latent() # returns or saves as attribute to td_model , this calls the 
#all_latents_df = create_latent_df_for_full_dag(configuration_dict, EXPERIMENT_DIR, train_df, verbose=True)


In [None]:
td_model.sample()

sampled_by_node, latents_by_node=sample_full_dag(configuration_dict,
                EXPERIMENT_DIR,
                device,
                do_interventions={},
                predefined_latent_samples_df=None,#all_latents_df,
                number_of_samples= 10_000,
                batch_size = 32,
                delete_all_previously_sampled=True,
                verbose=True,
                debug=False)