In [4]:
import torch

In [5]:
"""### Data"""

import copy

def compute_total_polynomial_terms(poly_degree, latent_dim):
    """
    Compute the total number of possible terms for polynomials of degree poly_degree.
    """
    count=0
    for degree in range(poly_degree+1):
        count+= pow(latent_dim, degree)
    return count


def compute_kronecker_product(degree, latent):
    """
    Compute the kronecker product of the latent vector with itself for a given degree.
    """
    if degree ==0:
        out= torch.tensor([1.])
    else:
        out=copy.deepcopy(latent)
        for idx in range(1, degree):
            out= torch.kron(out, latent)
    #print(out.shape)
    return out

def compute_decoder_polynomial(poly_degree,latent):
    """
    Compute all the kronecker products of the latent vector with itself up to degree poly_degree.
    """
    out=[]
    for degree in range(poly_degree+1):
        kroneck = compute_kronecker_product(degree, latent)
        out.append(kroneck)
    out= torch.concatenate(out)
    out= torch.reshape(out, (1,out.shape[0]))
    return out

def compute_decoder_polynomial_function(poly_degree):
    """
    Compute the function that computes the polynomial terms of a given degree.
    """
    def poly_decoder(latent):
        out=[]
        for degree in range(poly_degree+1):
    #         print('Computing polynomial term of degree ', degree)
            kroneck = compute_kronecker_product(degree, latent)
            #print(kroneck.shape)
            out.append(kroneck)
        out= torch.concatenate(out)
        out= torch.reshape(out, (1,out.shape[0]))
        return out
    return poly_decoder

class Poly_dec(nn.Module):
    def __init__(self, poly_degree, x_dim, lat_dim):
        super().__init__()
        self.deg = poly_degree
        self.x_dim = x_dim
        self.lat_dim = lat_dim

        self.poly_size = compute_total_polynomial_terms(poly_degree, lat_dim)
        #self.coff_matrix = np.random.multivariate_normal(np.zeros(self.poly_size), np.eye(self.poly_size), size=20).T
        self.coff_matrix = self.full_rank_coef_matrix()

    def full_rank_coef_matrix(self):
        """
        Generate a full rank coefficient matrix for the polynomial decoder.
        """
        M = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(self.poly_size), torch.eye(self.poly_size)).sample((self.x_dim,)).t()
        if torch.linalg.matrix_rank(M) == self.poly_size:
            return M
        else:
            return self.full_rank_coef_matrix()

    def one_sample_forw(self, sample):
        out = []
        for degree in range(self.deg+1):
            kroneck = compute_kronecker_product(degree, sample)
            out.append(kroneck)
        out= torch.concatenate(out)
        out= torch.reshape(out, (1,out.shape[0]))
        return out

    def forward(self, latent):
        if latent.shape[0] == 1:
            x = self.one_sample_forw(latent)
        else:
            x = []
            for idx in range(latent.shape[0]):
                x.append( self.one_sample_forw(latent[idx, :]))
            x = torch.cat(x, dim=0)
        x1= torch.matmul(x[:, :1+self.lat_dim], self.coff_matrix[:1+self.lat_dim, :])
        x2= torch.matmul(x[:, 1+self.lat_dim:], self.coff_matrix[1+self.lat_dim:, :])
        norm_factor= 0.5 * torch.max(torch.abs(x2)) / torch.max(torch.abs(x1))
        x2 = x2 / norm_factor
        x = (x1+x2)
        return x

class Data:
    def __init__(self,
                 X_dim,
                 latent_dim,
                 n_per_env,
                 n_env,
                 eta,
                 device,
                 target_intervention = False,
                 v_mean = None,
                 target_children = False,
                 degree = 2,
                 n_batch = 10,
                 misspecified = False):

        self.X_dim = X_dim
        self.lat_dim = latent_dim
        self.n_per_env = n_per_env
        self.n_env = n_env
        self.eta = eta
        self.device = device
        self.target_intervention = target_intervention
        self.v_mean = v_mean
        self.n_batch = n_batch
        self.ncond = n_env # + 1
        self.lower_triangular = not target_children
        self.degree = degree

        self.graph = get_dag((latent_dim + 1), connected = True)
        self.B = torch.from_numpy(adjacency(self.graph, l_triangular = self.lower_triangular)).to(torch.float32)
        self.C = torch.linalg.inv(torch.eye(latent_dim+1) - self.B).to(torch.float32)
        self.b = self.B[-1,:-1].unsqueeze(1)
        self.misspecified = misspecified

        self.latent_fn()

        self.cov_eps = get_covariance(latent_dim + 1, normed = True, zerow = False)

        self.means_delta = [np.random.uniform(-3,3,size = latent_dim + 1) for e in range(1,n_env)]
        if True:
            self.means_delta = [m/np.linalg.norm(m) for m in self.means_delta]
        self.covs_delta = [get_covariance(latent_dim + 1, normed = True, zerow = (not target_intervention)) for e in range(1,n_env)] #1 for reference env

        self.aa = np.array([self.covs_delta[e] + self.means_delta[e].reshape(-1,1) @ self.means_delta[e].reshape(1,-1) for e in range(self.n_env - 1)]
                      ).reshape(self.n_env -1, self.lat_dim+1, self.lat_dim +1)

        self.vvt = (self.eta/self.n_env) * np.sum(self.aa, 0)

        self.X,self.Y,self.Z,self.E = self.get_train_data()

        self.X_test,self.Y_test,self.Z_test = self.get_test_data(v_mean = self.v_mean)

        self.loader, self.full_loader = self.get_loader(batch_size=(self.n_per_env*self.n_env)//n_batch)

    def latent_fn(self, seed=None):
        '''
        Nontrained feedfowards to generate X from given Z
        '''
        if seed is not None:
            torch.manual_seed(seed)
        with torch.no_grad():
            self.lat_fn = Poly_dec(poly_degree=self.degree, x_dim=self.X_dim, lat_dim=self.lat_dim)

    def get_train_data(self):
        Z = torch.tensor([])
        Y = torch.tensor([])
        X = torch.tensor([])
        E = torch.tensor([])
        eye = torch.eye(self.n_env)

        for e in range(self.n_env):
            eps_e = torch.distributions.multivariate_normal.MultivariateNormal(
                torch.zeros(self.lat_dim + 1), torch.from_numpy(self.cov_eps).to(torch.float32)
                ).sample((self.n_per_env,))

            delta_e = np.zeros((self.n_per_env,self.lat_dim + 1))
            if e != 0:
                if not self.target_intervention:
                    self.means_delta[e-1][-1] = 0.0
                delta_e = np.random.multivariate_normal(self.means_delta[e-1], self.covs_delta[e-1], self.n_per_env)
            delta_e = torch.from_numpy(delta_e).to(torch.float32)

            observations = (eps_e + delta_e) @ self.C.t()
            Z_e = observations[:, :-1]
            with torch.no_grad():
                X_e = self.lat_fn(Z_e)
            Y_e = observations[:,-1]
            E_e = torch.cat([eye[e,:].unsqueeze(0) for _ in range(self.n_per_env)])
            X = torch.cat((X,X_e), dim = 0)
            Y = torch.cat((Y,Y_e), dim = 0)
            Z = torch.cat((Z,Z_e), dim = 0)
            E = torch.cat((E,E_e), dim = 0)
        return X,Y.unsqueeze(1),Z,E

    def get_test_data(self, v_mean = None):
        if v_mean is not None:
            v_mu = v_mean
        else:
            v_mu = (self.eta/self.n_env) * np.sum(self.means_delta, 0)#np.random.normal(size = (self.lat_dim + 1))
        if not self.target_intervention:
            v_mu[-1] = 0.0
            self.vvt[:, -1], self.vvt[-1, :] = 0., 0.

        self.v_mean = v_mu #/ np.linalg.norm(v_mu)
        if self.misspecified:
            eps_plus_v = torch.distributions.chi2.Chi2(torch.tensor([np.linalg.norm(self.v_mean, 1)], dtype=torch.float32)+0.5).sample((self.n_per_env, self.lat_dim+1)).squeeze(-1)
            observations = (eps_plus_v) @ self.C.t()
        else:
            eps = torch.distributions.multivariate_normal.MultivariateNormal(
              torch.zeros(self.lat_dim + 1), torch.from_numpy(self.cov_eps).to(torch.float32)
              ).sample((self.n_per_env,))

            v = np.random.multivariate_normal(self.v_mean,
                                              self.vvt - self.v_mean.reshape(-1,1)@self.v_mean.reshape(1,-1),
                                              self.n_per_env)
            v = torch.from_numpy(v).to(torch.float32)

            observations = (eps + v) @ self.C.t()
        Z_test = observations[:,:-1]
        Y_test = observations[:,-1]
        X_test = self.lat_fn(Z_test)
        return X_test, Y_test.unsqueeze(1), Z_test

In [68]:
from scipy.special import comb
from torch import nn
from sklearn.preprocessing import PolynomialFeatures

class Poly_dec(nn.Module):
    """ 
    Injective polynomial decoder for the latent space.
    """
    def __init__(self, deg, x_dim, lat_dim, debug=False):
        super().__init__()
        # Properties of data and latent space
        self.deg = deg
        self.x_dim = x_dim
        self.lat_dim = lat_dim
        self.debug = debug

        # Compute the total number of polynomial terms
        self.poly_size = self.num_polynomial_terms_of_degree_p(self.deg) #self.compute_num_polynomial_terms()

        # Check the implicit dimensionality condition for full column-rankedness
        assert self.x_dim >= self.num_polynomial_terms_of_degree_p(self.deg), "The polynomial degree is too high for the latent dimensionality to guarantee an injective polynomial decoder."

        # Generate a full rank coefficient matrix - full rank for injectivity
        self.coef_matrix = self.random_full_column_rank()
        
    
    def num_polynomial_terms_of_degree_p(self, p):
        """
        Compute the number of polynomial terms for a given degree.
        """
        count = 0 

        # Using the combinatorial formula for the number of non-negative integer solutions to the equation x1 + x2 + ... + xk = p
        for r in range(p+1):
            count += comb(r + self.lat_dim - 1, self.lat_dim - 1)

        return int(count)

    def compute_total_num_polynomial_terms(self):
        """
        Compute the total number of possible terms for polynomials of degree deg.
        """
        count = 0
        for p in range(self.deg + 1):
            count += self.num_polynomial_terms_of_degree_p(p)
        return count
        
    def random_full_column_rank(self):
        """
        Generate an n x p matrix (p <= n) with real entries drawn from a 
        normal distribution. With probability 1, it will be full column rank.
        If not, the function regenerates until it is.
        """
        n = self.x_dim
        p = self.poly_size
        # print(f"Generating a full column rank matrix of size {n} x {p}...")

        while True:
            # Generate a random matrix with high probability of being full column rank
            M = torch.randn(n, p)

            # Check if the matrix is full column rank
            rank_M = torch.linalg.matrix_rank(M)
            if rank_M == p:
                return M

    def compute_decoder_polynomial(self, latent):
        """
        Compute all the kronecker products with distinct entries of the latent vector 
        with itself up to degree poly_degree. 
        """
        assert latent.shape[0] == self.lat_dim, "The latent dimensionality of the sample is incorrect."

        out = []
     
        poly = PolynomialFeatures(degree=self.deg, include_bias=True)  # exclude the constant term
        out = poly.fit_transform(latent.reshape(1, -1)).T  # shape = (1, number_of_features)
    
        return torch.tensor(out, dtype=torch.float32)
    
    def forward(self, latent):
        """
        Forward pass of the polynomial decoder.
        """
        latent_poly = self.compute_decoder_polynomial(latent)

        
        # Apply coefficients to the polynomial terms
        X = torch.matmul(self.coef_matrix, latent_poly)

        if self.debug:
            print(f"\nShape of the latent vector: {latent.shape}")
            print(latent)
            print(f"\nShape of the polynomial terms: {latent_poly.shape}")
            print(latent_poly)
            print(f"\nShape of coefficients: {self.coef_matrix.shape}")
            print(self.coef_matrix)
            print(f"\nShape of the output: {X.shape}")
            print(X)
            
        return X


In [73]:
# Test the polynomial decoder
poly_degree = 4
x_dim = 6
lat_dim = 2
sample_latent = torch.randn(lat_dim)
#sample_latent = torch.tensor([2,3])

poly_decoder = Poly_dec(poly_degree, x_dim, lat_dim, debug=True)

# Test the forward pass
X = poly_decoder.forward(sample_latent)

AssertionError: The polynomial degree is too high for the latent dimensionality.

In [55]:
X

tensor([-0.4733, -0.0166,  0.5124,  0.8485,  1.1833,  1.3318, -0.0339, -0.0971,
        -1.3365, -0.9610])