## tesing dataloader V4

new ideas with ordinal c and yc 
has ordinal binary and continous outputs

In [1]:
# Load dependencies
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import os
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split



from utils.graph import *
from utils.loss_ordinal import *
from utils.tram_model_helpers import *
from utils.tram_models import *
from utils.tram_data import *


import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Train with GPU support.")
else:
    device = torch.device('cpu')
    print("No GPU found, train with CPU support.")

Train with GPU support.


adjustet funcitnos for ordinal outcomes 

dev ordinal

In [2]:
experiment_name = "testing_v4_dataloader"   ## <--- set experiment name
seed=42
np.random.seed(seed)

LOG_DIR="/home/bule/TramDag/dev_experiment_logs"
EXPERIMENT_DIR = os.path.join(LOG_DIR, experiment_name)
DATA_PATH = EXPERIMENT_DIR # <----------- change to different source if needed
CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, f"configuration.json")

In [3]:
### New Functions and classes


def create_levels_dict(df:pd.DataFrame,data_type:dict):
    # creates the levels dictionary for variables which should be modelled ordinaly 
    levels_dict={}
    for variable,datatype in data_type.items():
            if "ordinal" in datatype.lower():
                unique_vals = set(df[variable].dropna().unique())
                num_classes = len(unique_vals)

                expected_vals = set(range(num_classes))
                if unique_vals != expected_vals:
                    raise ValueError(
                        f"Variable '{variable}' has values {sorted(unique_vals)}, "
                        f"but expected values are {sorted(expected_vals)} (0 to {num_classes - 1}). "
                        "Multiclass ordinal variables must be zero-indexed and contiguous."
                )
                levels_dict[variable]=len(np.unique(df[variable]))
    return levels_dict   




def create_node_dict_v3(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict=None):
    """
    Creates a configuration dictionary for TRAMADAG based on an adjacency matrix,
    a neural network names matrix, and a data type dictionary.
    """
    if not validate_adj_matrix(adj_matrix):
        raise ValueError("Invalid adjacency matrix. Please check the criteria.")
    
    if len(data_type) != adj_matrix.shape[0]:
        raise ValueError("Data type dictionary should have the same length as the adjacency matrix.")
    
    target_nodes = {}
    G, edge_labels = create_nx_graph(adj_matrix, node_labels=list(data_type.keys()))
    
    sources = [node for node in G.nodes if G.in_degree(node) == 0]
    sinks = [node for node in G.nodes if G.out_degree(node) == 0]
    
    for i, node in enumerate(G.nodes):
        parents = list(G.predecessors(node))
        target_nodes[node] = {}
        target_nodes[node]['Modelnr'] = i
        target_nodes[node]['data_type'] = data_type[node]
        
        # write the levels of the ordinal outcome
        if 'ordinal' in data_type[node]:
            if levels_dict is None:
                raise ValueError(
                    "levels_dict must be provided for ordinal nodes; "
                    "e.g. levels_dict={'x3': 3}"
                )
            if node not in levels_dict:
                raise KeyError(
                    f"levels_dict is missing an entry for node '{node}'. "
                    f"Expected something like levels_dict['{node}'] = <num_levels>"
                )
            target_nodes[node]['levels'] = levels_dict[node]
    
        target_nodes[node]['node_type'] = "source" if node in sources else "sink" if node in sinks else "internal"
        target_nodes[node]['parents'] = parents
        target_nodes[node]['parents_datatype'] = {parent:data_type[parent] for parent in parents}
        target_nodes[node]['transformation_terms_in_h()'] = {parent: edge_labels[(parent, node)] for parent in parents if (parent, node) in edge_labels}
        target_nodes[node]['min'] = min_vals.iloc[i].tolist()   
        target_nodes[node]['max'] = max_vals.iloc[i].tolist()

        
        transformation_term_nn_models = {}
        for parent in parents:
            parent_idx = list(data_type.keys()).index(parent)  
            child_idx = list(data_type.keys()).index(node) 
            
            if nn_names_matrix[parent_idx, child_idx] != "0":
                transformation_term_nn_models[parent] = nn_names_matrix[parent_idx, child_idx]
        target_nodes[node]['transformation_term_nn_models_in_h()'] = transformation_term_nn_models
    return target_nodes


class GenericDataset_v4(Dataset):
    def __init__(
        self,
        df,
        target_col,
        target_nodes=None,
        parents_dataype_dict=None,
        transformation_terms_in_h=None,
        transform=None
    ):
        """
        df: pd.DataFrame
        target_col: str
        target_nodes: dict mapping each node → metadata (including 'data_type')
        parents_dataype_dict: dict var_name → "cont"|"ord"|"other"
        transform: torchvision transform for images
        transformation_terms_in_h: dict for intercept logic
        """
                
        self.df = df.reset_index(drop=True)
        self.target_col = target_col
        self.target_nodes = target_nodes or {}
        self.parents_dataype_dict = parents_dataype_dict or {}
        self.predictors = list(self.parents_dataype_dict.keys())
        self.transform = transform
        self.transformation_terms_in_h = transformation_terms_in_h or {}
        
        self.target_is_source= True if target_nodes[self.target_col].get('node_type').lower() == "source" else False
        self.h_needs_simple_intercept=True if all('i' not in str(v) for v in self.transformation_terms_in_h.values()) else False
        self.ordinal_num_classes ={
                        var: self.df[var].nunique() for var in self.predictors
                        if "ordinal" in self.parents_dataype_dict[var].lower() and "Xn".lower() in self.parents_dataype_dict[var].lower()
                        }
        
        self.target_data_type=self.target_nodes[self.target_col].get('data_type').lower()
        self.target_num_classes=self.target_nodes[self.target_col].get('levels') or None # should be none anywasy if levels not exits
        
        #checks
        self._check_multiclass_predictors_of_df()
        self._check_ordinal_levels()
    
    def _transform_y(self,row):
        #returns continous or onehot encoded target
        if  self.target_data_type=="continous" or "Yc".lower() in self.target_data_type:
                y = torch.tensor(row[self.target_col], dtype=torch.float32)
                return y
        elif self.target_num_classes is not None:
                        raw = row[self.target_col]
                        y_int = int(raw)
                        y = F.one_hot(torch.tensor(y_int, dtype=torch.long), num_classes=self.target_num_classes).float().squeeze()
                        return y    
        else:
            raise ValueError(
                f"Could not determine how to encode target '{self.target_col}'.\n"
                f"target_data_type: {self.target_data_type}, target_num_classes: {self.target_num_classes}"
            )

        
    def _check_multiclass_predictors_of_df(self):
        # checks whether the predictors(predictors) have star
        for var in self.predictors:
            dtype = self.parents_dataype_dict[var]
            # continous predictors or ordinal predictors modelled as continous are skippped
            if "ordinal" not in dtype:
                continue
            if "Xc" in dtype:
                continue
            
            # if ordinal varibale is nominally modelled -> check wheter the levels of the varibales start at 0 , minimal level must be 0
            elif "Xn".lower() in dtype.lower():
                unique_vals = set(self.df[var].dropna().unique())
                num_classes = len(unique_vals)

                expected_vals = set(range(num_classes))
                if unique_vals != expected_vals:
                    raise ValueError(
                        f"Variable '{var}' has values {sorted(unique_vals)}, "
                        f"but expected values are {sorted(expected_vals)} (0 to {num_classes - 1}). "
                        "Multiclass ordinal predictors must be zero-indexed and contiguous."
                )
            else:
                continue
            
    def _check_ordinal_levels(self):
        """
        Ensures all ordinal variables (including the target) have a 'levels' key in target_nodes,
        and that the values in the DataFrame are zero-indexed and contiguous.
        """
        all_ordinal_vars = [self.target_col] if "ordinal" in self.target_nodes[self.target_col]['data_type'] else []

        all_ordinal_vars+=[
            var for var in self.predictors
            if "ordinal" in self.parents_dataype_dict[var].lower()
            and "xn" in self.parents_dataype_dict[var].lower()
        ]
        
        for var in all_ordinal_vars:
            node_info = self.target_nodes.get(var, {})
            levels = node_info.get("levels")
            if levels is None:
                raise ValueError(
                    f"[Ordinal Check] Variable '{var}' is marked ordinal but has no 'levels' entry in target_nodes."
                )

            unique_vals = sorted(self.df[var].dropna().unique())
            expected_vals = list(range(levels))

            if unique_vals != expected_vals:
                raise ValueError(
                    f"[Ordinal Check] Variable '{var}' has values {unique_vals}, "
                    f"but expected values {expected_vals} (0 to {levels - 1}).\n"
                    f"Fix this in the DataFrame or correct 'levels' in target_nodes."
                )
                
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # ---------------------------------------- SOURCE NODE: no parents → x = [1.0] ---
        if  self.target_is_source:
            x_data = [torch.tensor(1.0)]
            y=self._transform_y(row)
            return tuple(x_data), y

        # ---------------------------------------- if not source prepare data  ---
        x_data = []

        # --- SIMPLE INTERCEPT if needed  first term in x is x = [1.0]---
        if self.h_needs_simple_intercept:
            x_data.append(torch.tensor(1.0))

        # --- BUILD FEATURES ---
        for var in self.predictors:
            dtype = self.parents_dataype_dict[var].lower()
            ## Continous  feature
            if dtype == "continous" or "Xc".lower() in dtype:
                x_data.append(torch.tensor(row[var], dtype=torch.float32))
                
            ## Ordinal feature , if it has more thatn 2 classes it uses onehotencodig, if binary use just 0 and 1
            elif "ordinal" in dtype and "Xn".lower() in dtype:
                x_ord = int(row[var])
                var_num_classes = self.ordinal_num_classes[var]
                x_ord_onehot = F.one_hot(torch.tensor(x_ord, dtype=torch.long),num_classes=var_num_classes).float()
                
                x_data.append(x_ord_onehot.squeeze())

            else:  # "other"
                img = Image.open(row[var]).convert("RGB")
                if self.transform:
                    img = self.transform(img)
                x_data.append(img)

        # --- BUILD TARGET ---
        y=self._transform_y(row)
            
            
        return tuple(x_data), y



def get_dataloader_v4(node, target_nodes, train_df, val_df, batch_size=32, verbose=False):
    
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    ordered_parents_dataype_dict, ordered_transformation_terms_in_h, _ = ordered_parents(node, target_nodes)
    if verbose:
        print(f"Parents dtype: {ordered_parents_dataype_dict}")
    train_ds = GenericDataset_v4(train_df,target_col=node,target_nodes=target_nodes,parents_dataype_dict=ordered_parents_dataype_dict,transform=transform,transformation_terms_in_h=ordered_transformation_terms_in_h)
    val_ds = GenericDataset_v4(val_df,target_col=node,target_nodes=target_nodes,parents_dataype_dict=ordered_parents_dataype_dict,transform=transform,transformation_terms_in_h=ordered_transformation_terms_in_h)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader

In [4]:

# datatype= "ordinal_Xc_Yc" # as X its continous aswell as Y continous
# datatype= "ordinal_Xc_Yo"
# datatype= "ordinal_Xn_Yc"
# datatype= "ordinal_Xn_Yo"


In [5]:
df = pd.DataFrame({
    "x1": np.random.normal(loc=0, scale=1, size=1000),
    # "x2": np.random.uniform(low=0, high=10, size=1000),
    "ord_bin1":  np.random.binomial(1, p=0.4, size=1000),
    "ord_multi": np.random.choice([0, 1, 2, 3], size=1000, p=[0.2, 0.3, 0.3, 0.2]),
    "ord_bin2":  np.random.binomial(1, p=0.4, size=1000),

})

train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

quantiles = train_df.quantile([0.05, 0.95])
min_vals = quantiles.loc[0.05]
max_vals = quantiles.loc[0.95]


data_type={key:value for key, value in zip(train_df.columns, ['continous']*1+['ordinal_Xn_Yo']*2+['ordinal_Xn_Yo'])}

configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)

columns = list(data_type.keys())
adj_matrix = np.full((len(columns), len(columns)), "0", dtype=object)
for i in range(len(columns)-1):
    adj_matrix[i, -1] = "ls"
            
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
levels_dict=create_levels_dict(df,data_type)
target_nodes=create_node_dict_v3(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)

for node in target_nodes:
    print(f'-----------------{node}-------------------')
    train_loader, _ = get_dataloader_v4(node, target_nodes, train_df, val_df, batch_size=1, verbose=False)
    x,y =next(iter(train_loader))

    print(x)
    print(y)

-----------------x1-------------------


[tensor([1.])]
tensor([-1.5148])
-----------------ord_bin1-------------------
[tensor([1.])]
tensor([[1., 0.]])
-----------------ord_multi-------------------
[tensor([1.])]
tensor([[0., 0., 0., 1.]])
-----------------ord_bin2-------------------
[tensor([1.]), tensor([-0.9815]), tensor([[1., 0.]]), tensor([[0., 1., 0., 0.]])]
tensor([[1., 0.]])


In [6]:
data_type={key:value for key, value in zip(train_df.columns, ['continous']+['ordinal_Xc_Yc']+['ordinal_Xn_Yc']+['ordinal_Xn_Yc'])}
print(data_type)
configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)

columns = list(data_type.keys())
adj_matrix = np.full((len(columns), len(columns)), "0", dtype=object)
for i in range(len(columns)-1):
    adj_matrix[i, -1] = "ls"
            
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
levels_dict=create_levels_dict(df,data_type)
target_nodes=create_node_dict_v3(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)

for node in target_nodes:
    print(f'-----------------{node}-------------------')
    train_loader, _ = get_dataloader_v4(node, target_nodes, train_df, val_df, batch_size=1, verbose=False)
    x,y =next(iter(train_loader))

    print(x)
    print(y)

{'x1': 'continous', 'ord_bin1': 'ordinal_Xc_Yc', 'ord_multi': 'ordinal_Xn_Yc', 'ord_bin2': 'ordinal_Xn_Yc'}
-----------------x1-------------------
[tensor([1.])]
tensor([-1.0777])
-----------------ord_bin1-------------------
[tensor([1.])]
tensor([1.])
-----------------ord_multi-------------------
[tensor([1.])]
tensor([2.])
-----------------ord_bin2-------------------
[tensor([1.]), tensor([-0.6270]), tensor([0.]), tensor([[0., 0., 1., 0.]])]
tensor([1.])


In [7]:
data_type={key:value for key, value in zip(train_df.columns, ['continous']+['ordinal_Xn_Yo']+['ordinal_Xn_Yo']+['ordinal_Xn_Yo'])}
print(data_type)
configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)

columns = list(data_type.keys())
adj_matrix = np.full((len(columns), len(columns)), "0", dtype=object)
for i in range(len(columns)-1):
    adj_matrix[i, -1] = "ls"
            
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
levels_dict=create_levels_dict(df,data_type)
target_nodes=create_node_dict_v3(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)

for node in target_nodes:
    print(f'-----------------{node}-------------------')
    train_loader, _ = get_dataloader_v4(node, target_nodes, train_df, val_df, batch_size=1, verbose=False)
    x,y =next(iter(train_loader))

    print(x)
    print(y)

{'x1': 'continous', 'ord_bin1': 'ordinal_Xn_Yo', 'ord_multi': 'ordinal_Xn_Yo', 'ord_bin2': 'ordinal_Xn_Yo'}
-----------------x1-------------------
[tensor([1.])]
tensor([0.1981])
-----------------ord_bin1-------------------
[tensor([1.])]
tensor([[1., 0.]])
-----------------ord_multi-------------------
[tensor([1.])]
tensor([[0., 0., 0., 1.]])
-----------------ord_bin2-------------------
[tensor([1.]), tensor([0.1504]), tensor([[0., 1.]]), tensor([[0., 0., 0., 1.]])]
tensor([[1., 0.]])


In [8]:
data_type={key:value for key, value in zip(train_df.columns, ['continous']+['ordinal_Xc_Yo']+['ordinal_Xc_Yc']+['ordinal_Xc_Yo'])}
print(data_type)
configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)

columns = list(data_type.keys())
adj_matrix = np.full((len(columns), len(columns)), "0", dtype=object)
for i in range(len(columns)-1):
    adj_matrix[i, -1] = "ls"
            
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
levels_dict=create_levels_dict(df,data_type)
target_nodes=create_node_dict_v3(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)

for node in target_nodes:
    print(f'-----------------{node}-------------------')
    train_loader, _ = get_dataloader_v4(node, target_nodes, train_df, val_df, batch_size=1, verbose=False)
    x,y =next(iter(train_loader))

    print(x)
    print(y)

{'x1': 'continous', 'ord_bin1': 'ordinal_Xc_Yo', 'ord_multi': 'ordinal_Xc_Yc', 'ord_bin2': 'ordinal_Xc_Yo'}
-----------------x1-------------------
[tensor([1.])]
tensor([-1.7630])
-----------------ord_bin1-------------------
[tensor([1.])]
tensor([[0., 1.]])
-----------------ord_multi-------------------
[tensor([1.])]
tensor([2.])
-----------------ord_bin2-------------------
[tensor([1.]), tensor([1.8862]), tensor([0.]), tensor([1.])]
tensor([[0., 1.]])


In [9]:
data_type={key:value for key, value in zip(train_df.columns, ['continous']+['ordinal_Xc_Yc']+['ordinal_Xc_Yc']+['ordinal_Xc_Yc'])}
print(data_type)
configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)

columns = list(data_type.keys())
adj_matrix = np.full((len(columns), len(columns)), "0", dtype=object)
for i in range(len(columns)-1):
    adj_matrix[i, -1] = "ls"
            
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
levels_dict=create_levels_dict(df,data_type)
target_nodes=create_node_dict_v3(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)

for node in target_nodes:
    print(f'-----------------{node}-------------------')
    train_loader, _ = get_dataloader_v4(node, target_nodes, train_df, val_df, batch_size=1, verbose=False)
    x,y =next(iter(train_loader))

    print(x)
    print(y)

{'x1': 'continous', 'ord_bin1': 'ordinal_Xc_Yc', 'ord_multi': 'ordinal_Xc_Yc', 'ord_bin2': 'ordinal_Xc_Yc'}
-----------------x1-------------------
[tensor([1.])]
tensor([0.1504])
-----------------ord_bin1-------------------
[tensor([1.])]
tensor([0.])
-----------------ord_multi-------------------
[tensor([1.])]
tensor([2.])
-----------------ord_bin2-------------------
[tensor([1.]), tensor([1.1356]), tensor([0.]), tensor([1.])]
tensor([0.])


In [20]:
def preprocess_inputs_v2(x, transformation_terms, device='cuda'):
    """
    Prepares model input by grouping features by transformation term base:
      - ci11, ci12 → 'ci1' (intercept)
      - cs11, cs12 → 'cs1' (shift)
      - cs21 → 'cs2' (another shift group)
      - cs, ls → treated as full group keys
    Returns:
      - int_inputs: Tensor of shape (B, n_features) for intercept model
      - shift_list: List of tensors for each shift model, shape (B, group_features)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transformation_terms=list(transformation_terms)
    
    print(transformation_terms)
    ## if there is only a source so transforamtion terms is 0:
    x = [xi.to(device, non_blocking=True) for xi in x]
    if len(transformation_terms)== 0:
        x = [xi.unsqueeze(1) for xi in x] 
        int_inputs= x[0]
        return int_inputs, None

    # Always ensure there's an intercept term
    if not any('ci' in str(value) for value in transformation_terms):
        transformation_terms.insert(0, 'si')

    # Lists to collect intercept tensors and shift‐groups
    int_tensors = []
    shift_groups = []

    # Helpers to track the “current” shift‐group for numbered suffixes
    current_group = None
    current_key = None

    for tensor, term in zip(x, transformation_terms):
        # 1) INTERCEPT terms (si*, ci*)
        if term.startswith(('si','ci')):
            int_tensors.append(tensor)

        # 2) SHIFT terms (cs*, ls*)
        elif term.startswith(('cs','ls')):
            # numbered suffix → group by the first 3 chars (e.g. 'cs11'/'cs12' → 'cs1')
            if len(term) > 2 and term[2].isdigit():
                key = term[:3]
                # start a new group if key changed
                if current_group is None or current_key != key:
                    current_group = []
                    shift_groups.append(current_group)
                    current_key = key
                current_group.append(tensor)

            # lone 'cs' or 'ls' → always its own group
            else:
                current_group = [tensor]
                shift_groups.append(current_group)
                current_key = None
        else:
            raise ValueError(f"Unknown transformation term: {term}")

    # Intercept: should be exactly one group
    if len(int_tensors) == 0:
        raise ValueError("No intercept tensors found!")
    int_inputs = torch.cat(
        [t.to(device, non_blocking=True).view(t.shape[0], -1) for t in int_tensors],
        dim=1
    )

    shift_list = [
        torch.cat([t.to(device, non_blocking=True).view(t.shape[0], -1) for t in group], dim=1)
        for group in shift_groups
    ]

    return int_inputs, shift_list if shift_list else None

In [21]:
data_type={key:value for key, value in zip(train_df.columns, ['continous']+['ordinal_Xn_Yo']+['ordinal_Xn_Yo']+['ordinal_Xn_Yo'])}
print(data_type)
configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)

adj_matrix=np.array([['0', '0', 'cs11', 'ci11'],
                     ['0', '0', 'cs12', 'ci12'],
                     ['0', '0', '0', 'cs'],
                     ['0', '0', '0', '0']])
            
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
levels_dict=create_levels_dict(df,data_type)
target_nodes=create_node_dict_v3(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)

for node in target_nodes:
    print(f'-----------------{node}-------------------')
    train_loader, _ = get_dataloader_v4(node, target_nodes, train_df, val_df, batch_size=1, verbose=False)
    x,y =next(iter(train_loader))

    print(x)
    print(y)
    
    _, ordered_transformation_terms_in_h, _=ordered_parents(node, target_nodes)


    int_input, shift_list = preprocess_inputs_v2(x, ordered_transformation_terms_in_h.values(), device=device)

    # print(f'int_input {int_input}')
    # print(f'shift_list {shift_list}')


{'x1': 'continous', 'ord_bin1': 'ordinal_Xn_Yo', 'ord_multi': 'ordinal_Xn_Yo', 'ord_bin2': 'ordinal_Xn_Yo'}
*************
 Model has Complex intercepts and Complex shifts, please add your Model to the modelzoo 
*************
-----------------x1-------------------
[tensor([1.])]
tensor([0.0675])
[]
-----------------ord_bin1-------------------
[tensor([1.])]
tensor([[0., 1.]])
[]
-----------------ord_multi-------------------
[tensor([1.]), tensor([-0.2258]), tensor([[0., 1.]])]
tensor([[0., 1., 0., 0.]])
[np.str_('cs11'), np.str_('cs12')]
-----------------ord_bin2-------------------
[tensor([-0.0543]), tensor([[0., 1.]]), tensor([[0., 0., 1., 0.]])]
tensor([[1., 0.]])
[np.str_('ci11'), np.str_('ci12'), np.str_('cs')]


# refract preprocess inputs

In [None]:
data_type={key:value for key, value in zip(train_df.columns, ['continous']+['ordinal_Xn_Yo']+['ordinal_Xn_Yo']+['ordinal_Xn_Yo'])}
print(data_type)
configuration_dict=new_conf_dict(experiment_name,EXPERIMENT_DIR,DATA_PATH,LOG_DIR)

adj_matrix=np.array([['0', '0', 'cs11', 'ci11'],
                     ['0', '0', 'cs12', 'ci12'],
                     ['0', '0', '0', 'cs'],
                     ['0', '0', '0', '0']])
            
nn_names_matrix= create_nn_model_names(adj_matrix,data_type)
levels_dict=create_levels_dict(df,data_type)
target_nodes=create_node_dict_v3(adj_matrix, nn_names_matrix, data_type, min_vals, max_vals,levels_dict)


def preprocess_inputs_v3(x,transformation_terms_preprocessing,intercept_indices,shift_groups_indices):
    """
    Prepares model input by grouping features:
      - Intercepts: concatenated from intercept_indices
      - Shifts: list of concatenated tensors per group

    Args:
      x: List of input tensors matching transformation_terms_preprocessing
      device: 'cuda' or 'cpu'

    Returns:
      int_inputs: Tensor of shape (B, n_intercept_features)
      shift_list: List of tensors for each shift group or None if empty
    """
    # Select device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Move inputs to device
    x = [xi.to(device, non_blocking=True) for xi in x]

    # Handle case with no transformation terms (unlikely after intercept insert)
    if not transformation_terms_preprocessing:
        # treat single input as intercept
        return x[0].unsqueeze(1), None

    # Build intercept inputs
    int_tensors = [x[idx] for idx in intercept_indices]
    if not int_tensors:
        raise ValueError("No intercept tensors found!")
    # Flatten and concatenate
    int_inputs = torch.cat(
        [t.view(t.shape[0], -1) for t in int_tensors], dim=1
    )

    # Build shift groups
    shift_list = []
    for group_idxs in shift_groups_indices:
        tensors = [x[idx] for idx in group_idxs]
        group_tensor = torch.cat(
            [t.view(t.shape[0], -1) for t in tensors], dim=1
        )
        shift_list.append(group_tensor)

    return int_inputs, (shift_list if shift_list else None)


def intercept_shift_indexes(transformation_terms_preprocessing):
         # ensure it's a list
        # Always ensure there's an intercept term
        if not any('ci' in str(value) for value in transformation_terms_preprocessing):
            transformation_terms_preprocessing.insert(0, 'si')
            
        intercept_indices = [i for i, term in enumerate(transformation_terms_preprocessing)if term.startswith(('si', 'ci'))]

        shift_groups_indices = []
        current_key = None
        for i, term in enumerate(transformation_terms_preprocessing):
            if term.startswith(('cs', 'ls')):
                # numbered suffix → group by the first 3 chars
                if len(term) > 2 and term[2].isdigit():
                    key = term[:3]
                    if not shift_groups_indices or current_key != key:
                        shift_groups_indices.append([i])
                        current_key = key
                    else:
                        shift_groups_indices[-1].append(i)
                else:
                    # lone 'cs' or 'ls'
                    shift_groups_indices.append([i])
                    current_key = None
        return intercept_indices,shift_groups_indices




for node in target_nodes:
    print(f'-----------------{node}-------------------')
    train_loader, _ = get_dataloader_v4(node, target_nodes, train_df, val_df, batch_size=1, verbose=False)
    x,y =next(iter(train_loader))

    print(x)
    print(y)
    
    _, ordered_transformation_terms_in_h, _=ordered_parents(node, target_nodes)

    transformation_terms_preprocessing = list(ordered_transformation_terms_in_h.values())

    intercept_indices,shift_groups_indices =intercept_shift_indexes(transformation_terms_preprocessing)

    int_input, shift_list = preprocess_inputs_v3(x,transformation_terms_preprocessing,intercept_indices,shift_groups_indices)

    print(f'int_input {int_input}')
    print(f'shift_list {shift_list}')

{'x1': 'continous', 'ord_bin1': 'ordinal_Xn_Yo', 'ord_multi': 'ordinal_Xn_Yo', 'ord_bin2': 'ordinal_Xn_Yo'}
*************
 Model has Complex intercepts and Complex shifts, please add your Model to the modelzoo 
*************
-----------------x1-------------------
[tensor([1.])]
tensor([-0.4081])
int_input tensor([[1.]], device='cuda:0')
shift_list None
-----------------ord_bin1-------------------
[tensor([1.])]
tensor([[1., 0.]])
int_input tensor([[1.]], device='cuda:0')
shift_list None
-----------------ord_multi-------------------
[tensor([1.]), tensor([0.4732]), tensor([[1., 0.]])]
tensor([[0., 0., 0., 1.]])
int_input tensor([[1.]], device='cuda:0')
shift_list [tensor([[0.4732, 1.0000, 0.0000]], device='cuda:0')]
-----------------ord_bin2-------------------
[tensor([-0.2490]), tensor([[0., 1.]]), tensor([[0., 1., 0., 0.]])]
tensor([[0., 1.]])
int_input tensor([[-0.2490,  0.0000,  1.0000]], device='cuda:0')
shift_list [tensor([[0., 1., 0., 0.]], device='cuda:0')]
