In [1]:
import numpy as np
import torch as th
import matplotlib.pyplot as plt
from matplotlib import patches as mpatches
import seaborn as sns
import pandas as pd
import os
import sys
from append_directories import *
data_generation_folder = (append_directory(3) + "/generate_data")
sys.path.append(data_generation_folder)
from true_conditional_data_generation import *

In [56]:
index_to_matrix_index(1,n)

(1, 0)

In [2]:
#index is assumed to be in i*n+j form where (i,j) is index of matrix
def index_to_spatial_location(minX, maxX, minY, maxY, n, index):

    # create one-dimensional arrays for x and y
    x = np.linspace(minX, maxX, n)
    y = np.linspace(minY, maxY, n)
    # create the mesh based on these arrays
    X, Y = np.meshgrid(x, y)
    X = X.reshape((np.prod(X.shape),1))
    Y = Y.reshape((np.prod(Y.shape),1))
    
    xlocation = (X[index])[0]
    ylocation = (Y[index])[0]
    return (xlocation, ylocation)


def index_to_matrix_index(index, n):
    print(index)
    return (int(index / n), int(index % n))

def visualize_spatial_field(observation, min_value, max_value, figname):

    fig, ax = plt.subplots(1)
    plt.imshow(observation, vmin = min_value, vmax = max_value)
    plt.savefig(figname)

def produce_bivariate_density(mask, minX, maxX, minY, maxY, n, variance, lengthscale,
                              masked_vector, number_of_replicates, missing_two_indices,
                              missing_indices, mask_type, folder_name, m, observed_vector):
    
    #missing_index is in between 0 and m, it's not the original missing index from n x n field
    conditional_vectors = sample_conditional_distribution(mask, minX, maxX, minY, maxY, n,
                                                     variance, lengthscale, masked_vector,
                                                     number_of_replicates)
    #conditional_vectors is shape (number of replicates, m)
    bivariate_density = (conditional_vectors[:,missing_two_indices]).reshape((number_of_replicates,2))
    fig, axs = plt.subplots(ncols = 2, figsize = (10,5))
    #emp_mean = round(np.mean(marg), 2)
    #emp_var = round(np.std(marginal_density)**2, 2)
    pdd = pd.DataFrame(bivariate_density,
                                    columns = None)

    #partially_observed_field = np.multiply(mask.astype(bool), observed_vector.reshape((n,n)))
    axs[0].imshow(observed_vector.reshape((n,n)), alpha = (1-mask), vmin = -2, vmax = 2)
    missing_true_index1 = missing_indices[missing_two_indices[0]]
    missing_true_index2 = missing_indices[missing_two_indices[1]]
    matrix_index1 = index_to_matrix_index(missing_true_index1, n)
    matrix_index2 = index_to_matrix_index(missing_true_index2, n)
    axs[0].plot(matrix_index1[0], matrix_index1[1], "r+")
    axs[0].plot(matrix_index2[0], matrix_index2[1], "r+")
    sns.kdeplot(x = bivariate_density[:,0], y = bivariate_density[:,1],
                ax = axs[1])
    plt.axvline(observed_vector[missing_true_index1], color='red', linestyle = 'dashed')
    plt.axhline(observed_vector[missing_true_index2], color='red', linestyle = 'dashed')
    plt.xlim(-2,2)
    plt.ylim(-2,2)
    axs[1].set_title("Marginal")
    location1 = index_to_spatial_location(minX, maxX, minY, maxY, n, missing_true_index1)
    rlocation1 = (round(location1[0],2), round(location1[1],2))
    location2 = index_to_spatial_location(minX, maxX, minY, maxY, n, missing_true_index2)
    rlocation2 = (round(location2[0],2), round(location2[1],2))
    axs[1].set_xlabel("location: " + str(rlocation1))
    axs[1].set_ylabel("location: " + str(rlocation2))
    axs[1].legend(labels = ['true'])
    plt.show()

def produce_marginal_density(mask, minX, maxX, minY, maxY, n, variance, lengthscale,
                                  number_of_replicates, missing_index,
                                  missing_indices, mask_type, folder_name, m, observed_vector, ref_image):

    #missing_index is in between 0 and m, it's not the original missing index from n x n field
    conditional_vectors = sample_conditional_distribution(mask, minX, maxX, minY, maxY, n,
                                                     variance, lengthscale, observed_vector,
                                                     number_of_replicates)
    #conditional_vectors is shape (number of replicates, m)
    marginal_density = (conditional_vectors[:,missing_index]).reshape((number_of_replicates,1))

    #fig, ax = plt.subplots(1)
    #ax.hist(marginal_disalsotribution, density = True, histtype = 'step', bins = 100)
    fig, axs = plt.subplots(ncols = 2, figsize = (10,5))
    emp_mean = round(np.mean(marginal_density), 2)
    emp_var = round(np.std(marginal_density)**2, 2)
    pdd = pd.DataFrame(marginal_density,
                                    columns = None)
    #partially_observed_field = np.multiply(mask.astype(bool), observed_vector.reshape((n,n)))
    axs[0].imshow(observed_vector.reshape((n,n)), alpha = (1-mask), vmin = -2, vmax = 2)
    missing_true_index = missing_indices[missing_index]
    matrix_index = index_to_matrix_index(missing_true_index, n)
    axs[0].plot(matrix_index[0], matrix_index[1], "r+")
    sns.kdeplot(data = pdd, ax = axs[1])
    plt.axvline(ref_image[int(matrix_index[0]),int(matrix_index[1])], color='red', linestyle = 'dashed')
    axs[1].set_title("Marginal")
    location = index_to_spatial_location(minX, maxX, minY, maxY, n, missing_true_index)
    rlocation = (round(location[0],2), round(location[1],2))
    axs[1].set_xlabel("location: " + str(rlocation))
    axs[1].legend(labels = ['true'])
    plt.show()


def produce_true_and_generated_marginal_density(mask, minX, maxX, minY, maxY, n, variance, lengthscale,
                                  number_of_replicates, missing_index,
                                  missing_indices, folder_name, m, observed_vector,
                                  conditional_generated_samples, ref_image, figname):

    #missing_index is in between 0 and m, it's not the original missing index from n x n field
    conditional_vectors = sample_conditional_distribution(mask, minX, maxX, minY, maxY, n,
                                                     variance, lengthscale, observed_vector,
                                                     number_of_replicates)
    #conditional_vectors is shape (number of replicates, m)
    marginal_density = (conditional_vectors[:,missing_index]).reshape((number_of_replicates,1))
    missing_true_index = missing_indices[missing_index]
    matrix_missing_index = index_to_matrix_index(missing_true_index, n)
    generated_marginal_density = conditional_generated_samples[:,int(matrix_missing_index[0]),int(matrix_missing_index[1])]

    #fig, ax = plt.subplots(1)
    #ax.hist(marginal_disalsotribution, density = True, histtype = 'step', bins = 100)
    fig, axs = plt.subplots(ncols = 2, figsize = (10,5))
    emp_mean = round(np.mean(marginal_density), 2)
    emp_var = round(np.std(marginal_density)**2, 2)
    pdd = pd.DataFrame(marginal_density,
                                    columns = None)
    generated_pdd = pd.DataFrame(generated_marginal_density,
                                    columns = None)

    #partially_observed_field = np.multiply(mask.astype(bool), observed_vector.reshape((n,n)))
    mask = mask.astype(float).reshape((n,n))
    axs[0].imshow(ref_image.reshape((n,n)), alpha = (1-mask), vmin = -2, vmax = 2)
    axs[0].plot(matrix_missing_index[1], matrix_missing_index[0], "r+")
    sns.kdeplot(data = pdd, ax = axs[1], palette=['blue'])
    sns.kdeplot(data = generated_pdd, palette = ["orange"], ax = axs[1])
    plt.axvline(ref_image[int(matrix_missing_index[0]),int(matrix_missing_index[1])], color='red', linestyle = 'dashed')
    axs[1].set_title("Marginal")
    location = index_to_spatial_location(minX, maxX, minY, maxY, n, missing_true_index)
    rlocation = (round(location[0],2), round(location[1],2))
    axs[1].set_xlabel("location: " + str(rlocation))
    axs[1].legend(labels = ['true', 'generated'])
    plt.savefig(figname)

def produce_true_and_generated_bivariate_density(mask, minX, maxX, minY, maxY, n, variance, lengthscale,
                                                 number_of_replicates, missing_two_indices,
                                                 missing_indices, mask_type, folder_name, m, observed_vector,
                                                 conditional_generated_samples, ref_image, figname):
    
    #missing_index is in between 0 and m, it's not the original missing index from n x n field
    conditional_vectors = sample_conditional_distribution(mask, minX, maxX, minY, maxY, n,
                                                          variance, lengthscale, observed_vector,
                                                          number_of_replicates)
    #conditional_vectors is shape (number of replicates, m)
    bivariate_density = (conditional_vectors[:,missing_two_indices]).reshape((number_of_replicates,2))
    missing_true_index1 = missing_indices[missing_two_indices[0]]
    missing_true_index2 = missing_indices[missing_two_indices[1]]
    matrix_index1 = index_to_matrix_index(missing_true_index1, n)
    matrix_index2 = index_to_matrix_index(missing_true_index2, n)
    number_of_replicates = conditional_generated_samples.shape[0]
    generated_bivariate_density = np.concatenate([(conditional_generated_samples[:,int(matrix_index1[0]),int(matrix_index1[1])]).reshape((number_of_replicates,1)),
                                                   (conditional_generated_samples[:,int(matrix_index2[0]),int(matrix_index2[1])]).reshape((number_of_replicates,1))],
                                                   axis = 1)
    bivariate_density = np.concatenate([bivariate_density, generated_bivariate_density], axis = 0)
    class_vector = np.concatenate([(np.repeat('true', number_of_replicates)).reshape((number_of_replicates,1)),
                                   (np.repeat('generated', number_of_replicates)).reshape((number_of_replicates,1))], axis = 0)
    bivariate_density = np.concatenate([bivariate_density, class_vector], axis = 1)
    fig, axs = plt.subplots(ncols = 2, figsize = (10,5))
    #emp_mean = round(np.mean(marg), 2)
    #emp_var = round(np.std(marginal_density)**2, 2)
    pdd = pd.DataFrame(bivariate_density,
                                    columns = ['x', 'y', 'class'])
    pdd = pdd.astype({'x': 'float64', 'y': 'float64'})
    #partially_observed_field = np.multiply(mask.astype(bool), observed_vector.reshape((n,n)))
    axs[0].imshow(ref_image.reshape((n,n)), alpha = (1-mask.reshape((n,n))), vmin = -2, vmax = 2)
    print(matrix_index1)
    axs[0].plot(matrix_index1[0], matrix_index1[1], "r+")
    axs[0].plot(matrix_index2[0], matrix_index2[1], "r+")
    kde1 = sns.kdeplot(data = pdd, x = 'x', y = 'y',
                ax = axs[1], hue = 'class', shade = True, levels = 5, alpha = .5)
    #kde2 = sns.kdeplot(x = generated_bivariate_density[:,0], y = generated_bivariate_density[:,1],
                #ax = axs[1], color = 'orange', levels = 5, label = "generated")
    blue_patch = mpatches.Patch(color='blue')
    orange_patch = mpatches.Patch(color='orange')
    plt.axvline(ref_image[int(matrix_index1[0]),int(matrix_index1[1])], color='red', linestyle = 'dashed')
    plt.axhline(ref_image[int(matrix_index2[0]),int(matrix_index2[1])], color='red', linestyle = 'dashed')
    plt.xlim(-2,2)
    plt.ylim(-2,2)
    axs[1].set_title("Marginal")
    location1 = index_to_spatial_location(minX, maxX, minY, maxY, n, missing_true_index1)
    rlocation1 = (round(location1[0],2), round(location1[1],2))
    location2 = index_to_spatial_location(minX, maxX, minY, maxY, n, missing_true_index2)
    rlocation2 = (round(location2[0],2), round(location2[1],2))
    axs[1].set_xlabel("location: " + str(rlocation1))
    axs[1].set_ylabel("location: " + str(rlocation2))
    #axs[1].legend(handles = [blue_patch, orange_patch],labels = ['true', 'generated'])
    plt.savefig(figname)
    plt.clf()

In [3]:
n = 32
number_of_replicates = 1000 
conditional_samples = np.load((data_generation_folder + "/data/ref_image2/diffusion/model4_beta_min_max_01_20_random50_1000.npy"))
conditional_samples = conditional_samples.reshape((number_of_replicates,n,n))
#mask = np.load((data_generation_folder + "/data/ref_image1/mask.npy"), allow_pickle = True)
n = 32
#mask = th.zeros((1,n,n))
#mask[:, int(n/4):int(n/4*3), int(n/4):int(n/4*3)] = 1
device = "cuda:0"
p = .5
mask = np.load((data_generation_folder + "/data/ref_image2/mask.npy"))
ref_image = (np.load((data_generation_folder + "/data/ref_image2/ref_image2.npy")))
minX = -10
maxX = 10
minY = -10
maxY = 10
variance = .4
lengthscale = 1.6                                                                                        
missing_indices = np.squeeze(np.argwhere((1-mask).reshape((n**2,))))
mask_type = "random50"
folder_name = (data_generation_folder + "/data/ref_image2/marginal_density")
m = missing_indices.shape[0]
observed_vector = ref_image.reshape((n**2))
observed_vector = np.delete(observed_vector, missing_indices)

In [30]:
conditional_vectors = sample_conditional_distribution((1-mask), minX, maxX, minY, maxY, n,
                                                     variance, lengthscale, observed_vector,
                                                     number_of_replicates)

In [33]:
missing_index = 0
missing_true_index = missing_indices[missing_index]
matrix_missing_index = index_to_matrix_index(missing_true_index, n)
generated_marginal_density = conditional_samples[:,int(matrix_missing_index[0]),int(matrix_missing_index[1])]


In [37]:
#missing_index is in between 0 and m, it's not the original missing index from n x n field
conditional_vectors = sample_conditional_distribution((1-mask), minX, maxX, minY, maxY, n,
                                                     variance, lengthscale, observed_vector,
                                                     number_of_replicates)

In [4]:
for i in range(50,100):
    missing_index = i
    true_missing_index = missing_indices[missing_index]
    true_missing_matrix_index = index_to_matrix_index(true_missing_index, n)
    figname = (folder_name + "/marginal_density_model4_" + str(int(true_missing_matrix_index[0]))
                + "_" + str(int(true_missing_matrix_index[1])) + ".png")
    produce_true_and_generated_marginal_density((1-mask), minX, maxX, minY, maxY, n, variance, lengthscale,
                                        number_of_replicates, missing_index,
                                        missing_indices, folder_name, m, observed_vector,
                                        conditional_samples, ref_image, figname)

94
94


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


95
95


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


96
96


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


98
98


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


99
99


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


100
100


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


101
101


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


104
104


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


105
105


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


106
106


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


108
108


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


109
109
111


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


111


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


113
113


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


116
116


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


118
118


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


119
119


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


120
120


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


122
122


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


124
124


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


125
125


  fig, axs = plt.subplots(ncols = 2, figsize = (10,5))
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


126
126
128


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


128


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


130
130


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


132
132


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


133
133


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


134
134


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


135
135


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


136
136


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


137
137


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


140
140


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


141
141


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


142
142


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


143
143


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


144
144


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


145
145


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


146
146


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


152
152


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


153
153


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


156
156


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


158


In [68]:
matrix_missing_index

(17, 0)

In [77]:
missing_index = 20
missing_true_index = missing_indices[missing_index]
matrix_missing_index = index_to_matrix_index(missing_true_index, n)
conditional_samples[:,matrix_missing_index[1],matrix_missing_index[0]]

42


array([-3.57364535e-01,  1.23831972e-01,  2.79935837e-01, -8.26655626e-01,
        7.87797868e-02, -4.27647322e-01, -5.33545434e-01,  4.52441603e-01,
        3.04330319e-01, -7.06139266e-01,  6.93361878e-01,  5.91756636e-03,
       -4.10902470e-01, -1.01192415e+00,  3.00612360e-01, -2.97114342e-01,
       -1.61815405e-01,  4.38174456e-01, -3.44166607e-01,  2.22219467e-01,
       -3.00160497e-01,  3.37844610e-01,  5.69529057e-01,  5.27656555e-01,
        5.99801302e-01, -1.30042702e-01,  6.63312256e-01, -2.54531085e-01,
        9.44514930e-01,  1.09138596e+00, -5.68095982e-01, -9.36052978e-01,
       -1.78911611e-02, -5.02843082e-01, -2.05404311e-01,  3.77273619e-01,
       -4.35521305e-01,  5.41186392e-01, -2.24787086e-01, -3.75308990e-01,
       -1.50101230e-01,  2.65033960e-01,  4.26889241e-01, -7.22273290e-02,
        5.15574634e-01,  3.03639352e-01,  5.83440661e-01, -1.05000293e+00,
       -4.16698530e-02, -2.29196668e-01, -2.51895428e-01, -4.82900053e-01,
        4.25429732e-01, -