In [1]:
import torch
nn = torch.nn
import os
from functools import reduce

os.chdir('..')

from causal_rl.sem import StructuralEquationModel, DirectedAcyclicGraph

In [5]:
graph = DirectedAcyclicGraph(torch.tensor([
    [[0., 0., 0.],
    [0., 0., 0.],
    [0., 0., 0.]],
    [[0., 0., 0.],
    [0.3, 0., 0.],
    [-1.2, 0.15, 0.]]
]))

target, value = 1, 5

sem = StructuralEquationModel(graph, 'bernoulli', 0.5)
# sem.noises = torch.tensor([0., 1., 0.])
# sem = StructuralEquationModel.random(7, 0.3, 'gaussian', 0.1)
obs = sem(n=1, z_prev=None, intervention=None)
cnt = sem.counterfactual(z_prev=None, intervention=(target, value))

print('obs', obs)
print('ground', cnt)
print('noise', sem.noises)
print()

B = sem.graph.weights[1,:,:]
X = obs
N = sem.noises

print(B)
print(sem.graph.weights)

def loopy(X, N, B, intervention):
    target, value = intervention
    result = X.clone()
    
#     for i in range(len(B)):
#         if i == target:
#             result[:, i] = value
#             continue

#         result[:, i] = B[i].matmul(result.t()) + N[i]
    
    for i in range(len(B)):
        if i == target:
            result[:, i] = value
            continue
        
        result[:, i] = 0
        for j in range(len(B)):
            result[:, i] = result[:, i] + B[i, j] * result[:, j]
        result[:, i] = result[:, i] + N[i]
    
    return result

print('loopy', loopy(X, N, B, (target,value)))

obs tensor([[0., 0., 0.]])
ground tensor([[0.0000, 5.0000, 0.7500]])
noise tensor([0., 0., 0.])

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.3000,  0.0000,  0.0000],
        [-1.2000,  0.1500,  0.0000]])
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.3000,  0.0000,  0.0000],
         [-1.2000,  0.1500,  0.0000]]])
loopy tensor([[0.0000, 5.0000, 0.7500]])


In [None]:
B = torch.tensor([
    [0, 0, 0],
    [0.3, 0, 0],
    [-1.2, 0.15, 0]
]).float()

X = torch.zeros(3)
N = torch.tensor([])

In [7]:
from torchviz import make_dot

class OrderedPredictor(nn.Module):
    def __init__(self, sem):
        super(OrderedPredictor, self).__init__()
        self.dim = sem.dim

        # heuristic: we know the true weights are lower triangular
        # heuristic: root nodes should have self-connection of 1 to carry noise to prediction
        self.linear1 = nn.Parameter(torch.randn((self.dim,self.dim)).tril_(-1))
    
    def forward(self, observation, noise, intervention):
        target, value = intervention
        
        output = observation.detach()

        for i in range(self.dim):
            if i == target:
                output[:, i] = value
                continue
            
            temp = output.clone()
            output[:, i] = self.linear1[i].matmul(temp.t()) + noise[:, i]

        return output

#         for i in range(self.dim):
#             if i == target:
#                 observation[:, i] = value
#                 continue

#             observation[:, i] = 0
#             for j in range(self.dim):
#                 if i > j:
#                     observation[:, i] = observation.clone()[:, i] + self.linear1[i, j] * observation.clone()[:, j]
#             observation[:, i] = observation.clone()[:, i] + noise.clone()[:, i]
        
#         return observation

class TwoStepPredictor(nn.Module):
    def __init__(self, sem):
        super(TwoStepPredictor, self).__init__()
        
        self.dim = sem.dim


        self.infer_noise = nn.Linear(self.dim, self.dim, bias=True) # NoiseInferencer(self.dim)
        self.predictor = OrderedPredictor(sem)

    def forward(self, observation, intervention):
        noise = self.infer_noise(observation.clone())
        return self.predictor(observation, noise, intervention)
    
def test_orderedpredictor(n):
    sem = StructuralEquationModel.random(3, 0.3, 'bernoulli', 0.5)
    model = TwoStepPredictor(sem)
    
    for i in range(n):
        obs = sem(n=1)
        target = torch.randint(high=3, size=(1,)).item()
        count = sem.counterfactual(intervention=(target, 5))

        pred = model(obs, (target, 5))
        
        print(pred)
        
        g = make_dot(pred)
        g.view()
        
        loss = (pred - count).pow(2).sum()
        print(loss)
        loss.backward()
    
test_orderedpredictor(1)

tensor([[ 0.0934,  5.0000, -4.5487]], grad_fn=<CopySlices>)
tensor(157.4437, grad_fn=<SumBackward0>)


In [None]:
in_weights = nn.Parameter(torch.randn(3,3))
noise_weights = nn.Parameter(torch.randn(3,3))

x = torch.randn(3)
n = x.matmul(noise_weights).view(-1)

x_ = torch.zeros(3, 3)
x_[0,:] = x #.detach()


for i in range(1, 3):
    x_[i, :] = x_[i-1,:]
    
    temp = x_[i-1,:].clone()
        
    x_[i, i] = in_weights[i].dot(temp) + n[i]

print(x_)
pred = x_[-1,:]

# make_dot(pred).view()
loss = (pred - torch.randn(3)).sum()
loss.backward()