In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pickle
from tqdm import tqdm
import random
import numpy as np

device="cuda:0"

In [2]:
#Prepare the random seeds

def set_seed(seed):
    # Set the seed for the random module
    random.seed(seed)
    
    # Set the seed for numpy
    np.random.seed(seed)
    
    # Set the seed for PyTorch (CPU)
    torch.manual_seed(seed)
    
    # Set the seed for PyTorch (GPU) if you are using CUDA
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    
    # Ensure deterministic behavior in PyTorch
    torch.backends.cudnn.deterministic = True  # This makes the computations deterministic
    torch.backends.cudnn.benchmark = False  # Disable auto-tuning for performance optimization
    
    # For reproducibility of other libraries like Python's `random`
    torch.random.manual_seed(seed)

# Set the seed
set_seed(53)

In [3]:
#Classification model

class MLPBlock(nn.Module):
    def __init__(self, in_features=16, out_features=16, dropout_prob=0.0):
        super(MLPBlock, self).__init__()
        self.ff1 = nn.Linear(in_features, out_features)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        x = self.ff1(x)
        x = self.act(x)
        x = self.dropout(x)
        return x

class MLPModel(nn.Module):
    def __init__(self, input_size=16, hidden_size=16, num_blocks=3, dropout_prob=0.0):
        super(MLPModel, self).__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.h = nn.ModuleList([MLPBlock(hidden_size, hidden_size, dropout_prob) for _ in range(num_blocks)])

    def forward(self, x):
        x = self.dropout(x)
        for layer in self.h:
            x = layer(x)
        return x

class MLPForClassification(nn.Module):
    def __init__(self, input_size=16, hidden_size=16, num_classes=2, num_blocks=3, dropout_prob=0.0):
        super(MLPForClassification, self).__init__()
        self.mlp = MLPModel(input_size, hidden_size, num_blocks, dropout_prob)
        self.score = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.mlp(x)
        x = self.score(x)
        return x

model = MLPForClassification()
model.to(device)

MLPForClassification(
  (mlp): MLPModel(
    (dropout): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-2): 3 x MLPBlock(
        (ff1): Linear(in_features=16, out_features=16, bias=True)
        (act): ReLU()
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (score): Linear(in_features=16, out_features=2, bias=True)
)

In [4]:
# Training of the classification model

# Load Training Dataset:
# Load preprocessed training features (X_train) and labels (y_train) from pickle files
with open("X_train.pkl", "rb") as f:
    X_train = pickle.load(f)

with open("y_train.pkl", "rb") as f:
    y_train = pickle.load(f)

# Move data to the specified device (CPU/GPU) for training
X_train = X_train.to(device)
y_train = y_train.to(device)

# Create DataLoader:
batch_size = 1024  # Number of samples per batch
epochs = 3  # Number of training iterations over the entire dataset

# Create a PyTorch dataset and DataLoader for batch processing
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  # Shuffle data for better generalization

# Initialize model, loss function, and optimizer:
criterion = nn.CrossEntropyLoss()  # Loss function for multi-class classification
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer with learning rate of 0.001

# Training Loop:
model.train()  # Set model to training mode
for epoch in range(epochs):
    total_loss = 0  # Track total loss for the epoch

    # Iterate over training batches
    for X_batch, y_batch in tqdm(train_loader):  
        optimizer.zero_grad()  # Reset gradients before each batch
        
        outputs = model(X_batch)  # Forward pass: compute model predictions
        loss = criterion(outputs, y_batch.squeeze().long())  # Compute loss

        loss.backward()  # Backpropagation: compute gradients
        optimizer.step()  # Update model parameters

        total_loss += loss.item()  # Accumulate loss

    # Print average loss for the epoch
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")  


100%|██████████████████████████████████████| 1024/1024 [00:09<00:00, 106.38it/s]


Epoch 1, Loss: 0.34166701680078404


100%|██████████████████████████████████████| 1024/1024 [00:09<00:00, 109.45it/s]


Epoch 2, Loss: 0.018978380689986807


100%|██████████████████████████████████████| 1024/1024 [00:09<00:00, 109.35it/s]

Epoch 3, Loss: 0.004886364653827968





In [5]:
# Testing the Classification Model:

# Load Testing Dataset:
# Load preprocessed testing features (X_test) and labels (y_test) from pickle files
with open("X_test.pkl", "rb") as f:
    X_test = pickle.load(f)

with open("y_test.pkl", "rb") as f:
    y_test = pickle.load(f)

# Move test data to the specified device (CPU/GPU)
X_test = X_test.to(device)
y_test = y_test.to(device)

# Create DataLoader:
# Create a PyTorch dataset and DataLoader for batch processing during testing
test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)  # No shuffling needed for evaluation

# Evaluate Model:
model.eval()  # Set model to evaluation mode (disables dropout, batch norm updates)

correct = 0  # Track the number of correct predictions
total = 0  # Track the total number of samples

# Disable gradient calculation to speed up inference and reduce memory usage
with torch.no_grad():
    for X_batch, y_batch in tqdm(test_loader):
        outputs = model(X_batch)  # Forward pass: compute model predictions

        _, predicted = torch.max(outputs, 1)  # Get the class with the highest probability
        correct += (predicted == y_batch.squeeze()).sum().item()  # Count correct predictions
        total += y_batch.size(0)  # Update total sample count

# Compute and display test accuracy
accuracy = correct / total
print(f"Test Accuracy: {accuracy:.4f}")  


100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 90.21it/s]

Test Accuracy: 0.9988





In [6]:
# Prepare the transformation model
# This transformation is referred to as "phi" in our current paper 
# and as "rotation" in the original DAS paper.
# At this stage, it is simply a rotation matrix.

# The RotateLayer class is copied directly from the DAS library.

class RotateLayer(torch.nn.Module):
    """A learnable linear transformation initialized as an orthogonal matrix."""

    def __init__(self, n, init_orth=True):
        """
        Args:
            n (int): Dimension of the square transformation matrix.
            init_orth (bool): If True, initializes the matrix with an orthogonal weight.
        """
        super().__init__()
        weight = torch.empty(n, n)  # Create an empty n x n matrix
        
        # We don't need initialization if loading from a pretrained checkpoint.
        # You can explore different initialization strategies if necessary, but this isn't our focus.
        if init_orth:
            torch.nn.init.orthogonal_(weight)
        
        self.weight = torch.nn.Parameter(weight, requires_grad=True)  # Learnable weight matrix

    def forward(self, x):
        """Applies the rotation matrix to the input tensor."""
        return torch.matmul(x.to(self.weight.dtype), self.weight)
        

class Transformation_Function(nn.Module):
    """Encapsulates the rotation transformation as a PyTorch module."""

    def __init__(self, embed_dim=16):
        """
        Args:
            embed_dim (int): The embedding dimension (size of the transformation matrix).
        """
        super(Transformation_Function, self).__init__()
        
        rotate_layer = RotateLayer(embed_dim)  # Initialize the rotation layer
        # Ensure the transformation remains an orthogonal matrix
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)

    def forward(self, x):
        """Applies the orthogonal transformation to the input tensor."""
        return self.rotate_layer(x)


class InverseTransformation_Function(nn.Module):
    """Computes the inverse of the given transformation function (phi)."""

    def __init__(self, transformation_function: Transformation_Function):
        """
        Args:
            transformation_function (Transformation_Function): The forward transformation function.
        """
        super(InverseTransformation_Function, self).__init__()
        self.transformation_function = transformation_function  # Store reference to the transformation

    def forward(self, x):
        """Applies the inverse transformation by transposing the orthogonal weight matrix."""
        weight_T = self.transformation_function.rotate_layer.weight.T  # Use matrix transpose as inverse
        return torch.matmul(x.to(weight_T.dtype), weight_T)


# Instantiate the transformation and its inverse
phi = Transformation_Function(16)
phi.to(device)  # Move transformation model to the specified device

phi_inverse = InverseTransformation_Function(phi)
phi_inverse.to(device)  # Move inverse transformation to the device

# Define loss function and optimizer for training the transformation function
criterion = nn.CrossEntropyLoss()  # Standard loss function for classification
optimizer = optim.Adam(phi.parameters(), lr=0.001)  # Adam optimizer for updating phi


In [7]:
# Hyperparameters used for DAS with this classification model

activation_dimension = 16  # Dimensionality of the hidden representation in the analyzed model

# Mapping of dimensions in the rotated space to specific features
# Each list specifies which indices in the transformed feature space correspond to a particular feature.
# In this case, the first 8 dimensions (0-7) belong to the first feature,
# and the next 8 dimensions (8-15) belong to the second feature.
dim_per_feature = [
    list(range(0, 8)),   # First feature (uses dimensions 0 to 7)
    list(range(8, 16))   # Second feature (uses dimensions 8 to 15)
]


In [8]:
# Extracting the generated datasets directly from the DAS paper's codebase:
# Reference: https://github.com/stanfordnlp/pyvene/blob/main/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb
# See also my second uploaded file. These functions are helper functions 
# to transform the DAS training and testing datasets into a format usable by my code.

def chunk_list(input_list, batch_size):
    """
    Splits a list into smaller chunks of size `batch_size`.
    
    Args:
        input_list (list): The list to be split.
        batch_size (int): The size of each chunk.

    Returns:
        list: A list of batches (sublists).
    """
    return [input_list[i:i + batch_size] for i in range(0, len(input_list), batch_size)]




def extract_base(input_list):
    """
    Extracts the `input_ids` from a list of input samples and stacks them into a tensor. (base input)
    
    Args:
        input_list (list): List of dictionary-like objects containing `input_ids`.

    Returns:
        torch.Tensor: A tensor of extracted `input_ids`, moved to the appropriate device.
    """
    res = []
    for i in input_list:
        res.append(i["input_ids"])  # Extract 'input_ids' from each sample

    res = torch.stack(res)  # Convert list to tensor
    res = res.to(device)  # Move tensor to the designated device
    return res

def extract_sources(input_list):
    """
    Extracts `source_input_ids` from the dataset. if `intervention_id=1` I need to swap source positions. (source input)
    
    Args:
        input_list (list): List of dictionary-like objects containing `source_input_ids`.

    Returns:
        list of torch.Tensor: A list of tensors containing grouped `source_input_ids`, moved to device.
    """
    res = [[] for _ in range(input_list[0]['source_input_ids'].shape[0])]  # Initialize list of empty lists

    for i in input_list:
        for j in range(i['source_input_ids'].shape[0]):
            if i['intervention_id'] in [0,2]:
                res[j].append(i['source_input_ids'][j])
            elif i['intervention_id'] in [1]:
                res[1 - j].append(i['source_input_ids'][j])  
            else:
                print("[ERROR]")
                exit()

    for i in range(len(res)):  # Convert lists to tensors and move to device
        res[i] = torch.stack(res[i])
        res[i] = res[i].to(device)

    return res

def extract_labels(input_list):
    """
    Extracts the expected output labels after intervention from the dataset and 
    converts them into a stacked tensor.

    Args:
        input_list (list): A list of dictionary-like objects containing the key "labels".

    Returns:
        torch.Tensor: A tensor of extracted labels, moved to the appropriate device.
    """
    res = []  # Initialize an empty list to store labels

    for i in input_list:
        res.append(i["labels"])  # Extract the "labels" field from each input sample

    res = torch.stack(res)  # Convert list of labels into a tensor
    res = res.to(device)  # Move tensor to the designated device
    return res


def prepare_intervention_matrix(input_list): 
    """
    Creates an intervention matrix of shape (batch_size, activation_dimension).
    
    The matrix is a boolean tensor indicating, for each element in each sample in the batch,
    whether it should use the base value (False) or the source value (True).
    
    Args:
        input_list (list): A list of dictionary-like objects, each containing an "intervention_id".
    
    Returns:
        torch.Tensor: A boolean tensor of shape (batch_size, activation_dimension),
                      moved to the appropriate device.
    """
    res = []  # Initialize an empty list to store intervention indicators

    for i in input_list:
        res.append([])  # Append a new row for each input sample
        for j in range(activation_dimension):  # Iterate over all dimensions
            if j in dim_per_feature[0]:  # Check if the dimension belongs to the first feature
                if i["intervention_id"] in [0, 2]:  
                    res[-1].append(True)  # Use source value
                else:
                    res[-1].append(False)  # Use base value
            elif j in dim_per_feature[1]:  # Check if the dimension belongs to the second feature
                if i["intervention_id"] in [1, 2]:  
                    res[-1].append(True)  # Use source value
                else:
                    res[-1].append(False)  # Use base value
            else:
                res[-1].append(False)  # Default to base value for dimensions not explicitly assigned

    res = torch.tensor(res)  # Convert list to a PyTorch tensor
    res = res.to(device)  # Move tensor to the designated device
    return res

In [9]:
# This is the hook function responsible for performing interventions on activations.

mode_info = None  # Stores mode information (either "source" or "intervene") and relevant information.
source_activations = None  # A tensor of shape (batch_size, activation_dimension).
# This tensor stores the rotated activations of the source inputs at the positions where
# intervention is supposed to occur in the transformed space.

def hook_fn(module, input, output):
    """
    A forward hook function that performs intervention on activations during forward pass.

    Behavior:
    - If `mode_info[0]` is "source", it stores the rotated activations of the source inputs
      in `source_activations` at the designated intervention positions.
    - If `mode_info[0]` is "intervene", it modifies the activations by replacing base activations
      with the stored source activations at the designated positions.

    Args:
        module (nn.Module): The module where the hook is attached.
        input (tuple): The input to the module.
        output (torch.Tensor): The output activations from the module.

    Returns:
        torch.Tensor (only in "intervene" mode): The modified activations with intervention applied.
    """
    global mode_info
    global source_activations

    if mode_info[0] == "source":
        # If processing a source input, store its rotated activations in `source_activations`
        # but only at the dimensions assigned to the specified feature (mode_info[1]).
        source_activations[:, dim_per_feature[mode_info[1]]] = phi(output.detach())[:, dim_per_feature[mode_info[1]]]

    elif mode_info[0] == "intervene":
        # Get the rotated activations of the base input
        result_tensor = phi(output.detach())

        # Replace the base activations with source activations at the intervention positions
        result_tensor = torch.where(mode_info[1], source_activations, result_tensor)

        # Rotate back to the original space and return the modified activations
        return phi_inverse(result_tensor)

# Retrieve the first MLP layer from the model
first_layer = model.mlp.h[0].ff1  # Access the first MLPBlock's feed-forward layer 

# Register the forward hook to apply interventions during inference
intervention_hook = first_layer.register_forward_hook(hook_fn)


In [10]:
# Load the preprocessed training dataset for DAS from a pickle file.
with open("DAS_Train.pkl", "rb") as f:
    DAS_Train = pickle.load(f)


In [11]:
# DAS Training Process

# Define hyperparameters
epochs = 10  # Number of epochs for training
batch_size = 6400  # Size of each training batch

# Set the transformation model (phi) to training mode
phi.train()

# Freeze the weights in the classification model (since we are only training the phi model)
for param in model.parameters():
    param.requires_grad = False  # Disable gradient updates for classification model parameters

# Loop through the epochs
for epoch in range(epochs):
    total_loss = 0  # Variable to track total loss for the epoch

    # Shuffle the DAS training data and create batches
    random.shuffle(DAS_Train)
    DAS_Train_Batches = chunk_list(DAS_Train, batch_size)  # Divide dataset into batches

    # Loop through each batch in the shuffled dataset
    for ac_batch in tqdm(DAS_Train_Batches):
        optimizer.zero_grad()  # Zero the gradients

        # Prepare Source Activations
        ac_sources = extract_sources(ac_batch)  # Extract the source inputs
        source_activations = torch.zeros(len(ac_batch), activation_dimension)  # Initialize source activations
        source_activations = source_activations.to(device)  # Move to the appropriate device (e.g., GPU)

        # For each source input, run the model to capture the rotated activations
        for ac_source_pos in range(len(ac_sources)):
            mode_info = ["source", ac_source_pos]  # Indicate that this is a source input and specify the source position
            model(ac_sources[ac_source_pos])  # The hook function will save the rotated source activations

        # Intervention Phase
        intervention_bools = prepare_intervention_matrix(ac_batch)  # Prepare the boolean matrix indicating where to intervene
        mode_info = ["intervene", intervention_bools]  # Set the mode to 'intervene' to trigger intervention in the hook
        ac_base = extract_base(ac_batch)  # Extract the base input 
        outputs = model(ac_base)  # Run the model to apply the intervention and get outputs
        labels = extract_labels(ac_batch)  # Extract the true labels from the batch
        loss = criterion(outputs, labels.squeeze().long())  # Compute the loss between predicted and true labels
        loss.backward()  # Backpropagate the loss to update the phi model
        optimizer.step()  # Apply the gradient updates to the phi model
        total_loss += loss.item()  # Accumulate the loss for this batch

    # Print the average loss for this epoch
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(DAS_Train_Batches)}")

100%|█████████████████████████████████████████| 200/200 [02:34<00:00,  1.30it/s]


Epoch 1, Loss: 4.891106110811234


100%|█████████████████████████████████████████| 200/200 [02:39<00:00,  1.25it/s]


Epoch 2, Loss: 2.485404103398323


100%|█████████████████████████████████████████| 200/200 [02:34<00:00,  1.29it/s]


Epoch 3, Loss: 0.5313158228248358


100%|█████████████████████████████████████████| 200/200 [02:40<00:00,  1.25it/s]


Epoch 4, Loss: 0.1602864133194089


100%|█████████████████████████████████████████| 200/200 [02:35<00:00,  1.29it/s]


Epoch 5, Loss: 0.06925371028482914


100%|█████████████████████████████████████████| 200/200 [02:40<00:00,  1.25it/s]


Epoch 6, Loss: 0.023070329078473152


100%|█████████████████████████████████████████| 200/200 [02:35<00:00,  1.28it/s]


Epoch 7, Loss: 0.012349987239576875


100%|█████████████████████████████████████████| 200/200 [02:40<00:00,  1.24it/s]


Epoch 8, Loss: 0.009252700987271965


100%|█████████████████████████████████████████| 200/200 [02:35<00:00,  1.29it/s]


Epoch 9, Loss: 0.007962174187414349


100%|█████████████████████████████████████████| 200/200 [02:35<00:00,  1.29it/s]

Epoch 10, Loss: 0.007298748982138932





In [12]:
with open("DAS_Test.pkl", "rb") as f:
    DAS_Test = pickle.load(f)

In [13]:
#DAS Training


phi.eval()

total_correct = 0
total_samples = 0

DAS_Test_Batches = chunk_list(DAS_Test, batch_size)
softm = nn.Softmax(dim=1)

for ac_batch in tqdm(DAS_Test_Batches):

    # Prepare Source Activations
    ac_sources = extract_sources(ac_batch)
    source_activations = torch.zeros(len(ac_batch), activation_dimension)
    source_activations =source_activations.to(device)
    for ac_source_pos in range(len(ac_sources)):
        mode_info = ["source", ac_source_pos]
        model(ac_sources[ac_source_pos])

    # Intervention
    intervention_bools = prepare_intervention_matrix(ac_batch)
    mode_info = ["intervene", intervention_bools]
    ac_base = extract_base(ac_batch)
    outputs = model(ac_base)  # Logits
    labels = extract_labels(ac_batch)  # True labels

    # Compute predictions
    predictions = torch.argmax(outputs, dim=1)  # Get class with highest probability
    labels = labels.to(predictions.device)  # Ensure labels and predictions are on the same device

    # Count correct predictions
    correct = (predictions.squeeze() == labels.squeeze()).sum().item()
    total_correct += correct
    total_samples += labels.size(0)

# Final accuracy calculation
accuracy = total_correct / total_samples
print(f"Test Accuracy: {accuracy * 100:.2f}%")
    

100%|█████████████████████████████████████████████| 2/2 [00:09<00:00,  4.67s/it]

Test Accuracy: 99.77%



