INITIAL MVP ALL IN HERE -> spread into proper files and implement properly

- currently only uses: LSTMConceptCorrector
- realignment (i.e., intervention updates) happens only within a cluster, more specifically:
  A custom “three-state mask” pipeline for concept realignment, with:
  0 = open (the LSTM can update/realign this concept),
  1 = permanently locked to ground truth (once an intervention happens -> replace by ground truth),
  2 = temporarily locked for the current iteration (but reverts to 0 in the next iteration if not permanently locked -> for out of cluster concepts)
- for now use synthetic data generation

TODO:

- get 2nd perspective on code
- make sure everything moved to GPU
- extend by other models for alignment (GRU and MLP and ...)
- spread code properly over multiple files and do not implement in notebook
- fit into whole pipeline, i.e. using real data and CBM model
- correct all variable and par names and make them in line with proposal/final text

MINOR NOTES:

- ensure that all clusters and labels have atleast one observation (synthetic data context)

GOOD TEXT CHUNKS FOR WRITTING (shorten and cut out certain parts):

## Concept Clusters and Implicit Propagation

### Concept Clusters

In machine learning models, especially in complex domains like medical diagnosis, **concept clusters** are essential. These clusters consist of related or interdependent concepts that collectively represent a broader phenomenon. For example, in diagnosing respiratory infections, concepts such as "fever," "cough," and "fatigue" may form a cluster. Clustering helps manage inherent dependencies among concepts, ensuring that interventions on one concept naturally influence others within the same cluster, thereby maintaining consistency and coherence in the model's predictions.

### Implicit Propagation via LSTM

To capture dependencies within concept clusters, **Long Short-Term Memory (LSTM)** networks are utilized for their ability to model **sequential dependencies** and **temporal patterns**. When an intervention targets a single concept, the LSTM effectively models the intricate relationships within the cluster. During training, intervening on a concept like "fever" influences related concepts such as "cough" and "fatigue," enabling the LSTM to **implicitly propagate** the intervention's effects across the entire cluster. This fosters a holistic and interconnected understanding of the underlying concepts.

## Pipeline Mechanism for Implicit Propagation

### Intervention on Individual Concepts

Central to the pipeline is the **intervene function**, which selects and applies interventions on individual concepts based on an **uncertainty-based policy**. For each sample in a batch, the function identifies the most critical concept—determined by its uncertainty—and sets it to its ground truth value, marking it as **permanently locked** (`mask=1`). Additionally, all other concepts within the same cluster are **temporarily locked** (`mask=2`) to prevent unintended modifications during the current intervention round. This targeted intervention provides a clear ground truth signal for the model to correct its predictions. Consequently, the LSTM observes these corrections and learns to adjust related concepts within the same cluster **implicitly**, enhancing the model's ability to refine its predictions systematically.

### LSTM Concept Corrector

The **LSTMConceptCorrector** is responsible for realigning and correcting concept vectors based on interventions. It processes sequences of concept vectors, capturing the **temporal evolution** of corrections across multiple intervention steps. The model inputs a combination of **locked** (intervened) and **open** (modifiable) concepts. Locked concepts remain unchanged, while open concepts are adjusted based on the LSTM's learned patterns. Temporarily locked concepts are held steady during the current intervention round but are reset to open (`mask=0`) after realignment, allowing for future interventions if necessary. Through successive intervention and correction cycles, the LSTM learns the dependencies within concept clusters, effectively modeling how changes in one concept influence others within the same group. This dynamic adjustment ensures coherence and accuracy in the model's concept representations.

### Sample Trajectory Function

The **sample trajectory function** simulates the iterative process of interventions and realignments. It orchestrates a sequence of intervention steps, each followed by a realignment phase handled by the LSTMConceptCorrector. By conducting multiple rounds of interventions, the function enables the LSTM to progressively refine its concept predictions based on prior corrections. This iterative approach enhances the model's ability to correct individual concepts and reinforces the **implicit propagation** of interventions across related concepts within clusters. Consequently, the model develops more accurate and interdependent concept representations, closely aligning with the underlying structure of the concept clusters.


In [None]:
# JUST A FEW THINGS TO BE AWARE OF:

# B - Batch size
# T - Number of time steps
# k - number of concepts

In [1]:
# IMPORTS


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [2]:
# SYNTHETIC DATA GENERATION


def generate_synthetic_data(k, n, J, m, seed):
    """
    k: number of concepts
    n: number of observations
    J: number of target classes
    m: number of concept clusters
    seed: random seed

    Returns:
      predicted_concepts: float in [0,1], shape (n, k)
      groundtruth_concepts: binary in {0,1}, shape (n, k)
      cluster_assignments: dict {cluster_id: [concept_indices]}
      labels: integer class label for each observation in [0..J-1], shape (n,)
    """
    torch.manual_seed(seed)

    # predicted concepts in [0,1]
    predicted_concepts = torch.rand(n, k)

    # ground truth concepts in {0,1}
    groundtruth_concepts = (torch.rand(n, k) > 0.5).float()

    # create random cluster assignment
    cluster_assignments = {cid: [] for cid in range(m)}
    for concept_idx in range(k):
        assigned_cluster = torch.randint(low=0, high=m, size=(1,)).item()
        cluster_assignments[assigned_cluster].append(concept_idx)

    # randomly assign labels in {0,...,J-1}
    labels = torch.randint(low=0, high=J, size=(n,))

    return predicted_concepts, groundtruth_concepts, cluster_assignments, labels

In [3]:
# INTERVENTION POLICY


def ucp(concepts, already_intervened_concepts):
    """
    Uncertainty-based Concept Picking (UCP) policy.

    Args:
        concepts (torch.Tensor): Current concept values, shape (B, k).
        already_intervened_concepts (torch.Tensor): Mask indicating interventions, shape (B, k).
                                                  1 => permanently locked, 0 => open, 2 => temporarily locked.

    Returns:
        importances (torch.Tensor): Importance scores for each concept, shape (B, k).
    """
    eps = 1e-8
    # Importance inversely proportional to distance from 0.5
    importances = 1.0 / (torch.abs(concepts - 0.5) + eps)

    # Exclude permanently and temporarily locked concepts by setting their importance to a large negative value
    importances[(already_intervened_concepts == 1) | (already_intervened_concepts == 2)] = -1e10

    return importances

In [4]:
# INTERVENTION FUNCTION


def intervene(concepts, estimated_concepts, concept_to_cluster, already_intervened_concepts, groundtruth_concepts, 
             intervention_policy=ucp):
    """
    Applies the intervention policy to select and intervene on individual concepts.

    Args:
        concepts (torch.Tensor): Current concept values, shape (B, k).
        estimated_concepts (torch.Tensor): Estimated concept predictions, shape (B, k).
        concept_to_cluster (list): List mapping each concept to its cluster, length k.
        already_intervened_concepts (torch.Tensor): Mask indicating interventions, shape (B, k).
                                                  1 => permanently locked, 0 => open, 2 => temporarily locked.
        groundtruth_concepts (torch.Tensor): Ground truth concept values, shape (B, k).
        intervention_policy (function): Function to compute importances.

    Returns:
        concepts_new (torch.Tensor): Updated concept values after intervention, shape (B, k).
        intervened_concepts_new (torch.Tensor): Updated mask, shape (B, k).
    """
    # Compute importances using the chosen policy
    importances = intervention_policy(concepts, already_intervened_concepts)
    
    B, k = concepts.shape

    # Select the most important concept to intervene on for each sample
    concepts_to_intervene = torch.argmax(importances, dim=1)  # shape: (B,)

    # Replace selected concepts with ground truth values
    concepts_new = concepts.clone()
    concepts_new[range(B), concepts_to_intervene] = groundtruth_concepts[range(B), concepts_to_intervene]

    # Update the intervention mask to permanently lock the intervened concepts
    intervened_concepts_new = already_intervened_concepts.clone()
    intervened_concepts_new[range(B), concepts_to_intervene] = 1

    # Temporarily lock other concepts in the same cluster
    for b in range(B):
        selected_concept = concepts_to_intervene[b].item()
        cluster_id = concept_to_cluster[selected_concept]
        cluster_concepts = [c for c in range(k) if concept_to_cluster[c] == cluster_id and c != selected_concept]
        if cluster_concepts:
            intervened_concepts_new[b, cluster_concepts] = 2  # Temporarily locked

    return concepts_new, intervened_concepts_new

In [5]:
# LSTMConceptCorrector WITH 3-STATE MASK


class LSTMConceptCorrector(nn.Module):
    """
    An LSTM-based model that realigns concept vectors based on interventions.

    Mask values:
        0 => open (LSTM can adjust this concept)
        1 => permanently locked to ground truth (once intervened) -> entered ground truth value
        2 => temporarily locked (cannot be changed in the current round) -> for concepts outside of intervention cluster
    """
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        """
        Initializes the LSTMConceptCorrector.

        Args:
            input_size (int): Number of input features (concepts).
            hidden_size (int): Number of features in the hidden state of the LSTM.
            num_layers (int): Number of recurrent layers in the LSTM.
            output_size (int): Number of output features (concepts).
        """
        super(LSTMConceptCorrector, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Define the LSTM layer
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

        # Define a fully connected layer to map LSTM outputs to concept space
        self.fc = nn.Linear(hidden_size, output_size)

    def prepare_initial_hidden(self, batch_size, device):
        """
        Prepares the initial hidden and cell states for the LSTM.

        Args:
            batch_size (int): Number of samples in the batch.
            device (torch.device): Device to place the hidden states.

        Returns:
            tuple: (h0, c0) initial hidden and cell states.
        """
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        return (h0, c0)

    def forward(self, inputs, mask, estimated_concepts, hidden):
        """
        Forward pass of the LSTMConceptCorrector.

        Args:
            inputs (torch.Tensor): Current concept values, shape (B, T, k).
            mask (torch.Tensor): Mask indicating interventions, shape (B, T, k).
                                 0 => open, 1 => permanently locked, 2 => temporarily locked.
            estimated_concepts (torch.Tensor): Estimated concept predictions, shape (B, T, k).
            hidden (tuple): Initial hidden and cell states for the LSTM.

        Returns:
            torch.Tensor: Updated concept values after realignment, shape (B, T, k).
            tuple: Updated hidden and cell states.
        """
        # Define which concepts are open, permanently locked, and temporarily locked
        mask_open = (mask == 0).float()          # 1.0 where open, 0.0 otherwise
        mask_perma_locked = (mask == 1).float()  # 1.0 where permanently locked
        mask_temp_locked = (mask == 2).float()   # 1.0 where temporarily locked

        # Create input for LSTM:
        # - For permanently locked concepts (mask=1), use the current inputs (ground truth).
        # - For temporarily locked (mask=2) and open concepts (mask=0), use the estimated predictions.
        x = mask_perma_locked * inputs + (mask_temp_locked + mask_open) * estimated_concepts  # Shape: (B, T, k)

        # Pass through LSTM
        lstm_out, hidden = self.lstm(x, hidden)  # lstm_out: (B, T, hidden_size)

        # Map LSTM outputs to concept space and apply sigmoid activation
        corrected_raw = torch.sigmoid(self.fc(lstm_out))  # Shape: (B, T, k)

        # Combine corrected concepts with locked and temporarily locked concepts:
        # - Keep permanently locked concepts as-is
        # - Keep temporarily locked concepts as estimated predictions
        # - Update open concepts with the LSTM's corrections
        output = mask_perma_locked * inputs + mask_temp_locked * estimated_concepts + mask_open * corrected_raw  # Shape: (B, T, k)

        return output, hidden

    def forward_single_timestep(self, inputs, mask, estimated_concepts, hidden):
        """
        Forward pass for a single time step.

        Args:
            inputs (torch.Tensor): Current concept values, shape (B, k).
            mask (torch.Tensor): Mask indicating interventions, shape (B, k).
                                 0 => open, 1 => permanently locked, 2 => temporarily locked.
            estimated_concepts (torch.Tensor): Estimated concept predictions, shape (B, k).
            hidden (tuple): Initial hidden and cell states for the LSTM.

        Returns:
            torch.Tensor: Updated concept values after realignment, shape (B, k).
            tuple: Updated hidden and cell states.
        """
        # Add a time dimension of 1 to match the LSTM's expected input shape
        inputs_ = inputs.unsqueeze(1)          # Shape: (B, 1, k)
        mask_ = mask.unsqueeze(1)              # Shape: (B, 1, k)
        est_ = estimated_concepts.unsqueeze(1) # Shape: (B, 1, k)

        # Forward pass through the LSTMConceptCorrector
        out, hidden = self.forward(inputs_, mask_, est_, hidden)  # out: (B, 1, k)

        # Remove the time dimension
        out = out.squeeze(1)  # Shape: (B, k)

        return out, hidden

In [8]:
# SAMPLE TRAJECTORY FUNCTION


def sample_trajectory(model, predicted_concepts, groundtruth_concepts, 
                     concept_to_cluster, max_interventions=3, intervention_policy=ucp):
    """
    Simulates multiple rounds of interventions and realignments.

    Args:
        model (nn.Module): The concept corrector model.
        predicted_concepts (torch.Tensor): Predicted concepts, shape (B, k).
        groundtruth_concepts (torch.Tensor): Ground truth concepts, shape (B, k).
        concept_to_cluster (list): List mapping each concept to its cluster, length k.
        max_interventions (int): Number of intervention steps to perform.
        intervention_policy (function): Function to compute importances.

    Returns:
        list of torch.Tensor: Concept vectors at each intervention step, length (max_interventions + 1),
                              each tensor of shape (B, k).
    """
    device = predicted_concepts.device
    B, k = predicted_concepts.shape

    # Initialize masks: 1 => permanently locked, 0 => open, 2 => temporarily locked
    already_intervened_concepts = torch.zeros(B, k).to(device)

    # Clone predicted concepts to start
    current_concepts = predicted_concepts.clone()

    # Prepare initial hidden state
    hidden = model.prepare_initial_hidden(B, device)

    # Store all steps for analysis
    all_steps = [current_concepts.clone()]

    for step in range(max_interventions):
        # Apply intervention
        concepts_new, intervened_concepts_new = intervene(
            concepts=current_concepts,
            estimated_concepts=predicted_concepts,
            concept_to_cluster=concept_to_cluster,
            already_intervened_concepts=already_intervened_concepts,
            groundtruth_concepts=groundtruth_concepts,
            intervention_policy=intervention_policy
        )

        # Update concepts and masks
        current_concepts = concepts_new
        already_intervened_concepts = intervened_concepts_new

        # Realign with the model
        corrected_concepts, hidden = model.forward_single_timestep(
            inputs=current_concepts,
            mask=already_intervened_concepts.float(),  # Convert mask to float for the model
            estimated_concepts=predicted_concepts,
            hidden=hidden
        )

        # Update current concepts with corrected values
        current_concepts = corrected_concepts

        # Reset temporary locks (mask=2) back to open (mask=0)
        already_intervened_concepts = torch.where(already_intervened_concepts == 2, 
                                                 torch.zeros_like(already_intervened_concepts), 
                                                 already_intervened_concepts)

        # Store the step
        all_steps.append(current_concepts.clone())

    return all_steps

In [9]:
# TRAINING LOOP


def main():

    # hyperpar synthetic data
    k = 6           # Number of concepts
    n = 100         # Number of observations
    J = 3           # Number of target classes
    m = 2           # Number of concept clusters
    seed = 42       # Random seed for reproducibility

    # Generate synthetic data
    predicted_concepts, groundtruth_concepts, cluster_assignments, labels = generate_synthetic_data(
        k=k,
        n=n,
        J=J,
        m=m,
        seed=seed
    )

    # Print Cluster Assignments for Reference
    print("=== Cluster Assignments ===")
    for cid, c_list in cluster_assignments.items():
        print(f"Cluster {cid}: {c_list}")
    print()

    # Prepare concept_to_cluster list
    concept_to_cluster = [0] * k  # Initialize list
    for cid, c_list in cluster_assignments.items():
        for c in c_list:
            concept_to_cluster[c] = cid

    # Create dataset and dataloader
    dataset = TensorDataset(predicted_concepts, groundtruth_concepts, labels)
    batch_size = 16
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize the LSTMConceptCorrector model
    hidden_size = 16
    num_layers = 1
    output_size = k  # Same as number of concepts
    model = LSTMConceptCorrector(input_size=k, hidden_size=hidden_size, num_layers=num_layers, output_size=output_size)
    device = torch.device("cpu")  # Change to "cuda" if GPU is available
    model = model.to(device)

    # Move data to device
    predicted_concepts = predicted_concepts.to(device)
    groundtruth_concepts = groundtruth_concepts.to(device)

    # Define loss criterion and optimizer
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Training hyperparameters
    epochs = 10
    max_interventions = 3  # Number of intervention steps per batch

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for batch_idx, (pred_c, gt_c, lbls) in enumerate(dataloader):
            # Move batch data to device
            pred_c = pred_c.to(device)  # Shape: (B, k)
            gt_c = gt_c.to(device)      # Shape: (B, k)
            lbls = lbls.to(device)      # Shape: (B,)

            optimizer.zero_grad()

            # Perform multiple intervention steps and realignments
            all_steps = sample_trajectory(
                model=model,
                predicted_concepts=pred_c,
                groundtruth_concepts=gt_c,
                concept_to_cluster=concept_to_cluster,
                max_interventions=max_interventions,
                intervention_policy=ucp
            )

            # Use the final corrected concepts for loss computation
            final_corrected_concepts = all_steps[-1]  # Shape: (B, k)

            # Compute loss against ground truth concepts
            loss = criterion(final_corrected_concepts, gt_c)

            # Backpropagation and optimization step
            loss.backward()
            optimizer.step()

            # Accumulate loss
            total_loss += loss.item()

        # Compute average loss for the epoch
        average_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{epochs}] | Average Loss: {average_loss:.4f}")

    print("\nTraining complete!")

    # (Optional) Inspect final cluster assignments and model predictions
    print("\n=== Final Cluster Assignments ===")
    for cid, c_list in cluster_assignments.items():
        print(f"Cluster {cid}: {c_list}")

    # (Optional) Print some example corrected concepts
    print("\n=== Example Corrected Concepts ===")
    example_steps = sample_trajectory(
        model=model,
        predicted_concepts=predicted_concepts[:5],      # Take first 5 samples
        groundtruth_concepts=groundtruth_concepts[:5],
        concept_to_cluster=concept_to_cluster,
        max_interventions=max_interventions,
        intervention_policy=ucp
    )

    for step_idx, cvec in enumerate(example_steps):
        print(f"Step {step_idx}:")
        print(cvec)
        print()

In [10]:
# RUN MAIN FUNCTION


if __name__ == "__main__":
    main()

=== Cluster Assignments ===
Cluster 0: [0, 4]
Cluster 1: [1, 2, 3, 5]

Epoch [1/10] | Average Loss: 0.5128
Epoch [2/10] | Average Loss: 0.5079
Epoch [3/10] | Average Loss: 0.5263
Epoch [4/10] | Average Loss: 0.5074
Epoch [5/10] | Average Loss: 0.5125
Epoch [6/10] | Average Loss: 0.5175
Epoch [7/10] | Average Loss: 0.5127
Epoch [8/10] | Average Loss: 0.5187
Epoch [9/10] | Average Loss: 0.5218
Epoch [10/10] | Average Loss: 0.5058

Training complete!

=== Final Cluster Assignments ===
Cluster 0: [0, 4]
Cluster 1: [1, 2, 3, 5]

=== Example Corrected Concepts ===
Step 0:
tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009],
        [0.2566, 0.7936, 0.9408, 0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411, 0.4294, 0.8854, 0.5739],
        [0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317],
        [0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062]])

Step 1:
tensor([[0.4813, 0.9150, 0.3829, 0.9593, 0.4593, 0.0000],
        [0.4818, 0.7936, 0.9408, 0.1332, 0.4586, 0.0000],
        [0.