In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset



In [16]:
# def create_masks(adjacency_matrix, neurons_per_layer):
#     num_layers = len(neurons_per_layer)
#     num_nodes = adjacency_matrix.shape[0]
#     layer_masks = []

#     # Ensure each layer, except the last, has an even number of neurons if required
#     assert all(n % 2 == 0 for n in neurons_per_layer[:-1]), "Each layer, except the last, must have an even number of neurons."

#     # The last layer must match the number of DAG nodes (variables)
#     assert neurons_per_layer[-1] == num_nodes, "The last layer must exactly match the number of nodes in the DAG."

#     for layer_index in range(1, num_layers):
#         prev_layer_size = neurons_per_layer[layer_index - 1]
#         curr_layer_size = neurons_per_layer[layer_index]
#         mask = np.zeros((curr_layer_size, prev_layer_size))

#         prev_neurons_per_node = prev_layer_size // num_nodes
#         curr_neurons_per_node = curr_layer_size // num_nodes

#         for i in range(num_nodes):
#             for j in range(num_nodes):
#                 if adjacency_matrix[i, j] == 1:
#                     mask[j * curr_neurons_per_node:(j + 1) * curr_neurons_per_node,
#                          i * prev_neurons_per_node:(i + 1) * prev_neurons_per_node] = 1
#         layer_masks.append(mask)

#     return layer_masks


def create_masks(adjacency_matrix, neurons_per_layer):
    num_layers = len(neurons_per_layer)
    num_nodes = adjacency_matrix.shape[0]
    layer_masks = []

    # Create masks for each layer based on the adjacency matrix and the idea of preserving the diagonal after the first layer.
    for layer_index in range(1, num_layers):
        prev_layer_size = neurons_per_layer[layer_index - 1]
        curr_layer_size = neurons_per_layer[layer_index]
        mask = np.zeros((curr_layer_size, prev_layer_size))

        # Process each node according to the adjacency matrix for the first layer
        if layer_index == 1:
            for i in range(num_nodes):
                for j in range(num_nodes):
                    if adjacency_matrix[i, j] == 1:
                        mask[j, i] = 1  # Allow signal as per adjacency matrix
        else:
            # After the first layer, maintain diagonal connectivity to ensure signal preservation
            # and also allow propagation according to the adjacency matrix logic
            np.fill_diagonal(mask, 1)  # Ensure self-preservation
            for i in range(min(num_nodes, prev_layer_size)):  # Keep within the smaller dimension
                for j in range(num_nodes):
                    if adjacency_matrix[i, j] == 1:
                        mask[j, i] = 1  # Propagate based on adjacency rules

        layer_masks.append(mask)

    return layer_masks


class MaskedLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(MaskedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
        self.mask = None
        self.reset_parameters()
    
    def set_mask(self, mask):
        self.mask = nn.Parameter(mask, requires_grad=False)
    
    def reset_parameters(self):
#         nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
        self.weight.data.fill_(1.0)
        self.bias.data.fill_(0.0)
    
    def forward(self, input):
        if self.mask is not None:
            return nn.functional.linear(input, self.weight * self.mask)
        else:
            return nn.functional.linear(input, self.weight)
    
    
class DAGAutoencoder(nn.Module):
    def __init__(self, neurons_per_layer):
        super(DAGAutoencoder, self).__init__()
        self.layers = nn.ModuleList() 
        self.activations = nn.ModuleList()  

        for i in range(len(neurons_per_layer) - 1):
            linear_layer = MaskedLinear(neurons_per_layer[i], neurons_per_layer[i+1])
            self.layers.append(linear_layer)
            if i < len(neurons_per_layer) - 2:  
                self.activations.append(nn.ReLU())

    def set_masks(self, masks):
        # Apply masks only to linear layers
        assert len(masks) == len(self.layers), "The number of masks must match the number of linear layers."
        for layer, mask in zip(self.layers, masks):
            layer.set_mask(mask)
                
    def forward(self, x):
        for linear, activation in zip(self.layers, self.activations):
            x = linear(x)
            x = activation(x)
        x = self.layers[-1](x)  # Apply the last linear layer (without ReLU if it's the output layer)
        return x



In [17]:
input_size = 4  # Adjust according to your model's expected input size
neurons_per_layer = [input_size,  4, 4, 6, input_size]  # Symmetric for autoencoding
num_samples = 100  # Number of synthetic data samples
batch_size = 10  # Batch size for training
epochs = 500  # Number of epochs for training

# Generate synthetic data
data = torch.randn(num_samples, input_size)
dataset = TensorDataset(data, data)  # Using the same data as both input and target
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


model = DAGAutoencoder(neurons_per_layer)
initial_adj_matrix = np.array([
    [0, 1, 0, 0],
    [0, 0, 0, 1],
    [1, 0, 0, 1],
    [0, 0, 0, 0]
])
initial_adj_matrix = np.array([
    [0, 1, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0]
])
initial_masks = [torch.from_numpy(mask).float() for mask in create_masks(initial_adj_matrix, neurons_per_layer)]
model.set_masks(initial_masks)

input_tensor = torch.ones((1, 4))
o = torch.mm(input_tensor, (model.layers[0].weight * model.layers[0].mask).T)  
print('output:', o, 'mask used:', model.layers[0].mask)
o = torch.mm(o, (model.layers[1].weight * model.layers[1].mask).T)  
print('output:',o,'mask used:',  model.layers[1].mask)
o = torch.mm(o, (model.layers[2].weight * model.layers[2].mask).T) 
print('output:', o, 'mask used:', model.layers[2].mask)
o = torch.mm(o, (model.layers[3].weight * model.layers[3].mask).T) 
print('output:',o, 'mask used:', model.layers[3].mask)

output: tensor([[0., 1., 0., 0.]], grad_fn=<MmBackward0>) mask used: Parameter containing:
tensor([[0., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
output: tensor([[0., 1., 0., 0.]], grad_fn=<MmBackward0>) mask used: Parameter containing:
tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
output: tensor([[0., 1., 0., 0., 0., 0.]], grad_fn=<MmBackward0>) mask used: Parameter containing:
tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
output: tensor([[0., 1., 0., 0.]], grad_fn=<MmBackward0>) mask used: Parameter containing:
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.]])


In [None]:
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for inputs, targets in dataloader:
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)  # Multiply by batch size

    # Print statistics
    epoch_loss = running_loss / num_samples
    if epoch % 10 == 9:  # Print every 10 epochs
        print(f'Epoch {epoch + 1}, Loss: {epoch_loss}')

print('Finished Training')

# Test the model with one example
test_input = torch.randn(1, input_size)
model.eval()  # Set model to evaluation mode
test_output = model(test_input)
print("Test input:", test_input)
print("Reconstructed output:", test_output)

intervention_adj_matrix = np.zeros_like(initial_adj_matrix) 
intervention_masks = [torch.from_numpy(mask).float() for mask in create_masks(intervention_adj_matrix, neurons_per_layer)]
model.set_masks(intervention_masks)

# Example usage after intervention
output_after_intervention = model(test_input)
print("Output of the network after intervention:", output_after_intervention)

In [None]:
o
