In [None]:
import torch
import torch_geometric as pyg
import numpy as np

# https://towardsdatascience.com/a-beginners-guide-to-graph-neural-networks-using-pytorch-geometric-part-1-d98dc93e7742
# https://towardsdatascience.com/a-beginners-guide-to-graph-neural-networks-using-pytorch-geometric-part-2-cd82c01330ab
# https://github.com/tkipf/gcn

#from torch_geometric.datasets import KarateClub
#data = KarateClub()

class MultiGraphModel(torch.nn.Module):
    def __init__(self, Z, K):
        super().__init__()
        N,D = Z.shape
        self.N = N
        self.D = D
        self.Z = Z

        self.K = K

    def sample(self, n, Zs = None):

        A = np.zeros((n, self.K, self.N, self.N))
        pw_dists = np.linalg.norm(self.Z[:, None, :] - self.Z[None, :, :], axis=-1)

        for k in range(self.K):
            A[:,k,:,:] = np.random.binomial(n=1, p=np.exp(-np.repeat(pw_dists[np.newaxis,:,:], n, axis=0)))
        return A

J = 3
n_nodes, latent_dim_per_node = 300, 1
#Z = np.random.normal(size=(n_nodes,latent_dim_per_node))
Z = np.random.random(size=(n_nodes,latent_dim_per_node))
TZ = np.concatenate([np.cos(2*np.pi*Z), np.sin(2*np.pi*Z)], axis=1)
m = MultiGraphModel(TZ, J)

dim_Z_per_node = latent_dim_per_node
dim_T_per_node = 2 * dim_Z_per_node
dim_Z = dim_Z_per_node * n_nodes
dim_T = dim_T_per_node * n_nodes
dim_js = [n_nodes for j in range(J)]  # dimensions of marginals

graph_comp_mode = 'single'

N = 1
A = m.sample(n=N)
edge_index = np.zeros((N,))
A.shape

In [None]:
from rpm import RPMEmpiricalMarginals, EmpiricalDistribution, LogPartition_gauss_diagonal, LogPartition_discrete, LogPartition_vonMises
from rpm import ExpFam, ConditionalExpFam, SemiparametricConditionalExpFam, RPM

import torch_geometric
from torch_geometric.nn import GCNConv

import torch
import numpy as np

dtype = torch.float

setup = 'vonMises'

if setup == 'gaussian':  # conditional Gaussian case
    # define Gaussian prior in natural parametrization  
    def activation_out(x,d=dim_Z_per_node): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
        return torch.cat([x[:,:d], -torch.nn.Softplus()(x[:,d:])],axis=-1)
    log_partition = LogPartition_gauss_diagonal(d=dim_Z)
elif setup == 'discrete': # conditional categorical case
    def activation_out(x,d=None): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
        return torch.nn.LogSoftmax(dim=-1)(x)
    log_partition = LogPartition_discrete(D=dim_T)
elif setup == 'vonMises': # conditional categorical case
    def activation_out(x,d=None): # NN returns natural parameters; in Gaussian case, that is m/sig2, -1/(2*sig2)
        return torch.nn.Identity()(x)
    log_partition = LogPartition_vonMises(d=dim_Z)

latent_prior = ExpFam(natparam=torch.normal(mean=0.0, std=torch.ones(dim_T).reshape(1,-1)),
                                    log_partition=log_partition, activation_out=activation_out)


# GCN model with 2 layers 
class Net_multigraph(torch.nn.Module):
    def __init__(self, n_in, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net_multigraph, self).__init__()
        self.conv1 = GCNConv(n_in, n_hidden)
        self.conv2 = GCNConv(n_hidden, n_out)
        self.activation_out = activation_out

    def forward(self, x):
        assert x.ndim in [2,3]
        if x.ndim == 2 or (x.ndim==3 and len(x) ==1):

            connectivity = x[0] if len(x) == 1 else x
            num_nodes = connectivity.shape[0]
            assert connectivity.ndim==2 and connectivity.shape[1] == num_nodes
            node_features = torch.eye(num_nodes)
            assert node_features.shape[0] == num_nodes
            edge_index, edge_weights = torch_geometric.utils.to_edge_index(connectivity.to_sparse())

            x = torch.nn.functional.relu(self.conv1(node_features, edge_index))
            #x = torch.nn.functional.dropout(x, training=self.training)
            x = self.conv2(x, edge_index)
            x = self.activation_out(x)
            return  x.transpose(-2,-1).reshape(1,-1) if x.ndim==3 else x.transpose(-2,-1).flatten() 
        else:
            return torch.stack([self.forward(xn) for xn in x], axis=0)

class Net_singlegraph(torch.nn.Module):
    def __init__(self, edge_index, node_features, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net_singlegraph, self).__init__()

        self.edge_index = edge_index
        node_ids = edge_index.flatten().unique()
        assert torch.all(node_ids==torch.arange(len(node_ids)))
        self.num_nodes = len(node_ids)
        self.node_features = torch.eye(self.num_nodes) if node_features is None else node_features
        assert self.node_features.shape[0] == self.num_nodes

        self.conv1 = GCNConv(self.node_features.shape[-1], n_hidden)
        self.conv2 = GCNConv(n_hidden, n_out)
        self.activation_out = activation_out

    def forward(self, x=None):

        x = torch.nn.functional.relu(self.conv1(self.node_features, self.edge_index))
        #x = torch.nn.functional.dropout(x, training=self.training)
        x = self.conv2(x, self.edge_index)
        x = self.activation_out(x)
        return  x.transpose(-2,-1).reshape(1,-1) 

"""
class Net(torch.nn.Module):
    def __init__(self, n_in, n_out, n_hidden, activation_out=torch.nn.Identity()):
        super(Net, self).__init__()
        self.activation_out = activation_out
        self.fc1 = torch.nn.Linear(n_in, n_hidden, bias=True)
        self.fc2 = torch.nn.Linear(n_hidden, n_out, bias=True)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return self.activation_out(x)
"""

if graph_comp_mode == 'multi':
    xjs = [ torch.tensor(A[:,j],dtype=dtype) for j in range(J)]
    natparam_models = [Net_multigraph(dim_js[j], dim_T_per_node, n_hidden=8, activation_out=activation_out) for j in range(J)]
elif graph_comp_mode == 'single':
    xjs = [ torch_geometric.utils.to_edge_index(torch.tensor(A[0,j],dtype=dtype).to_sparse())[0] for j in range(J)]
    natparam_models = [Net_singlegraph(xjs[j], None, dim_T_per_node, n_hidden=8, activation_out=activation_out) for j in range(J)]

pxjs = RPMEmpiricalMarginals(xjs)
rec_factors = [ConditionalExpFam(model=m, log_partition=log_partition) for m in natparam_models]


q =  SemiparametricConditionalExpFam(natparams=torch.normal(mean=0.0, std=torch.ones(N, dim_T)), 
                                     log_partition=log_partition, activation_out=activation_out)

# constsruct implicit RPM
rpm = RPM(rec_factors, latent_prior, pxjs, q)

In [None]:
[print(p.shape)  for p in natparam_models[0].parameters()]

In [None]:
optimizer_p = torch.optim.Adam(rpm.joint_model.parameters(), lr=1e-3)
optimizer_q = torch.optim.Adam(rpm.q.parameters(), lr=1e-3)

epochs = 100000
batch_size = 1

class RPMDatasetMultigraph(torch.utils.data.Dataset):
    def __init__(self,xjs,num_features=None):
        self.J = len(xjs)
        assert all([len(xjs[0]) == len(xjs[j]) for j in range(self.J)])
        self.xjs = xjs
    def __len__(self):
        return len(self.xjs[0])
    def __getitem__(self,idx):
        print(idx)
        return [self.xjs[j][idx] for j in range(self.J)], idx

class RPMDatasetSinglegraph(torch.utils.data.Dataset):
    def __init__(self,xjs,num_features=None):
        self.J = len(xjs)
        assert all([len(xjs[0]) == len(xjs[j]) for j in range(self.J)])
        self.xjs = xjs
    def __len__(self):
        return 1
    def __getitem__(self,idx):
        return [torch.zeros((0,1)) for j in range(self.J)], idx

def RPMDataset(xjs):
    if graph_comp_mode == 'multi':
        return RPMDatasetMultigraph(xjs)
    elif graph_comp_mode == 'single':
        return RPMDatasetSinglegraph(xjs)
    else: 
        RaiseException()

ds = RPMDataset(xjs)
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=True)

ls,t = np.zeros(epochs*(N//batch_size)),0

for i in range(epochs):
    for (batch, idx_data) in dl:
        optimizer_p.zero_grad()
        optimizer_q.zero_grad()

        loss = rpm.training_step(batch, idx_data, batch_idx=t)
        loss.backward()

        optimizer_p.step()
        optimizer_q.step()
        ls[t] = loss.detach().numpy()
        print('step #', t, '/', len(ls), ', loss = ', ls[t])
        t+=1


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
if ls.min() < 0:
    plt.semilogy( np.arange(t), ls[:t] - ls[:t].min() + 1e-7 )
    plt.ylabel('shifted loss (axis shifted to ensure positive values for semilogy)')
else:
    plt.semilogy( np.arange(t), ls[:t])
plt.show()

In [None]:
mu.shape

In [None]:
d = 300
eta = rpm.q.nat_param(nat_param_offset=rpm.latent_prior.nat_param)
angle = torch.arctan2(eta[:,0::2], eta[:,1::2]).detach().numpy()
plt.plot(angle[0])

r = torch.sqrt((eta.reshape(-1,d,2)**2).sum(axis=-1)).reshape(-1,d,1)
mu =  eta.reshape(-1,d,2) * (torch.special.i1(r)/torch.i0(r)/(r+1.*(r==0.)))
mu = mu.reshape(-1,d*2)
mu = rpm.q.mean_param(nat_param_offset=rpm.latent_prior.nat_param)

angle = torch.arctan2(mu[0,0::2], mu[0,1::2]).detach().numpy()
plt.plot(angle, ':', color='red')

mu = rpm.latent_prior.nat_param
angle = torch.arctan2(mu[0,0::2], mu[0,1::2]).detach().numpy()
plt.plot(angle, '--', color='green')


In [None]:
mu_x = rpm.q.mean_param(nat_param_offset=rpm.latent_prior.nat_param).detach().numpy()

plt.plot(Z, mu_x[0, 300:], '.')

In [None]:
import matplotlib 
clrs = matplotlib.cm.get_cmap('hsv')(np.linspace(0,1,n_nodes))[:,:3]

mu_x = rpm.q.mean_param(nat_param_offset=rpm.latent_prior.nat_param).detach().numpy().reshape(-1,2)
mu_x.shape

idx = np.argsort(Z.flatten())
for i in range(n_nodes):
    plt.plot(Z.flatten()[idx[i]], np.arctan2(*mu_x[idx[i]]), '.', color=clrs[i])
    #plt.plot(i, mu_x.flatten()[idx[i]], '.', color=clrs[i])
plt.show()

mu_sorted = mu_x.flatten()[idx]
Z_sorted = Z.flatten()[idx]

pwdists_Z = np.sqrt((Z_sorted.reshape(-1,1) - Z_sorted.reshape(1,-1))**2)
pwdists_mu = np.sqrt((mu_sorted.reshape(-1,1) - mu_sorted.reshape(1,-1))**2)
plt.figure(figsize=(16,7))
plt.subplot(1,2,1)
plt.imshow(pwdists_mu)
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(pwdists_Z)
plt.colorbar()
plt.show()

In [None]:
import matplotlib 
clrs = matplotlib.cm.get_cmap('hsv')(np.linspace(0,1,n_nodes))[:,:3]

p = rpm.q.nat_param(nat_param_offset=rpm.latent_prior.nat_param).detach().numpy()
d = rpm.q.log_partition.d
mu_x = -0.5 * p[:,:d] / p[:,d:]
mu_x.shape

for i in range(n_nodes//2):
    plt.subplot(1,2,1)
    plt.plot(mu_x[:,2*i+0], mu_x[:,2*i+1], '.', color=clrs[i])
    plt.subplot(1,2,2)
    plt.plot(Z[i,0], Z[i,1], 'x', color=clrs[i])
plt.show()

plt.figure(figsize=(16,8))
for i in range(n_nodes//2, n_nodes):
    plt.subplot(1,2,1)
    plt.plot(mu_x[:,2*i+0], mu_x[:,2*i+1], '.', color=clrs[i])
    plt.subplot(1,2,2)
    plt.plot(Z[i,0], Z[i,1], 'x', color=clrs[i])
plt.show()

# Toy toy-test-cases

# Linear-Gaussian true generative model with univariate latents

In [None]:
import torch
import torch_geometric as pyg
import numpy as np

dim_T = 2
J = 3
dim_js = [20 for j in range(J)]  # dimensions of marginals
dim_Z = 1

N = 1000
Z_true = np.random.normal(size=(N,dim_Z))
A = [-2., 0., 1.]

xjs = [A[j] * Z_true + np.random.normal(size=(N,dim_js[j])) for j in range(J)]
xjs = [torch.tensor(xj, dtype=torch.float32) for xj in xjs]

[xj.shape for xj in xjs]

# 2D Gaussian copula with Gaussian or Exponential marginals

In [None]:
from implicitRPM import ObservedMarginal, IndependentMarginal, GaussianCopula_ExponentialMarginals
from discreteRPM import discreteRPM, Prior_discrete, RecognitionFactor_discrete

import torch
import numpy as np

J = 2                           # three marginals 
dim_js = [1 for j in range(J)]  # dimensions of marginals
dim_Z = 1                       # dimension of latent
dim_T = 2                       # dimension of sufficient statistics


# currently playing with either Gaussian or Exponential marginals
marginals = 'exponential' 
if marginals == 'exponential':
    rates = [1.0, 0.5, 3.0][:J]
    pxjs = [ObservedMarginal(torch.distributions.exponential.Exponential(rate=rates[j])) for j in range(J)]
    P = np.array([[1.0, -0.85], [-0.85, 1.0]])
    print('P:', P)
    px = GaussianCopula_ExponentialMarginals(P=P, rates=rates, dims=dim_js)
elif marginals == 'gaussian':
    locs, scales = [-1.5, -0.5, 3.0][:J], [1.0, 2.0, 0.25][:J]
    pxjs = [torch.distributions.normal.Normal(loc=locs[j], scale=scales[j]) for j in range(J)]
else: 
    raise Exception('marginals not implemented')
pxind = IndependentMarginal(pxjs, dims=dim_js)

N = 10000
xjs = px.sample_n(n=N)


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot()
ax.semilogy( np.arange(t), ls[:t] - ls[:t].min() + 1e-7 )
plt.show()

In [None]:
%matplotlib notebook
p = rpm.q.nat_param(nat_param_offset=rpm.latent_prior.nat_param).detach().numpy()
mu_x = -0.5 * p[:,0] / p[:,1]

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(xjs[0][:,0], xjs[1][:,0], mu_x, marker='o')

In [None]:
p = rpm.q.nat_param(nat_param_offset=rpm.latent_prior.nat_param).detach().numpy()
plt.plot(Z_true, -0.5 * p[:,0] / p[:,1], '.')
plt.show()
p

In [None]:
[p for p in rpm.joint_model[0].parameters()]