In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import gpytorch
from gpytorch.kernels import MaternKernel, ScaleKernel, \
RBFKernel, CosineKernel
from gpytorch.constraints import Positive

import matplotlib.pyplot as plt
import numpy as np

## Note: 
Define $\sigma: \mathbb{R} \rightarrow (0, \infty)$ and $g: \mathbb{R}^{d \times n} \rightarrow \mathbb{R}^{n \times n}$ by $$\sigma(t) = \frac{e^t}{1+e^t}$$ and $$g(\theta)_{i,j} = \sigma[(\theta^T \theta)_{i,j}]$$


In [4]:
sigmoid = nn.Sigmoid()

## **Implementation:** 
Let $\theta \in \mathbb{R}^{d \times n}$. For $t \in \{1, \ldots, T\}$ and $i,j \in \{1, \ldots, n\}$, suppose that $A^{(t)}_{i,j} \overset{\text{iid}}{\sim} \text{Ber}(g(\theta)_{i,j})$ 

In [None]:
# negative average log-likelihood on R^dxn/O(d):

# input:
    #eta: (d+1 choose 2) x 1 tensor containing real number parametrizing the orbit 
    #A: dxnxT tensor of binary data such that for each t, A[:,:,t] is an observed adjacency matrix
    
# output:
    # nall_val: value of negative average log likelihood for data A on orbit of theta_eta 
    # where theta_eta the embedding of R^{d+1 choose 2} into the lower-triangular matrices
    # of R^{dxn}

def neg_avg_log_like_bar(eta, A):
    # embed
    p = sigmoid(t**2)
    nall = -torch.mean(x*torch.log(p) + (1-x)*torch.log(1-p))
    
    
    return nall

## **Comparison to MASE:**

In [None]:
# MASE: 

# Vt_hat = d left leading singular vectors of nxd adj matrix At_hat (use torch.linalg.svd)
# U_hat = nxtd concatenated Vt_hats
# V_hat = d left leading singular vectors of U_hat