# Numpy

In [None]:
import numpy as np

In [None]:
def per(n):
    X = []
    for i in range(1<<n):
        s=bin(i)[2:]
        s='0'*(n-len(s))+s
        x = list(map(int,list(s)))
        X.append(x)
    return np.array(X)

def p_x_given_y(M, X, Y, V, a):
    """
        Q: (t, t), symmetric
        V: (d, t)
        a: (d, 1)
        Y: (b, d)
        X: (2**t, t)
    """
    d_t = M.shape[0]
    d_n = 2**d_t
    d_d = V.shape[0]
    d_b = Y.shape[0]
    
    Q = 0.5 * (M + M.T)
    
    assert V.shape == (d_d, d_t)
    assert Q.shape == (d_t, d_t)
    assert a.shape == (d_d, 1)
    assert X.shape == (d_n, d_t)
    assert Y.shape == (d_b, d_d)
    
    E1 = np.einsum('ki,ij,kj -> k', X, Q, X)
    E2 = (V.dot(X.T) + a).T
    E3 = E2.dot(Y.T)
    logits = E1 + E3.T

    assert logits.shape == (d_b, d_n)
    return logits

In [None]:
t = 3     # dimension of discrete variable
d = 4     # dimension of continuous variable
n = 2**t  # dimension of discrete distribution
b = 6     # batch size (for training later on)

In [None]:
print(t, d, n, b)

In [None]:
M = np.random.random((t, t)).astype(np.float32)
V = np.random.random((d, t)).astype(np.float32)
a = np.random.random((d, 1)).astype(np.float32)
#
Y = np.random.random((b, d)).astype(np.float32)
X = per(t)

In [None]:
print(X)

In [None]:
logits_np = p_x_given_y(M, X, Y, V, a)
print(logits_np.shape)

In [None]:
logits_np

# Torch

In [None]:
import torch

In [None]:
class LHBarlowTwins(torch.nn.Module):
    def  __init__(self, d_t, d_d, loss_param_scale=1., loss_param_lmbda=1.):
        super(LHBarlowTwins, self).__init__()
        self.d_t = d_t
        self.d_d = d_d
        self.d_n = 2**d_t
        
        self.M = torch.rand((d_t, d_t), requires_grad=True)
        self.V = torch.rand((d_d, d_t), requires_grad=True)
        self.a = torch.rand((d_d, 1), requires_grad=True)
        
        self.X = torch.Tensor(per(d_t))
        
        # affine = False -> no learnable parameters
        self.bn = torch.nn.BatchNorm1d(self.d_n, affine=False)
        
        self.loss_param_scale = loss_param_scale
        self.loss_param_lmbda = loss_param_lmbda
    
    def p_x_given_y(self, Y):
        Q = 0.5 * (self.M + self.M.T)
        E1 = torch.einsum('ki,ij,kj -> k', self.X, Q, self.X)
        E2 = (self.V.matmul(self.X.T) + self.a).T
        E3 = torch.matmul(E2, Y.T)
        logits = E1 + E3.T
        return logits
    
    def forward(self, y1, y2):
        
        z1 = self.p_x_given_y(y1)
        z2 = self.p_x_given_y(y2)
        #
        # emprical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)
        c = c / y1.shape[0]

        loss = self.loss(c)
        return loss
    
    def from_numpy(self, M, V, a):
        #
        d_t = M.shape[0]
        d_d = V.shape[0]
        d_n = 2**d_t

        assert M.shape == (d_t, d_t)
        assert V.shape == (d_d, d_t)
        assert a.shape == (d_d, 1)
        
        self.M = torch.tensor(M, requires_grad=True)
        self.V = torch.tensor(V, requires_grad=True)
        self.a = torch.tensor(a, requires_grad=True)
        #
        self.d_t = d_t
        self.d_d = d_d
        self.d_n = d_n
    
    def off_diagonal(self, x):
        # return a flattened view of the
        # off-diagonal elements of a square matrix
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

    def loss(self, c):
        on_diag = torch.diagonal(
            c).add_(-1).pow_(2).sum().mul(self.loss_param_scale)
        off_diag = self.off_diagonal(c).pow_(
            2).sum().mul(self.loss_param_scale)
        #
        loss = on_diag + self.loss_param_lmbda * off_diag
        return loss

In [None]:
model = LHBarlowTwins(t, d)
model.from_numpy(M, V, a)

In [None]:
logits_pt = model.p_x_given_y(torch.Tensor(Y))

In [None]:
y1 = torch.rand((b, d))
y2 = torch.rand((b, d))

In [None]:
model.forward(y1, y2)