In [1]:
import sys
import torch
import numpy as np
from numpy.random import random
from torch import nn
from torch import digamma
from torch.distributions import MultivariateNormal as MVN
from torch.distributions import Bernoulli as Bern
from torch.distributions import Poisson as Pois
from torch.distributions import Categorical as Categorical

In [None]:
class UncollapsedGibbsIBP(nn.Module):
    ################################################
    ########### UNCOLLAPSED GIBBS SAMPLER ##########
    ################################################
    ### Depends on a few self parameters but could##
    ### be made a standalone script if need be #####
    ###############################################
    def __init__(self, alpha, K, max_K, sigma_a, sigma_n, epsilon, lambd, phi):
        super(UncollapsedGibbsIBP, self).__init__()

        # idempotent - all are constant and have requires_grad=False
        self.alpha = torch.tensor(alpha)
        self.K = torch.tensor(K)
        self.max_K = torch.tensor(max_K)
        self.sigma_a = torch.tensor(sigma_a)
        self.sigma_n = torch.tensor(sigma_n)
        self.epsilon = torch.tensor(epsilon)
        self.lambd = torch.tensor(lambd)
        self.phi = torch.tensor(phi)

    def init_A(self,K,D):
        '''
        Sample from prior p(A_k)
        A_k ~ N(0,sigma_A^2 I)
        '''
        Ak_mean = torch.zeros(D)
        Ak_cov = self.sigma_a.pow(2)*torch.eye(D)
        p_Ak = MVN(Ak_mean, Ak_cov)
        A = torch.zeros(K,D)
        for k in range(K):
            A[k] = p_Ak.sample()
        return A

    def init_Y(self, n_filters, n_pixels):
        '''
        Sample from prior p(Y_kd)
        Y_kd ~ Bern(epsilon)
        '''
        Y = torch.zeros(n_filters,n_pixels)
        for k in range(n_filters):
            for d in range(n_pixels):
                p_Ykd = Bern(self.epsilon)
                Y[k,d] = p_Ykd.sample()
        
        return Y

    def left_order_form(self,Z):
        Z_numpy = Z.clone().numpy()
        twos = np.ones(Z_numpy.shape[0])*2.0
        twos[0] = 1.0
        powers = np.cumprod(twos)[::-1]
        values = np.dot(powers,Z_numpy)
        idx = values.argsort()[::-1]
        return torch.from_numpy(np.take(Z_numpy,idx,axis=1))

    def init_Z(self,N=20):
        '''
        Samples from the IBP prior that defines P(Z).

        First Customer i=1 takes the first Poisson(alpha/(i=1)) dishes
        Each next customer i>1 takes each previously sampled dish k
        independently with m_k/i where m_k is the number of people who
        have already sampled dish k. Z_ik=1 if the ith customer sampled
        the kth dish and 0 otherwise.
        '''
        Z = torch.zeros(N,self.K)
        K = int(self.K.item())
        total_dishes_sampled = 0
        for i in range(N):
            selected = torch.rand(total_dishes_sampled) < Z[:,:total_dishes_sampled].sum(dim=0) / (i+1.)
            Z[i][:total_dishes_sampled][selected]=1.0
            p_new_dishes = Pois(torch.tensor([self.alpha/(i+1)]))
            new_dishes = int(p_new_dishes.sample().item())
            if total_dishes_sampled + new_dishes >= K:
                new_dishes = K - total_dishes_sampled
            Z[i][total_dishes_sampled:total_dishes_sampled+new_dishes]=1.0
            total_dishes_sampled += new_dishes
        
        return self.left_order_form(Z)

    def remove_allzeros_ZAY(self,Z,A,Y):
        """
        Remove columns (features) from Z that are not active, and also the corresponding rows from A and Y
        """
        to_keep = Z.sum(dim=0) > 0
        return Z[:, to_keep], A[to_keep, :], Y[to_keep, :]



    def F_loglik_given_ZA(self,F,Z,A):
        '''
        p(F|Z,A) = 1/([2*pi*sigma_n^2]^(ND/2)) * exp([-1/(2*sigma_n^2)] tr((F-ZA)^T(F-ZA)))
        '''
        N = F.size()[0]
        D = F.size()[1]
        pi = np.pi
        sig_n2 = self.sigma_n.pow(2)
        one = torch.tensor([1.0])
        log_first_term = one.log() - (N*D/2.)*(2*pi*sig_n2).log()
        log_second_term = ((-1./(2*sig_n2)) * \
            torch.trace((F-Z@A).transpose(0,1)@(F-Z@A)))
        log_likelihood = log_first_term + log_second_term

        return log_likelihood

    def X_loglik_given_ZY(self,X,Z,Y):
        '''
        p(X|Z,Y) = prod_n prod_d p(x_nd|Z,Y)
          let e_n = Z_n,:@Y_:,n
        p(x_nd=1|Z,Y) = (1 - (1-lamb)^e_n) * (1-epsilon)
        p(x_nd=0|Z,Y) = (1-lamb)^e_n * (1-epsilon)
        '''
        N, D, K = X.shape[0], X.shape[1], Z.shape[1]

        lamb = self.lambd
        ep = self.epsilon

        # Initialize the likelihood variable
        log_likelihood = 0.0

        # Loop over each image (or do this in a batch-wise fashion)
        for i in range(N):
            # Compute the effective feature activations for the i-th image
            e_n = torch.matmul(Z[i, :], Y) # size (1, D)
            
            # Calculate the log-likelihood for the i-th image
            log_likelihood += torch.sum(
                X[i, :] * torch.log(1 - ((1 - lamb) ** e_n) * (1 - ep)) +
                (1 - X[i, :]) * torch.log((1 - lamb) ** e_n * (1 - ep))
            )

        return log_likelihood


    def resample_Z_ik(self,Z,F,X,A,Y,i,k):
        '''
        m = number of observations not including Z_ik containing feature k

        Prior: p(z_ik=1) = m / (N-1)
        
        Posterior combines the prior with the likelihood:
        p(z_ik=1|Z_-nk,F,X,A,Y) propto p(z_ik=1)p(X|Z,Y)p(F|Z,A)
        
        Z_ik is a Bernoulli RV with this posterior probability
        '''
        N,D = X.size()
        Z_k = Z[:,k]
        
        m = Z_k.sum() - Z_k[i] # Called m_-nk in the paper

        # If Z_nk were 0
        Z_if_0 = Z.clone()
        Z_if_0[i,k] = 0
        
        log_prior_if_0 = (1 - (m/(N-1))).log() #Prior
        F_log_likelihood_if_0 = self.F_loglik_given_ZA(F,Z_if_0,A) # Likelihood of F
        X_log_likelihood_if_0 = self.X_loglik_given_ZY(X,Z_if_0,Y) # Likelihood of X

        log_score_if_0 = log_prior_if_0 + F_log_likelihood_if_0 + X_log_likelihood_if_0

        # If Z_nk were 1
        Z_if_1 = Z.clone()
        Z_if_1[i,k]=1
        
        log_prior_if_1 = (m/(N-1)).log() # Prior
        F_log_likelihood_if_1 = self.F_loglik_given_ZA(F,Z_if_1,A) # Likelihood of F  
        X_log_likelihood_if_1 = self.X_loglik_given_ZY(X,Z_if_1,Y) # Likelihood of X
      
        log_score_if_1 = log_prior_if_1 + F_log_likelihood_if_1 + X_log_likelihood_if_1

        # Exp, Normalize, Sample
        log_scores = torch.cat((log_score_if_0,log_score_if_1),dim=0)
        probs = self.renormalize_log_probs(log_scores)
        p_znk = Bern(probs[1])

        return p_znk.sample() # 0 or 1


    def renormalize_log_probs(self,log_probs):
        log_probs = log_probs - log_probs.max()
        likelihoods = log_probs.exp()
        return likelihoods / likelihoods.sum()


    def F_loglik_given_k_new(self,cur_F_minus_ZA,Z,D,i,j):
        '''
        cur_F_minus_ZA is equal to F - ZA, using Z without the
        extra j columns that are appended to compute the likelihood
        for X|k_new=j. We have to pass this in because Z is changed
        in a loop that calls this function.

        Z: each time this function is called in the loop one level up,
        Z has one more column. Z is N x (K + k_new=j) dimensional.

        D: F.size()[1]

        i: A few levels up from this function, we are looping through every datapoint,
        and for each datapoint, considering how many new features k_new it draws. We
        are considering the i^th datapoint.

        j: We are calculating the likelihood for F|k_new = j
        '''
        N,K=Z.size()
        cur_F_minus_ZA_T = cur_F_minus_ZA.transpose(0,1)
        sig_n = self.sigma_n
        sig_a = self.sigma_a

        if j==0:
            ret = 0.0
        else:
            w = torch.ones(j,j) + (sig_n/sig_a).pow(2)*torch.eye(j)
            # alternative: torch.potrf(a).diag().prod()
            w_numpy = w.numpy()
            sign,log_det = np.linalg.slogdet(w_numpy)
            log_det = torch.tensor([log_det],dtype=torch.float32)
            # Note this is in log space
            first_term = j*D*(sig_n/sig_a).log() - ((D/2)*log_det)

            second_term = 0.5* \
                torch.trace( \
                cur_F_minus_ZA_T @ \
                Z[:,-j:] @ \
                w.inverse() @ \
                Z[:,-j:].transpose(0,1) @ \
                cur_F_minus_ZA) / \
                sig_n.pow(2)
            ret = first_term + second_term

        return ret
    
    def X_loglik_given_k_new(self, Z, Y, X, i, orig_k, k_new):
        """
        Calculate the log-likelihood of observing X[i, :] given Z, Y, and a proposed new feature count k_new.

        Parameters:
        - i: int, index of the image row in X
        - k_new: int, proposed number of new features
        - Z: Tensor, binary matrix of shape (N, K) for current feature ownership
        - Y: Tensor, binary matrix of shape (K, d) for feature-to-pixel activations
        - X: Tensor, binary matrix of shape (N, d) for observed images
        - lamb: float, efficacy parameter for feature activation
        - ep: float, spontaneous activation probability for each pixel
        - p: float, probability of a new feature turning on a pixel

        Returns:
        - log_likelihood: Tensor, the computed log-likelihood value for this k_new
        """

        lamb = self.lambd
        ep = self.epsilon
        p = self.phi

        # Compute effective feature activations for the i-th image
        e = torch.matmul(Z[i, 0:orig_k], Y[0:orig_k, :])
        
        # Indices of pixels that are "on" and "off" in X[i, :]
        one_inds = [t for t in range(X.shape[1]) if X[i, t] == 1]
        zero_inds = [t for t in range(X.shape[1]) if t not in one_inds]
        
        # Compute eta values for "on" and "off" pixels
        eta_one = (1 - lamb) ** e[one_inds]
        eta_zero = (1 - lamb) ** e[zero_inds]
        
        # Calculate likelihood components for pixels that are "on" and "off"
        lhood_XiT = torch.sum(torch.log(1 - (1 - ep) * eta_one * ((1 - lamb * p) ** k_new)))
        lhood_XiT += torch.sum(torch.log((1 - ep) * eta_zero * ((1 - lamb * p) ** k_new)))
        
        return lhood_XiT

    def sample_k_new(self,Z,F,X,A,Y,i,truncation=10):
        '''
        i: The loop calling this function is asking this function
        "how many new features (k_new) should data point i draw?"

        truncation: When computing the un-normalized posterior for k_new|X,Z,A, we cannot
        compute the posterior for the infinite amount of values k_new could take on. So instead
        we compute from 0 up to some high number, truncation, and then normalize. In practice,
        the posterior probability for k_new is so low that it underflows past truncation=20.
        '''

        N,K = Z.size()
        D = X.size()[1]

        # # Check if we are at the maximum number of features
        # if K == self.max_K:
        #     return 0

        p_k_new = Pois(torch.tensor([self.alpha/N]))
        cur_F_minus_ZA = F - Z@A
        
        prior_poisson_probs = torch.zeros(truncation)
        F_log_likelihood = torch.zeros(truncation)
        X_log_likelihood = torch.zeros(truncation)

        for j in range(truncation):
            # Compute the prior probability of k_new equaling j
            prior_poisson_probs[j] = p_k_new.log_prob(torch.tensor(j))

            # Compute the log likelihood of F with k_new equaling j
            F_log_likelihood[j] = self.F_loglik_given_k_new(cur_F_minus_ZA,Z,D,i,j)

            # Compute the log likelihood of X with k_new equaling j
            X_log_likelihood[j] = self.X_loglik_given_k_new(Z,Y,X,i,K,j)

            # Add new column to Z for next feature
            zeros = torch.zeros(N)
            Z = torch.cat((Z,torch.zeros(N,1)),1)
            Z[i][-1]=1

        # Compute log posterior of k_new and exp/normalize
        log_sample_probs = prior_poisson_probs + F_log_likelihood + X_log_likelihood
        sample_probs = self.renormalize_log_probs(log_sample_probs)

        # Important: we changed Z for calculating p(k_new| ...) so we must take off the extra rows
        Z = Z[:,:-truncation]
        assert Z.size()[1] == K
        posterior_k_new = Categorical(sample_probs)
        return posterior_k_new.sample()

    def resample_Z(self,Z,F,X,A,Y):
        '''
        - Re-samples existing Z_ik by using p(Z_ik=1|Z_-ik,A,X)
        - Samples the number of new dishes that customer i takes
          corresponding to:
            - prior: p(k_new) propto Pois(alpha/N)
            - likelihood: p(X|Z_old,A_old,k_new)
            - posterior: p(k_new|X,Z_old,A_old)
        - Adds the columns to Z corresponding to the new dishes,
          setting those columns to 1 for customer i
        - Adds rows to A corresponds to the new dishes.
          - p(A_new|X,Z_new,Z_old,A_old) propto p(X|Z_new,Z_old,A_old,A_new)p(A_new)
        '''

        N = F.size()[0]
        K = A.size()[0]
        
        # Iterate over each data point
        for i in range(N):
            # Resample existing Z_ik
            for k in range(K):
                Z[i,k] = self.resample_Z_ik(Z,F,X,A,Y,i,k)
            
            # Decide how many new features to draw
            k_new = self.sample_k_new(Z,F,X,A,Y,i)

            # Limit such that current_k + k_new <= max_K
            # current_k = A.size()[0]
            # k_new = np.clip(k_new, 0, self.max_K - current_k)

            # If new features are drawn, add them to Z, A, and Y
            if k_new > 0:
                # Add new columns to Z
                Z = torch.cat((Z,torch.zeros(N,k_new)),1)
                for j in range(k_new):
                    Z[i][-(j+1)] = 1

                # Add new rows to A, based on Z and A
                A_new = self.A_new(F,k_new,Z,A)
                A = torch.cat((A,A_new),dim=0)

                # Add new rows to Y, based on Z and Y
                Y_new = self.Y_new(k_new,Y.size()[1])
                Y = torch.cat((Y,Y_new),dim=0)
                # resample Y_new at the new features
                Y = self.resample_Y(Z, X, Y, start_idx=K)

        return Z, A, Y

    def resample_A(self,F,Z):
        '''
        mu = (Z^T Z + (sigma_n^2 / sigma_A^2) I )^{-1} Z^T  X
        Cov = sigma_n^2 (Z^T Z + (sigma_n^2/sigma_A^2) I)^{-1}
        p(A|X,Z) = N(mu,cov)
        '''
        N,D = F.size()
        K = Z.size()[1]
        ZT = Z.transpose(0,1)
        ZTZ = ZT@Z
        I = torch.eye(K)
        sig_n = self.sigma_n
        sig_a = self.sigma_a
        mu = (ZTZ + (sig_n/sig_a).pow(2)*I).inverse()@ZT@F
        cov = sig_n.pow(2)*(ZTZ + (sig_n/sig_a).pow(2)*I).inverse()
        A = torch.zeros(K,D)
        for d in range(D):
            p_A = MVN(mu[:,d],cov)
            A[:,d] = p_A.sample()
        return A

    def A_new(self,F,k_new,Z,A):
        '''
        p(A_new | X, Z_new, Z_old, A_old) propto
            p(X|Z_new,Z_old,A_old,A_new)p(A_new)
        ~ N(mu,cov)
            let ones = knew x knew matrix of ones
            let sig_n2 = sigma_n^2
            let sig_A2 = sigma_A^2
            mu =  (ones + sig_n2/sig_a2 I)^{-1} Z_new_T (X - Z_old A_old)
            cov = sig_n2 (ones + sig_n2/sig_A2 I)^{-1}
        '''
        N,D = F.size()
        K = Z.size()[1]
        assert K == A.size()[0]+k_new
        ones = torch.ones(k_new,k_new)
        I = torch.eye(k_new)
        sig_n = self.sigma_n
        sig_a = self.sigma_a
        Z_new = Z[:,-k_new:]
        Z_old = Z[:,:-k_new]
        Z_new_T = Z_new.transpose(0,1)
        # mu is k_new x D
        mu = (ones + (sig_n/sig_a).pow(2)*I).inverse() @ \
            Z_new_T @ (F - Z_old@A)
        # cov is k_new x k_new
        cov = sig_n.pow(2) * (ones + (sig_n/sig_a).pow(2)*I).inverse()
        A_new = torch.zeros(k_new,D)
        for d in range(D):
            p_A = MVN(mu[:,d],cov)
            A_new[:,d] = p_A.sample()
        return A_new

    def resample_Y(self, Z, X, Y, start_idx=0):
        """
        Sample the feature-to-pixel activation matrix Y given the current feature matrix Z and the observed images X.
        """
        K = Z.size()[1]
        N, T = X.size()
        ep = self.epsilon
        lamb = self.lambd
        p = self.phi

        pY_a0 = torch.zeros(K, T)
        pY_a1 = torch.zeros(K, T)
        
        prior_Y_a0 = torch.log(1 - p)
        prior_Y_a1 = torch.log(p)

        for t in range(T):
            for k in range(start_idx, K):
                for a in [0, 1]:
                    Y[k, t] = a
                    e = torch.matmul(Z, Y[:, t])
                    
                    log_likelihood = torch.sum(
                        (X[:, t])*torch.log(1-((1-lamb)**e)*(1-ep)) + (1-X[:, t])*torch.log((1-lamb)**e*(1 - ep))
                        )

                    if a == 0:
                        pY_a0[k, t] = torch.exp(prior_Y_a0 + log_likelihood)
                    else:
                        pY_a1[k, t] = torch.exp(prior_Y_a1 + log_likelihood)

                # Normalize the probabilities
                tempsum = pY_a0[k, t] + pY_a1[k, t]
                pY_a0[k, t] /= tempsum
                pY_a1[k, t] /= tempsum

                # Sample the element
                p_Ykt = Bern(pY_a1[k, t])
                Y[k, t] = p_Ykt.sample()

        return Y
                
                
    def Y_new(self, k_new, D):
        
        return torch.zeros(k_new, D) 

    def gibbs(self, F, X, iters):
        n_obs_X, n_pixels = X.size()
        n_obs_F, n_features = F.size()
        assert n_obs_X == n_obs_F, "Number of observations in X and F must match"
        n_obs = n_obs_X

        K = self.K

        Z = self.init_Z(n_obs)
        A = self.init_A(K, n_features)
        Y = self.init_Y(K, n_pixels)

        As = []
        Zs = []
        Ys = []

        for i in range(iters):
            print('iteration:', i, end='\r')
            # Gibbs resampling
            A = self.resample_A(F, Z)
            Y = self.resample_Y(Z, X, Y)
            Z, A, Y = self.resample_Z(Z,F,X,A,Y)

            # cleanup
            Z, A, Y = self.remove_allzeros_ZAY(Z, A, Y)

            # save the samples to the chain
            As.append(A.clone().numpy())
            Zs.append(Z.clone().numpy())
            Ys.append(Y.clone().numpy())

        return As,Zs,Ys


In [None]:
# Combinations of latent states
true_Z_options = [
    [1, 0, 0], 
    [0, 1, 0], 
    [0, 0, 1], 
    [1, 1, 0], 
    [1, 0, 1], 
    [0, 1, 1], 
    [1, 1, 1]
    ]

# True force directions corresponding to each latent state
true_A = torch.tensor([
    [ 0.2,  0.2], 
    [-0.2,  0.2], 
    [   0, -0.2]
    ], dtype=torch.float32)

# True observation corresponding to each latent state
true_Y = torch.tensor([
    [1, 0, 0, 0], 
    [0, 1, 0, 0], 
    [0, 0, 0, 1]
    ], dtype=torch.float32)


# Generate Dataset
Z_latent = []
X_dataset = []
F_dataset = []
for i in range(5):
    for selected_Z in true_Z_options:
        Z_latent.append(selected_Z)

        # generate the true F element 
        true_F = torch.matmul(torch.tensor(selected_Z).float(), true_A)
        F_dataset.append(true_F)

        # generate the true X as ZY
        true_X = torch.matmul(torch.tensor(selected_Z).float(), true_Y)
        X_dataset.append(true_X)

X_dataset = torch.stack(X_dataset)
F_dataset = torch.stack(F_dataset)

# print(X_dataset)
# print(F_dataset)
# print(np.array(Z_latent))


In [17]:
def add_noise_to_obs(X, F, F_noise_std = 0.01, lambd=0.98, epsilon=0.02):
    """
    Add gaussian noise to force data, and randomly flip pixels in the observation data
    """
    F += torch.randn(F.size()) * F_noise_std
    
    X_noisy = torch.zeros(X.size())
    for i in range(X.size()[0]):
        for j in range(X.size()[1]):
            if X[i, j] == 1:
                X_noisy[i, j] = 1 if random() < lambd else 0
            else:
                X_noisy[i, j] = 1 if random() < epsilon else 0

    return X_noisy, F

In [27]:
X_dataset, F_dataset = add_noise_to_obs(X_dataset, F_dataset)

print(X_dataset[:5,])

tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [1., 1., 0., 0.],
        [1., 0., 0., 1.]])


In [31]:
inf = UncollapsedGibbsIBP(alpha=0.05, K=1, max_K=4, sigma_a=0.2, sigma_n=0.1, epsilon=0.01, lambd=0.99, phi=0.25)

As, Zs, Ys = inf.gibbs(F_dataset, X_dataset, 100)

iteration: 99

In [7]:
def extract_mean_from_samples(As, Zs, Ys, n=10):
    A_mean = np.round(np.mean(np.array(As[-n:]),axis=0), 2)
    Z_mean = np.round(np.mean(np.array(Zs[-n:]),axis=0), 2)
    Y_mean = np.round(np.mean(np.array(Ys[-n:]),axis=0), 2)

    return A_mean, Z_mean, Y_mean

def compare_distance(reference_matrix, inferred_matrix):
    """
    Compare the distance between rows of the reference and the true matrix.
    use this to create a permutation matrix that reorders the inffered matrix to match the reference matrix,
    and return the permutation matrix
    """
    n, m = reference_matrix.shape
    assert inferred_matrix.shape == (n, m)

    # compute the distance matrix
    distance_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            distance_matrix[i, j] = np.linalg.norm(reference_matrix[i] - inferred_matrix[j])

    # find the permutation that minimizes the distance
    from scipy.optimize import linear_sum_assignment
    row_ind, col_ind = linear_sum_assignment(distance_matrix)
    
    # create the permutation matrix that corresponds to this reordering
    permutation_matrix = np.zeros((n, n))
    for i in range(n):
        permutation_matrix[i, col_ind[i]] = 1

    return permutation_matrix


In [32]:
A, Z, Y = extract_mean_from_samples(As, Zs, Ys, n=10)

reorder = compare_distance(true_A.numpy(), A)

print("True A:")
print(true_A.numpy())
print("Inferred A:")
print(np.round(reorder @ A,2))

print("\nTrue Y:")
print(true_Y.numpy())
print("Inferred Y:")
print(np.round(reorder @ Y,2))

print("\nTrue Z:")
print(np.array(Z_latent))
print("Inferred Z:")
print(np.round(Z @ reorder,0))

True A:
[[ 0.2  0.2]
 [-0.2  0.2]
 [ 0.  -0.2]]
Inferred A:
[[ 0.21  0.19]
 [-0.2   0.18]
 [-0.01 -0.17]]

True Y:
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]]
Inferred Y:
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]]

True Z:
[[1 0 0]
 [0 1 0]
 [0 0 1]
 [1 1 0]
 [1 0 1]
 [0 1 1]
 [1 1 1]
 [1 0 0]
 [0 1 0]
 [0 0 1]
 [1 1 0]
 [1 0 1]
 [0 1 1]
 [1 1 1]
 [1 0 0]
 [0 1 0]
 [0 0 1]
 [1 1 0]
 [1 0 1]
 [0 1 1]
 [1 1 1]
 [1 0 0]
 [0 1 0]
 [0 0 1]
 [1 1 0]
 [1 0 1]
 [0 1 1]
 [1 1 1]
 [1 0 0]
 [0 1 0]
 [0 0 1]
 [1 1 0]
 [1 0 1]
 [0 1 1]
 [1 1 1]]
Inferred Z:
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 1. 0.]
 [1. 0. 1.]
 [0. 1. 1.]
 [1. 1. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 1. 0.]
 [1. 0. 1.]
 [0. 1. 1.]
 [1. 1. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 1. 0.]
 [1. 0. 1.]
 [0. 1. 1.]
 [1. 1. 1.]
 [1. 1. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 1. 0.]
 [1. 0. 1.]
 [0. 1. 1.]
 [1. 1. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 1. 0.]
 [1. 0. 1.]
 [0. 1. 1.]
 [1. 1. 1.]]
