In [None]:
class SamplingDataset(Dataset):
    def __init__(self, node,EXPERIMENT_DIR,number_of_samples=100,rootfinder='bisection', target_nodes=None, transform=None):
        """
        Args:
            node (str): Name of the target node.
            conf_dict (dict): Mapping variable names to their types: "cont", "other", "ord".
            transform (callable, optional): Optional transform to apply to image data.
        """
        self.EXPERIMENT_DIR=EXPERIMENT_DIR
        self.node = node
        self.number_of_samples=number_of_samples
        self.target_nodes = target_nodes
        self.transform = transform
        self.rootfinder=rootfinder
        self.predictors = None if target_nodes is None else self._ordered_keys()
        self.datatensors = self._get_sampled_parent_tensors()# shape: (num_parents, num_samples, dim)
        _,self.transformation_terms_in_h, _ = ordered_parents(self.node, self.target_nodes)

    def _get_sampled_parent_tensors(self):
        ## loads parent roots if they are available 
        tensor_list = []
        if self.target_nodes is None or self.target_nodes[self.node]['node_type'] == 'source':
            tensor_list.append(torch.ones(self.number_of_samples) * 1.0)
            return tensor_list 
        else:        
            parents_dataype_dict, _, _ = ordered_parents(self.node, self.target_nodes)
        for parent_pair in parents_dataype_dict:
            # print(parent_pair)
            PARENT_DIR = os.path.join(self.EXPERIMENT_DIR, f'{parent_pair}')
            tensor = load_roots(PARENT_DIR,rootfinder=self.rootfinder)  # expected shape: (num_samples, feature_dim) # TODO change the fucnitno name because its not roots its sampled observations or interventional data
            tensor_list.append(torch.tensor(tensor))  # ensure tensor type
        return tensor_list  # list of tensors, each (num_samples, dim)

    def _ordered_keys(self):
        parents_dataype_dict, _, _ = ordered_parents(self.node, self.target_nodes)
        return list(parents_dataype_dict.keys())

    def __len__(self):        
        return self.datatensors[0].shape[0]  # assuming same num_samples for all

    def __getitem__(self, idx):
        x_data = []
        
        if self.target_nodes is None or self.target_nodes[self.node]['node_type'] == 'source':
            return (torch.tensor([1.0], dtype=torch.float32),)

        if all('i' not in str(value) for value in self.transformation_terms_in_h.values()):
            x = torch.tensor(1.0) 
            x_data.append(x)

        for i, var in enumerate(self.predictors):
            if self.target_nodes[var]['data_type'] == "cont":
                val = self.datatensors[i][idx]
                x_data.append(val.unsqueeze(0))  # ensure shape (1,)
            elif self.target_nodes[var]['data_type'] == "ord":
                val = self.datatensors[i][idx]
                x_data.append(val.unsqueeze(0).long())  # ensure shape (1,) and long
            elif self.target_nodes[var]['data_type'] == "other":
                img_path = self.datatensors[i][idx]
                image = Image.open(img_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                x_data.append(image)
                
                
        # fixed shape here : quick and dirty , resolve later
        squeezed = []
        for t in x_data:
            if isinstance(t, torch.Tensor) and t.dim() == 1 and t.shape[0] == 1:
                squeezed.append(t.squeeze(0))  # from shape (1,) → shape ()
            else:
                squeezed.append(t)
        return tuple(squeezed)