In [61]:
import torch
from torch.nn import functional as F
import torch.nn as nn
import networkx as nx
import numpy as np

def adjacency_mod(adjacency_matrix, causal_ordering):  # TODO: remove if function not needed
    # this is just a temporary function to see if adding a 1 to the diagonal for all zeroth-order variables helps
    for i, var in enumerate(causal_ordering.keys()):
        order = causal_ordering[var]
        if order == 0:
            adjacency_matrix[i, i] = 1  # add 1 to the i-th diagonal element
    return adjacency_matrix



# adapted from example GPT code  https://github.com/karpathy/ng-video-lecture
class Head(nn.Module):
    def __init__(self, head_size, dropout_rate, dag):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        
        self.head_size = head_size
        # user a register buffer (not a module parameter) for the creation of self.dag
        # dag will determine what variables can communicate with each other

        self.dag_orig = dag.T  # transpose the DAG so that it pulls the correct embeddings

        self.register_buffer('dag_mod', self.dag_orig)  # include transpose
        self.dropout = nn.Dropout(dropout_rate)
        self.act = nn.LeakyReLU()
        self.att_wei = None
        
        
    def forward(self, X):

        K = self.key(X)  # B, T, hs
        Q = self.query(X)  # B, T, hs
        V = self.value(X)  # B, T, hs
        B, T, HS = Q .shape
        QK = torch.matmul(Q, K.transpose(1, 2)) / (self.head_size ** 0.5)
          
        self.att_wei = QK.masked_fill(self.dag_mod == 0, float('-inf')) 
        self.att_wei = F.softmax(self.att_wei, dim=-1)
        nan_rows = torch.any(torch.isnan(self.att_wei), dim=-1)  # check if any rows are <all> -inf, these need to be masked to 0
        nan_mask = nan_rows.unsqueeze(-1).expand_as(self.att_wei)
        self.att_wei = torch.where(nan_mask, torch.zeros_like(self.att_wei), self.att_wei) # set any rows have nan values (because they have no causal parents) to 0 to avoid nans
        
        out = self.att_wei @ V  # B, T, hs
        
        return self.act(out)

class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size, dropout_rate, dag):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size=head_size, dropout_rate=dropout_rate, dag=dag) for _ in range(num_heads)])
        self.projection = nn.Linear(int(head_size*num_heads), 1)
        self.dropout = nn.Dropout(dropout_rate)
        self.act = nn.LeakyReLU()

    def forward(self, X):
        out = torch.cat([h(X) for h in self.heads], dim=-1)
        out = self.dropout(self.projection(out))
        return self.act(out)


class FF(nn.Module):
    def __init__(self, n_embed, dropout_rate):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout_rate),
        )

    def forward(self, X):
        out = self.net(X)
        return out


class Block(nn.Module):

    def __init__(self, n_embed, num_heads, head_size,  dropout_rate, dag):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, head_size, dropout_rate, dag)
        self.ff = FF(n_embed, dropout_rate)

    def forward(self, X):
        X = self.mha(X)  # + X  with skip connection (careful with adding back in after having masked it)
        X = X + self.ff(X)  # with skip connection
        return X


class MixedLoss(nn.Module):
    def __init__(self, var_types_sorted, causal_ordering):
        super(MixedLoss, self).__init__()
        self.causal_ordering = causal_ordering
        self.var_types_sorted = var_types_sorted  # sorted types for determining which loss to use
        self.cont_loss = nn.MSELoss()  # Loss for continuous variables
        self.bin_loss = nn.BCEWithLogitsLoss()  # Loss for binary variables
        self.cat_loss = nn.CrossEntropyLoss()   # takes logits for each class as input

    def forward(self, pred, target, shuffle_ordering):

        total_loss = 0
        loss_tracking = {}
        sorted_vars = [list(self.var_types_sorted.keys())[i] for i in shuffle_ordering]
        for i, var_name in enumerate(sorted_vars):
            var_type = self.var_types_sorted[var_name]
            order = self.causal_ordering[var_name]
            # if order != 0:  # don't generate a loss for predicting stuff with no parents
            if var_type == 'cont':
                loss = self.cont_loss(pred[:, i], target[:, i])
            elif var_type == 'bin':
                loss = self.bin_loss(pred[:, i], target[:, i])
            elif var_type == 'cat':
                loss = self.cat_loss(pred[:, i].unsqueeze(0), target[:, i].long())

            loss_tracking[var_name] = loss.item()
            total_loss += loss

        return total_loss, loss_tracking


def shuffler(X, targets, dag, shuffling=False):
    # shuffles the order of X, targets, and adjacency matrix for a batch
    if shuffling:
        shuffle_ordering = np.random.permutation(X.shape[1])
    else:
        shuffle_ordering = np.arange(0, X.shape[1])
    return X[:, shuffle_ordering], dag[shuffle_ordering, :], targets[:, shuffle_ordering], shuffle_ordering



class CaT(nn.Module):

    def __init__(self, num_heads, head_size, n_layers, dag):
        '''
        :param dag_type: Whether to use the lower-triangular with diagonal (=0), the lower-triangular without diagonal (=1) or causal adjacency matrix (=2)
        :param num_heads:
        :param head_size:
        :param n_layers:
        :param dag: adjacency matrix
        :param device: 'cuda' or 'cpu'
        '''

        super().__init__()
        self.device = 'cpu'
        self.n_layers = n_layers
        self.num_heads = num_heads
        dag = torch.tensor(dag).to(self.device).T
        self.blocks = nn.Sequential(
            *[Block(n_embed=1, num_heads=num_heads, head_size=head_size, dropout_rate=0.01, dag=dag) for _ in range(n_layers)])
        self.lm_head = nn.Linear(1, 1)

    def forward(self, X, targets=None):

        X = self.blocks(X)  # B, num_vars, head_size
        X = self.lm_head(X)
        X = X[:, :, 0]

        
# note that right now the network has MHA with blocks in paralle, but this is also done sequentially, combining
# both network width and network depth.

# for the causal transformer, we have to be careful that we include a 'diagonal' pass-thru after the first layer
# otherwise, and e.g. in a three variable chain A->B->C, the dependency structure will prevent B from being predicted
# from A <after the first layer>, because B is caused by A, not by itself. So the diagonal of ones should be
# introduced after the first layer.
# Be careful also that between parallel blocks there is no interaction with e.g. the linear/FF layers which 
# violates the contraints.

In [62]:
num_heads = 2
head_size = 10
n_layers = 3

dag = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]])  # Example DAG where rows are causes and columns are effects

# Initialize the CaT model
model = CaT(num_heads=num_heads, head_size=head_size, n_layers=n_layers, dag=dag)

# Random test data
B = 1  # Batch size
V = 4  # Number of variables
C = 10  # Number of channels/dimensions
n_embd = C
test_data = torch.randn(B, V, C)  # Random data (B, V, C)

output = model(test_data)

tensor([[[0., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.]]], grad_fn=<SWhereBackward0>)
torch.Size([1, 4, 10])
tensor([[[0., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.]]], grad_fn=<SWhereBackward0>)
torch.Size([1, 4, 10])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1 and 10x10)

In [66]:

import numpy as np
A = np.array([
[0, 1, 1],
[0, 0, 1],
[0, 0, 0]
])
V = np.array([
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]
])
# Transpose A to correctly reflect the intended influence directions
A_transposed = A.T
A_transposed @ V

array([[0, 0, 0, 0],
       [2, 2, 2, 2],
       [5, 5, 5, 5]])