In [None]:
import torch
import numpy as np
from torch_geometric.data import Data


def filter_and_remap_edges(edge_index, unit_mask, n_units, n_attrs):
    """
    Filter and remap edges in bipartite graph (units <-> attributes).
    
    Args:
        edge_index: torch.Tensor of shape [2, E] 
        unit_mask: torch.Tensor of shape [n_units] with boolean values for units to keep
        n_units: Original number of units
        n_attrs: Number of attributes (unchanged)
    
    Returns:
        new_edge_index: torch.Tensor with filtered and remapped edges
    """
    # Create mapping for unit nodes (0 to n_units-1)
    unit_mapping = torch.cumsum(unit_mask, dim=0) - 1
    unit_mapping[~unit_mask] = -1
    
    # Create mapping for attribute nodes (n_units to n_units+n_attrs-1)
    # Attribute indices need to be shifted down by the number of removed units
    n_units_kept = unit_mask.sum().item()
    attr_mapping = torch.arange(n_attrs) + n_units_kept  # New attribute indices
    
    # Create full node mapping
    full_mapping = torch.full((n_units + n_attrs,), -1, dtype=torch.long)
    full_mapping[:n_units] = unit_mapping
    full_mapping[n_units:] = attr_mapping
    
    # Filter edges: keep only edges involving kept units
    unit_nodes_in_edges = edge_index[0] < n_units  # Source is unit
    attr_nodes_in_edges = edge_index[0] >= n_units  # Source is attribute
    
    # For unit->attr edges: keep if unit is in mask
    unit_to_attr = unit_nodes_in_edges & unit_mask[edge_index[0].clamp(max=n_units-1)]
    # For attr->unit edges: keep if target unit is in mask  
    attr_to_unit = attr_nodes_in_edges & unit_mask[edge_index[1].clamp(max=n_units-1)]
    
    edge_mask = unit_to_attr | attr_to_unit
    
    # Apply mapping to filtered edges
    filtered_edges = edge_index[:, edge_mask]
    new_edge_index = full_mapping[filtered_edges]
    
    return new_edge_index


def filter_dataset(data: Data, mask: torch.Tensor) -> Data:
    """
    Filter dataset to keep only units where mask[i] = True.
    Keeps ALL attribute nodes but updates their indices.
    
    Args:
        data: PyTorch Geometric Data object with bipartite structure
        mask: Boolean tensor of shape [n_units] for units to keep
    
    Returns:
        train_data: New Data object with filtered units and remapped structure
    """
    
    # Ensure mask is boolean and on CPU for numpy operations
    mask = mask.bool()
    mask_np = mask.cpu().numpy()
    n_units_kept = mask.sum().item()
    
    # Create new data object
    train_data = copy.deepcopy(data)
    
    # Copy metadata (n_attrs doesn't change)
    for attr in ['n_attrs', 'n_rel_types', 'node_feature_dim', 'edge_attr_dim']:
        if hasattr(data, attr):
            setattr(train_data, attr, getattr(data, attr))
    
    # Update n_units to filtered count
    # n_units, n_attrs = data.n_units, data.n_attrs
    train_data.n_units = n_units_kept

    
    # 1. Filter node features: keep filtered units + all attributes
    if hasattr(data, 'x') and data.x is not None:
        unit_features = data.x[:data.n_units][mask]  # Filter unit nodes
        attr_features = data.x[data.n_units:]        # Keep all attribute nodes
        train_data.x = torch.cat([unit_features, attr_features], dim=0)
    
    # 2. Filter is_unit: keep filtered units + all attributes  
    if hasattr(data, 'is_unit') and data.is_unit is not None:
        unit_is_unit = data.is_unit[:data.n_units][mask]  # Filter unit flags
        attr_is_unit = data.is_unit[data.n_units:]        # Keep attribute flags
        train_data.is_unit = torch.cat([unit_is_unit, attr_is_unit], dim=0)
    
    # 3. Filter unit-level tensors
    if hasattr(data, 'treatment') and data.treatment is not None:
        train_data.treatment = data.treatment[mask]
    
    if hasattr(data, 'outcome') and data.outcome is not None:
        train_data.outcome = data.outcome[mask]
    
    if hasattr(data, 'true_effect') and data.true_effect is not None:
        train_data.true_effect = data.true_effect[mask]
    
    # 4. Filter unit-level masks
    if hasattr(data, 'treatment_mask') and data.treatment_mask is not None:
        train_data.treatment_mask = data.treatment_mask[mask]
    
    if hasattr(data, 'outcome_mask') and data.outcome_mask is not None:
        train_data.outcome_mask = data.outcome_mask[mask]
    
    # 5. Filter split masks (train/val/test)
    if hasattr(data, 'train_mask') and data.train_mask is not None:
        train_data.train_mask = data.train_mask[mask]
    
    if hasattr(data, 'val_mask') and data.val_mask is not None:
        train_data.val_mask = data.val_mask[mask]
    
    if hasattr(data, 'test_mask') and data.test_mask is not None:
        train_data.test_mask = data.test_mask[mask]
    
    # 6. Filter observed_mask (reshape, filter, flatten)
    if hasattr(data, 'observed_mask') and data.observed_mask is not None:
        # Reshape to [n_units, n_attrs], filter units, then flatten
        observed_reshaped = data.observed_mask.view(data.n_units, data.n_attrs)
        observed_filtered = observed_reshaped[mask]  # Shape: [n_units_kept, n_attrs]
        train_data.observed_mask = observed_filtered.flatten()
    
    # 7. Filter and remap bipartite edges (unit <-> attribute)
    if hasattr(data, 'edge_index') and data.edge_index is not None:
        train_data.edge_index = filter_and_remap_edges(
            data.edge_index, mask, data.n_units, data.n_attrs
        )
        
        # Filter edge attributes based on which edges were kept
        if hasattr(data, 'edge_attr') and data.edge_attr is not None:
            # Need to determine which edges were kept
            unit_nodes_in_edges = data.edge_index[0] < data.n_units
            attr_nodes_in_edges = data.edge_index[0] >= data.n_units
            
            unit_to_attr = unit_nodes_in_edges & mask[data.edge_index[0].clamp(max=data.n_units-1)]
            attr_to_unit = attr_nodes_in_edges & mask[data.edge_index[1].clamp(max=data.n_units-1)]
            
            edge_mask = unit_to_attr | attr_to_unit
            train_data.edge_attr = data.edge_attr[edge_mask]
    
    # 8. Filter relational edges (unit <-> unit only)
    if hasattr(data, 'rel_edge_index') and data.rel_edge_index is not None:
        # Relational edges only involve units (indices 0 to n_units-1)
        rel_edge_mask = mask[data.rel_edge_index[0]] & mask[data.rel_edge_index[1]]
        filtered_rel_edges = data.rel_edge_index[:, rel_edge_mask]
        
        # Remap unit indices (0 to n_units_kept-1)
        unit_mapping = torch.cumsum(mask, dim=0) - 1
        train_data.rel_edge_index = unit_mapping[filtered_rel_edges]
        
        # Filter relational edge types
        if hasattr(data, 'rel_edge_type') and data.rel_edge_type is not None:
            train_data.rel_edge_type = data.rel_edge_type[rel_edge_mask]
    
    # 9. Filter adjacency matrix (for network baselines)
    # if hasattr(data, 'A') and data.A is not None:
    #     # Handle sparse tensor - only involves units
    #     if hasattr(data.A, 'to_dense'):
    #         adj_dense = data.A.to_dense().cpu().numpy()
    #     else:
    #         adj_dense = data.A.cpu().numpy()
        
    #     # Filter adjacency matrix (units only)
    #     filtered_adj = adj_dense[np.ix_(mask_np, mask_np)]
        
    #     # Convert back to same format as original
    #     if hasattr(data.A, 'to_dense'):
    #         indices = torch.nonzero(torch.tensor(filtered_adj)).t()
    #         values = torch.tensor(filtered_adj)[torch.nonzero(torch.tensor(filtered_adj), as_tuple=True)]
    #         train_data.A = torch.sparse_coo_tensor(
    #             indices=indices,
    #             values=values,
    #             size=filtered_adj.shape
    #         ).to(data.A.device)
    #     else:
    #         train_data.A = torch.tensor(filtered_adj, device=data.A.device, dtype=data.A.dtype)
    
    # 10. Filter tabular data (numpy arrays) - units only
    if hasattr(data, 'arr_X') and data.arr_X is not None:
        train_data.arr_X = data.arr_X[mask_np]
    
    if hasattr(data, 'arr_YF') and data.arr_YF is not None:
        train_data.arr_YF = data.arr_YF[mask_np]
    
    if hasattr(data, 'arr_Y1') and data.arr_Y1 is not None:
        train_data.arr_Y1 = data.arr_Y1[mask_np]
    
    if hasattr(data, 'arr_Y0') and data.arr_Y0 is not None:
        train_data.arr_Y0 = data.arr_Y0[mask_np]
    
    # Filter multi-dimensional adjacency matrix (units only)
    if hasattr(data, 'arr_Adj') and data.arr_Adj is not None:
        if len(data.arr_Adj.shape) == 2:  # Single adjacency matrix
            train_data.arr_Adj = data.arr_Adj[np.ix_(mask_np, mask_np)]
        elif len(data.arr_Adj.shape) == 3:  # Multi-relational adjacency matrices
            # Convert boolean mask to indices first:
            node_indices = np.where(mask_np)[0]
            # For 3D arrays [R, N, N], filter the last two dimensions:
            train_data.arr_Adj = data.arr_Adj[:, node_indices, :][:, :, node_indices]
            # train_data.arr_Adj = data.arr_Adj[:, np.ix_(mask_np, mask_np)]
    
    # # 11. Filter dataframes (units only)
    # if hasattr(data, 'df_full') and data.df_full is not None:
    #     train_data.df_full = data.df_full[mask_np].reset_index(drop=True)
    
    # if hasattr(data, 'df_miss') and data.df_miss is not None:
    #     train_data.df_miss = data.df_miss[mask_np].reset_index(drop=True)
    
    # if hasattr(data, 'df_imputed') and data.df_imputed is not None:
    #     train_data.df_imputed = data.df_imputed[mask_np].reset_index(drop=True)
    
    return train_data


# Convenience functions
def create_train_val_data(data: Data) -> Data:
    """Create dataset with only train and validation units."""
    mask = data.train_mask | data.val_mask
    return filter_dataset(data, mask)


def create_train_data_only(data: Data) -> Data:
    """Create dataset with only training units."""
    return filter_dataset(data, data.train_mask)


def create_test_data_only(data: Data) -> Data:
    """Create dataset with only test units."""
    return filter_dataset(data, data.test_mask)

import torch 


full_data_path = '/Users/jason/Documents/Coding Projects/2025_Claude/NetDeconf_main_hao/datasets/exps/BlogCatalog/p=0.0_k=9_seed=194.pt'
# full_data_path = 'datasets/exps/Syn/p=0.0_k=0_seed=919.pt'

data = torch.load(full_data_path, weights_only=False)
create_train_data_only(data)
create_train_val_data(data)
create_test_data_only(data)


Data(n_attrs=20, n_rel_types=1, node_feature_dim=20, edge_attr_dim=1, n_units=1039, x=[1059, 20], is_unit=[1059], treatment=[1039], outcome=[1039], true_effect=[1039], treatment_mask=[1039], outcome_mask=[1039], train_mask=[1039], val_mask=[1039], test_mask=[1039], observed_mask=[20780], edge_index=[2, 41560], edge_attr=[41560], rel_edge_index=[2, 13524], rel_edge_type=[13524], arr_X=[1039, 20], arr_YF=[1039], arr_Y1=[1039], arr_Y0=[1039], arr_Adj=[1, 1039, 1039])

In [3]:
import torch
import numpy as np
from torch_geometric.data import Data


def filter_and_remap_edges(edge_index, unit_mask, n_units, n_attrs):
    """
    Filter and remap edges in bipartite graph (units <-> attributes).
    
    Args:
        edge_index: torch.Tensor of shape [2, E] 
        unit_mask: torch.Tensor of shape [n_units] with boolean values for units to keep
        n_units: Original number of units
        n_attrs: Number of attributes (unchanged)
    
    Returns:
        new_edge_index: torch.Tensor with filtered and remapped edges
    """
    # Create mapping for unit nodes (0 to n_units-1)
    unit_mapping = torch.cumsum(unit_mask, dim=0) - 1
    unit_mapping[~unit_mask] = -1
    
    # Create mapping for attribute nodes (n_units to n_units+n_attrs-1)
    # Attribute indices need to be shifted down by the number of removed units
    n_units_kept = unit_mask.sum().item()
    attr_mapping = torch.arange(n_attrs) + n_units_kept  # New attribute indices
    
    # Create full node mapping
    full_mapping = torch.full((n_units + n_attrs,), -1, dtype=torch.long)
    full_mapping[:n_units] = unit_mapping
    full_mapping[n_units:] = attr_mapping
    
    # Filter edges: keep only edges involving kept units
    unit_nodes_in_edges = edge_index[0] < n_units  # Source is unit
    attr_nodes_in_edges = edge_index[0] >= n_units  # Source is attribute
    
    # For unit->attr edges: keep if unit is in mask
    unit_to_attr = unit_nodes_in_edges & unit_mask[edge_index[0].clamp(max=n_units-1)]
    # For attr->unit edges: keep if target unit is in mask  
    attr_to_unit = attr_nodes_in_edges & unit_mask[edge_index[1].clamp(max=n_units-1)]
    
    edge_mask = unit_to_attr | attr_to_unit
    
    # Apply mapping to filtered edges
    filtered_edges = edge_index[:, edge_mask]
    new_edge_index = full_mapping[filtered_edges]
    
    return new_edge_index


def test_edge_filtering_basic():
    """Test 1: Basic functionality with simple bipartite graph"""
    print("=== Test 1: Basic Edge Filtering ===")
    
    # Setup: 3 units, 2 attributes (total 5 nodes)
    n_units, n_attrs = 3, 2
    
    # Create bipartite edges: units (0,1,2) <-> attributes (3,4)
    edge_index = torch.tensor([
        [0, 1, 2, 3, 4, 3],  # sources: unit0->attr0, unit1->attr0, unit2->attr1, attr0->unit0, attr1->unit1, attr0->unit2
        [3, 3, 4, 0, 1, 2]   # targets
    ])
    
    # Edge attributes (values for each edge)
    edge_attr = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0, 2.0])  # Bidirectional edges have same values
    
    # Keep units 0 and 2 (remove unit 1)
    mask = torch.tensor([True, False, True])
    
    print(f"Original graph: {n_units} units + {n_attrs} attrs = {n_units + n_attrs} nodes")
    print(f"Original edges: {edge_index}")
    print(f"Original edge_attr: {edge_attr}")
    print(f"Unit mask: {mask} (keep units 0, 2)")
    
    # Apply filtering
    new_edge_index = filter_and_remap_edges(edge_index, mask, n_units, n_attrs)
    
    # Filter edge attributes using the same logic
    unit_nodes_in_edges = edge_index[0] < n_units
    attr_nodes_in_edges = edge_index[0] >= n_units
    
    unit_to_attr = unit_nodes_in_edges & mask[edge_index[0].clamp(max=n_units-1)]
    attr_to_unit = attr_nodes_in_edges & mask[edge_index[1].clamp(max=n_units-1)]
    
    edge_mask = unit_to_attr | attr_to_unit
    new_edge_attr = edge_attr[edge_mask]
    
    print(f"\nFiltered edges: {new_edge_index}")
    print(f"Filtered edge_attr: {new_edge_attr}")
    print(f"Edge mask: {edge_mask}")
    
    # Expected result:
    # - Keep edges involving units 0 and 2
    # - Unit 0 -> Unit 0 (remapped), Unit 2 -> Unit 1 (remapped)
    # - Attributes 3,4 -> 2,3 (remapped)
    expected_edges = torch.tensor([
        [0, 1, 2, 3, 2],  # unit0->attr0, unit2->attr1, attr0->unit0, attr1->unit2, attr0->unit1
        [2, 3, 0, 1, 1]
    ])
    
    print(f"Expected edges: {expected_edges}")
    print(f"Match: {torch.equal(new_edge_index, expected_edges)}")
    

def test_edge_filtering_no_edges():
    """Test 2: Edge case with no edges"""
    print("\n=== Test 2: No Edges Case ===")
    
    n_units, n_attrs = 2, 2
    edge_index = torch.zeros((2, 0), dtype=torch.long)  # No edges
    edge_attr = torch.zeros(0)
    mask = torch.tensor([True, False])
    
    print(f"Original: no edges, mask={mask}")
    
    new_edge_index = filter_and_remap_edges(edge_index, mask, n_units, n_attrs)
    
    unit_nodes_in_edges = edge_index[0] < n_units
    attr_nodes_in_edges = edge_index[0] >= n_units
    unit_to_attr = unit_nodes_in_edges & mask[edge_index[0].clamp(max=n_units-1)]
    attr_to_unit = attr_nodes_in_edges & mask[edge_index[1].clamp(max=n_units-1)]
    edge_mask = unit_to_attr | attr_to_unit
    new_edge_attr = edge_attr[edge_mask]
    
    print(f"Result: edges={new_edge_index}, attrs={new_edge_attr}")
    print(f"Shapes: edges={new_edge_index.shape}, attrs={new_edge_attr.shape}")


def test_edge_filtering_keep_all():
    """Test 3: Keep all units"""
    print("\n=== Test 3: Keep All Units ===")
    
    n_units, n_attrs = 2, 2
    edge_index = torch.tensor([
        [0, 1, 2, 3],  # unit0->attr0, unit1->attr1, attr0->unit0, attr1->unit1
        [2, 3, 0, 1]
    ])
    edge_attr = torch.tensor([5.0, 6.0, 5.0, 6.0])
    mask = torch.tensor([True, True])  # Keep all units
    
    print(f"Original: edges={edge_index}, mask={mask}")
    
    new_edge_index = filter_and_remap_edges(edge_index, mask, n_units, n_attrs)
    
    unit_nodes_in_edges = edge_index[0] < n_units
    attr_nodes_in_edges = edge_index[0] >= n_units
    unit_to_attr = unit_nodes_in_edges & mask[edge_index[0].clamp(max=n_units-1)]
    attr_to_unit = attr_nodes_in_edges & mask[edge_index[1].clamp(max=n_units-1)]
    edge_mask = unit_to_attr | attr_to_unit
    new_edge_attr = edge_attr[edge_mask]
    
    print(f"Result: edges={new_edge_index}, attrs={new_edge_attr}")
    
    # Should be identical (just attribute indices shifted)
    expected_edges = torch.tensor([
        [0, 1, 2, 3],  # Same structure but attrs now at indices 2,3
        [2, 3, 0, 1]
    ])
    print(f"Expected: {expected_edges}")
    print(f"Match: {torch.equal(new_edge_index, expected_edges)}")


def test_edge_filtering_complex():
    """Test 4: More complex graph with multiple connections"""
    print("\n=== Test 4: Complex Graph ===")
    
    n_units, n_attrs = 4, 3
    
    # Create a more complex bipartite graph
    # Units 0,1,2,3 connect to attributes 4,5,6
    edge_index = torch.tensor([
        [0, 0, 1, 2, 2, 3, 4, 5, 6, 5],  # Various unit->attr and attr->unit connections
        [4, 5, 4, 5, 6, 6, 0, 0, 2, 2]
    ])
    edge_attr = torch.tensor([1.1, 1.2, 2.1, 3.1, 3.2, 4.1, 1.1, 1.2, 3.2, 3.1])
    
    # Keep units 0, 2, 3 (remove unit 1)
    mask = torch.tensor([True, False, True, True])
    
    print(f"Original graph: {n_units} units + {n_attrs} attrs")
    print(f"Original edges: {edge_index}")
    print(f"Original edge_attr: {edge_attr}")
    print(f"Unit mask: {mask} (remove unit 1)")
    
    new_edge_index = filter_and_remap_edges(edge_index, mask, n_units, n_attrs)
    
    unit_nodes_in_edges = edge_index[0] < n_units
    attr_nodes_in_edges = edge_index[0] >= n_units
    unit_to_attr = unit_nodes_in_edges & mask[edge_index[0].clamp(max=n_units-1)]
    attr_to_unit = attr_nodes_in_edges & mask[edge_index[1].clamp(max=n_units-1)]
    edge_mask = unit_to_attr | attr_to_unit
    new_edge_attr = edge_attr[edge_mask]
    
    print(f"\nFiltered edges: {new_edge_index}")
    print(f"Filtered edge_attr: {new_edge_attr}")
    print(f"Edge mask: {edge_mask}")
    
    # Manual verification
    print("\nManual verification:")
    for i, keep in enumerate(edge_mask):
        if keep:
            orig_edge = edge_index[:, i]
            new_edge = new_edge_index[:, edge_mask[:i+1].sum()-1] if edge_mask[:i+1].sum() > 0 else None
            print(f"  Edge {i}: {orig_edge} -> {new_edge} (attr: {edge_attr[i]} -> {new_edge_attr[edge_mask[:i+1].sum()-1] if edge_mask[:i+1].sum() > 0 else 'N/A'})")


def test_edge_attr_consistency():
    """Test 5: Verify edge_attr filtering matches edge_index filtering"""
    print("\n=== Test 5: Edge Attribute Consistency ===")
    
    n_units, n_attrs = 3, 2
    edge_index = torch.tensor([
        [0, 1, 2, 3, 4],
        [3, 4, 3, 1, 2]
    ])
    edge_attr = torch.tensor([10.0, 20.0, 30.0, 10.0, 30.0])
    mask = torch.tensor([False, True, True])  # Keep units 1, 2
    
    print(f"Original: {edge_index.shape[1]} edges")
    print(f"Edge values: {edge_attr}")
    print(f"Mask: {mask}")
    
    new_edge_index = filter_and_remap_edges(edge_index, mask, n_units, n_attrs)
    
    # Apply the same filtering logic to edge attributes
    unit_nodes_in_edges = edge_index[0] < n_units
    attr_nodes_in_edges = edge_index[0] >= n_units
    unit_to_attr = unit_nodes_in_edges & mask[edge_index[0].clamp(max=n_units-1)]
    attr_to_unit = attr_nodes_in_edges & mask[edge_index[1].clamp(max=n_units-1)]
    edge_mask = unit_to_attr | attr_to_unit
    new_edge_attr = edge_attr[edge_mask]
    
    print(f"Filtered: {new_edge_index.shape[1]} edges")
    print(f"Filtered values: {new_edge_attr}")
    
    # Verify same number of edges and attributes
    assert new_edge_index.shape[1] == new_edge_attr.shape[0], \
        f"Mismatch: {new_edge_index.shape[1]} edges vs {new_edge_attr.shape[0]} attributes"
    
    print("âœ… Edge count matches attribute count")
    
    # Verify edge attribute values make sense
    print("\nEdge-by-edge verification:")
    for i in range(new_edge_index.shape[1]):
        edge = new_edge_index[:, i]
        attr_val = new_edge_attr[i]
        print(f"  Edge {i}: {edge[0]} -> {edge[1]}, value: {attr_val}")


def run_comprehensive_tests():
    """Run all tests"""
    print("ğŸ§ª Testing Edge Filtering and Edge Attribute Logic\n")
    
    test_edge_filtering_basic()
    test_edge_filtering_no_edges()
    test_edge_filtering_keep_all()
    test_edge_filtering_complex()
    test_edge_attr_consistency()
    
    print("\nâœ… All tests completed!")


# Example of how to use this in your actual filtering function
def example_usage():
    """Example showing how this integrates with your filter_dataset function"""
    print("\n=== Example Integration ===")
    
    # Simulate a data object
    class MockData:
        def __init__(self):
            self.n_units = 3
            self.n_attrs = 2
            self.edge_index = torch.tensor([
                [0, 1, 2, 3, 4],
                [3, 4, 3, 0, 1]
            ])
            self.edge_attr = torch.tensor([1.0, 2.0, 3.0, 1.0, 2.0])
    
    data = MockData()
    mask = torch.tensor([True, False, True])  # Keep units 0, 2
    
    print(f"Original data: {data.n_units} units, {data.n_attrs} attrs")
    print(f"Original edges: {data.edge_index}")
    print(f"Original edge_attr: {data.edge_attr}")
    
    # Apply the filtering logic from your function
    train_data = MockData()  # New filtered data object
    
    # Filter and remap edges
    train_data.edge_index = filter_and_remap_edges(
        data.edge_index, mask, data.n_units, data.n_attrs
    )
    
    # Filter edge attributes based on which edges were kept
    if hasattr(data, 'edge_attr') and data.edge_attr is not None:
        # Need to determine which edges were kept
        unit_nodes_in_edges = data.edge_index[0] < data.n_units
        attr_nodes_in_edges = data.edge_index[0] >= data.n_units
        
        unit_to_attr = unit_nodes_in_edges & mask[data.edge_index[0].clamp(max=data.n_units-1)]
        attr_to_unit = attr_nodes_in_edges & mask[data.edge_index[1].clamp(max=data.n_units-1)]
        
        edge_mask = unit_to_attr | attr_to_unit
        train_data.edge_attr = data.edge_attr[edge_mask]
    
    print(f"\nFiltered data:")
    print(f"Filtered edges: {train_data.edge_index}")
    print(f"Filtered edge_attr: {train_data.edge_attr}")
    print(f"Edge count consistency: {train_data.edge_index.shape[1] == train_data.edge_attr.shape[0]}")


if __name__ == "__main__":
    run_comprehensive_tests()
    example_usage()

ğŸ§ª Testing Edge Filtering and Edge Attribute Logic

=== Test 1: Basic Edge Filtering ===
Original graph: 3 units + 2 attrs = 5 nodes
Original edges: tensor([[0, 1, 2, 3, 4, 3],
        [3, 3, 4, 0, 1, 2]])
Original edge_attr: tensor([1., 2., 3., 1., 3., 2.])
Unit mask: tensor([ True, False,  True]) (keep units 0, 2)

Filtered edges: tensor([[0, 1, 2, 2],
        [2, 3, 0, 1]])
Filtered edge_attr: tensor([1., 3., 1., 2.])
Edge mask: tensor([ True, False,  True,  True, False,  True])
Expected edges: tensor([[0, 1, 2, 3, 2],
        [2, 3, 0, 1, 1]])
Match: False

=== Test 2: No Edges Case ===
Original: no edges, mask=tensor([ True, False])
Result: edges=tensor([], size=(2, 0), dtype=torch.int64), attrs=tensor([])
Shapes: edges=torch.Size([2, 0]), attrs=torch.Size([0])

=== Test 3: Keep All Units ===
Original: edges=tensor([[0, 1, 2, 3],
        [2, 3, 0, 1]]), mask=tensor([True, True])
Result: edges=tensor([[0, 1, 2, 3],
        [2, 3, 0, 1]]), attrs=tensor([5., 6., 5., 6.])
Expected: 