In [133]:
import torch
import matplotlib.pyplot as plt
import torch.distributions as tdist

In [6]:
M = 3  # total number of datasets
N = 1000  # total number of samples
d = 5  # latent variable z dimension
y_dims = torch.tensor([20, 30, 50])  # dimensions of each dataset
x_dims = torch.tensor([4, 5, 6])


In [7]:
def generate_W_L_Phi(y_dims, x_dims, d=5, M=3, minval=-100, maxval=100, max_noise=5):
    all_Ws = []
    all_Ls = []
    all_Phis = []
    
    for i, y_val in enumerate(y_dims):
        # declare new uniform distributions from which to sample ground truth W, L, and Phi for each dataset
        W_uniform = tdist.uniform.Uniform(torch.ones(y_val, d)*minval, torch.ones(y_val, d)*maxval)
        L_uniform = tdist.uniform.Uniform(torch.ones(y_val, x_dims[i])*minval, torch.ones(y_val, x_dims[i])*maxval)
        Phi_uniform = tdist.uniform.Uniform(torch.zeros(y_val), max_noise*torch.ones(y_val))

        # store ground truth W, L, and Phi
        all_Ws.append(W_uniform.sample())
        all_Ls.append(L_uniform.sample())
        all_Phis.append(torch.diag(Phi_uniform.sample()))
    
    return all_Ws, all_Ls, all_Phis

In [8]:
all_Ws, all_Ls, all_Phis = generate_W_L_Phi(y_dims, x_dims)

In [9]:
print([W.shape for W in all_Ws])
print([L.shape for L in all_Ls])
print([Phi.shape for Phi in all_Phis])

[torch.Size([20, 5]), torch.Size([30, 5]), torch.Size([50, 5])]
[torch.Size([20, 4]), torch.Size([30, 5]), torch.Size([50, 6])]
[torch.Size([20, 20]), torch.Size([30, 30]), torch.Size([50, 50])]


In [13]:
def generate_samples(all_Ws, all_Ls, all_Phis, y_dims, x_dims, N=1000, d=5):
    datasets = [torch.zeros(1, y_d) for y_d in y_dims]

    # z-distribution remains fixed
    z_distribution = tdist.multivariate_normal.MultivariateNormal(torch.zeros(d), torch.eye(d))
    
    # simulate the graphical model
    for sample in range(N):
        print('Generating sample: {}/{}'.format(sample, N), end='\r', flush=True)
        # for each sample, retrieve the latent z latent variable
        z = z_distribution.sample()
        # for each dataset, compute the dataset-specific mean and variance, and obtain 1 sample
        for i, dim in enumerate(x_dims):
            x = tdist.multivariate_normal.MultivariateNormal(torch.zeros(dim), torch.eye(dim)).sample()
            y_i = tdist.multivariate_normal.MultivariateNormal(all_Ws[i] @ z + all_Ls[i] @ x, all_Phis[i]).sample()
            datasets[i] = torch.cat([datasets[i], y_i[None,:]])
    datasets = [dataset[1:] for dataset in datasets]      
    return datasets

In [14]:
datasets = generate_samples(all_Ws, all_Ls, all_Phis, y_dims, x_dims)

Generating sample: 999/1000

In [17]:
# concatenate the y-vectors for each sample together
y_concat = torch.cat(datasets, axis=1)

In [24]:
# we have the low-rank structure and shit here
def initialize_model(y_dims, x_dims, d=5, std=1e-2, mean=0):
    Ws_to_stack = []
    Phis_to_stack = []
    Ls_to_stack = []
    
    for i, y_dim in enumerate(y_dims):
        cur_W = torch.nn.init.normal_(torch.zeros(y_dim, d), mean=mean, std=std)
        cur_L = torch.nn.init.normal_(torch.zeros(y_dim, x_dims[i]), mean=mean, std=std)
        cur_Phi = torch.nn.init.normal_(torch.zeros(y_dim, y_dim), mean=mean, std=std)
        Ws_to_stack.append(cur_W)
        Ls_to_stack.append(cur_L)
        Phis_to_stack.append(cur_Phi)
    return torch.cat(Ws_to_stack, axis=0), torch.block_diag(*Ls_to_stack), torch.block_diag(*Phis_to_stack)
        

In [25]:
# compute the ground-truth parameters
W_GT = torch.cat(all_Ws, axis=0)
L_GT = torch.block_diag(*all_Ls)
Phi_GT = torch.block_diag(*all_Phis)

In [26]:
# initialize the model parameters
W_model, L_model, Phi_model = initialize_model(y_dims, x_dims)

In [97]:
# the E-Step required values (is there a way to batch this intelligently?)
def E_step(W, L, Phi, x_dims, d, y_i):
    # schur-complement (M/D)^{-1}
    sigma_22_inv = torch.inverse(W@W.T + L @ L.T)

    # other necessary block matrices
    sigma_12 = torch.cat([W.T, L.T], axis=0)
    sigma_11 = torch.eye(torch.sum(x_dims)+d)

    # compute the posterior mean of z and x; y should be a matrix with all samples aligned as columns
    posterior_z_x_mean = sigma_12 @ sigma_22_inv @ (y_i)
    posterior_z_mean = posterior_z_x_mean[:d]
    posterior_x_mean = posterior_z_x_mean[d:]

    # posterior covariance
    posterior_x1_cov = sigma_11 - sigma_12 @ sigma_22_inv @ sigma_12.T
    posterior_z_x_cov = posterior_x1_cov[:d, d:]  # cross covariance
    posterior_z_z_cov = posterior_x1_cov[:d, :d]  # upper left block matrix
    posterior_x_x_cov = posterior_x1_cov[d:, d:]  # bottom right block matrix
    
    # need to batch zmu and xmu: [n_samples, <[z, x]>.shape, 1]
    zmu_batched = posterior_z_mean.T[:, :, None]
    xmu_batched = posterior_x_mean.T[:, :, None]

    # posterior <zx.T> = cov(z, x) + <z><x.T>
    posterior_zxT = posterior_z_x_cov + zmu_batched @ xmu_batched.permute(0, 2, 1)  # shape: (n_samples, z_dim, x_dim)
    posterior_zzT = posterior_z_z_cov + zmu_batched @ zmu_batched.permute(0, 2, 1)  # shape: (n_samples, z_dim, z_dim)
    posterior_xxT = posterior_x_x_cov + xmu_batched @ xmu_batched.permute(0, 2, 1)  # shape: (n_samples, x_dim, x_dim)

    return posterior_zxT, posterior_zzT, posterior_xxT, zmu_batched, xmu_batched

In [98]:
y_concat_T = y_concat.T
print(y_concat_T.shape)

torch.Size([100, 1000])


In [99]:
zxT, zzT, xxT, zmu, xmu = E_step(W_model, L_model, Phi_model, x_dims, d, y_concat_T)

In [128]:
def M_step(zxT, zzT, xxT, zmu, xmu, y_i, Phi_model, L_model, W_model, N):
    y_i_batched = y_i[:, :, None]  # (n_samples, batch_dim, 1)
    new_L = torch.sum(y_i_batched @ xmu.permute(0, 2, 1) - W_model @ zxT, axis=0) @ torch.inverse(torch.sum(xxT, axis=0))
    new_W = torch.sum(y_i_batched @ zmu.permute(0, 2, 1) - L_model @ zxT.permute(0, 2, 1), axis=0) @ torch.inverse(torch.sum(zzT, axis=0))
    
    phi_term_1 = 2 / N * torch.sum( (L_model @ xmu + W_model @ zmu) @ y_i_batched.permute(0, 2, 1), axis=0)
    phi_term_2 = -1 / N * torch.sum(y_i_batched @ y_i_batched.permute(0, 2, 1), axis=0)
    #     print(phi_term_1.shape)
    #     print(phi_term_2.shape)
    #     print(new_L.shape)
    #     print(new_W.shape)
    phi_term_3 = -1 / N * torch.sum( (L_model @ xxT @ L_model.T) + W_model @ zzT @ W_model.T + 2 * L_model @ zxT.permute(0, 2, 1) @ W_model.T, axis=0)
    
    new_Phi = torch.inverse(phi_term_1 + phi_term_2 + phi_term_3)
    return new_W, new_L, new_Phi

In [129]:
new_W, new_L, new_Phi = M_step(zxT, zzT, xxT, zmu, xmu, y_concat, Phi_model, L_model, W_model, N)

In [137]:
# number of E/M-Steps To Run
steps = 10000
# initialize the model parameters
W_model, L_model, Phi_model = initialize_model(y_dims, x_dims)

# ground truth (GT) values
W_GT = torch.cat(all_Ws, axis=0)
L_GT = torch.block_diag(*all_Ls)
Phi_GT = torch.block_diag(*all_Phis)

# compute the reconstruction error
W_losses = []
L_losses = []
Phi_losses = []

# GPU setup
# device = 'cpu'
# if torch.cuda.is_available():
#     device = 'cuda:0'

# # move to the GPU    
# W_model = W_model.to(device)
# L_model = L_model.to(device)
# Phi_model = Phi_model.to(device)
# x_dims = x_dims.to(device)

# iterate through E/M Steps
for i in range(steps):
    # E-Step, then M-Step
    zxT, zzT, xxT, zmu, xmu = E_step(W_model, L_model, Phi_model, x_dims, d, y_concat_T)
    W_model, L_model, Phi_model = M_step(zxT, zzT, xxT, zmu, xmu, y_concat, Phi_model, L_model, W_model, N)
    
    # compute training stats
    W_mse = torch.mean(torch.pow(W_model - W_GT, 2))
    L_mse = torch.mean(torch.pow(L_model - L_GT, 2))
    Phi_mse = torch.mean(torch.pow(Phi_model - Phi_GT, 2))
    
    # store training stats
    W_losses.append(W_mse.item())
    L_losses.append(L_mse.item())
    Phi_losses.append(Phi_mse.item())
    
    if (i % 50 == 0):
        print("{}/{}: W_mse: {} Phi_mse: {} L_mse: {}".format(i, steps, W_mse.item(), Phi_mse.item(), L_mse.item()), flush=True, end='\r')


2350/10000: W_mse: 3297.847900390625 Phi_mse: 0.08523043990135193 L_mse: 2912775.5.087e+18

KeyboardInterrupt: 