In [None]:
import numpy as np
import matplotlib.pyplot as plt

def construct_contour_gauss(centers, weights, learned_variance, interp=100):
    QMI_TRUE_LIST = []
    interp = 100
    delta = (max-min)/interp

    x_axis = np.linspace(min, max, interp)
    y_axis = np.linspace(min, max, interp)
    xv, yv = np.meshgrid(x_axis,y_axis)

    input = np.array((xv, yv)).reshape(2, -1).T

    gaussian_plot_joint_ = []
    gaussian_plot_split_x_ = []
    gaussian_plot_split_y_ = []

    #centers = np.concatenate((center_x, center_y), 1)
    #difference = input.reshape(-1, 2) - centers.reshape(-1, 1, 2)

    for i in range(0, centers.shape[0]):
        gaussian_plot_joint_.append(weights[i]*gaussian_nd(input - centers[i], 0, learned_variance[i]))
    gaussian_plot_joint = np.mean(np.array(gaussian_plot_joint_), 0)*delta*delta
    
    return gaussian_plot_joint.reshape(interp, interp)

def gaussian_nd_numpy(MEAN, VARIANCE):
    bs = VARIANCE.shape[0]
    dim = VARIANCE.shape[1]

    det = np.linalg.det(VARIANCE)
    inv = np.linalg.pinv(VARIANCE)
            
    product = np.sum((MEAN.reshape(-1, 1, dim)@inv).reshape(-1, dim)*MEAN.reshape(-1, dim), 1)
    
    return ((2*np.pi)**(-dim/2))*det**(-1/2)*np.exp(-(1/2)*product)

def compute_TRUE_ENTROPY(MEAN_matrix, COV_matrix, weights_matrix):
    K = MEAN_matrix.shape[0]
    dim = MEAN_matrix.shape[1]
    
    MEAN_DIFF = MEAN_matrix.reshape(K, 1, dim) - MEAN_matrix.reshape(1, K, dim)
    COV_DIFF = COV_matrix.reshape(K, 1, dim, dim) + COV_matrix.reshape(1, K, dim, dim)
    WEIGHT_DIFF = weights_matrix.reshape(K, 1)*weights_matrix.reshape(1, K)
    
    return np.sqrt(np.sum(WEIGHT_DIFF.reshape(-1)*gaussian_nd_numpy(MEAN_DIFF.reshape(-1, dim), COV_DIFF.reshape(-1, dim, dim))))

def gaussian_nd(input, m, sigma):
    k = sigma.shape[0]
    det = np.linalg.det(sigma)
    inv = np.linalg.pinv(sigma)
    
    return ((2*np.pi)**(-k/2))*det**(-1/2)*np.exp(-(1/2)*np.sum((input-m)@inv*(input-m), 1))

#### CONSTRUCT A MORE RESONABLE MKM

min = 0
max = 1

def makediag3d(a):
    a = np.asarray(a)
    depth, size = a.shape
    x = np.zeros((depth,size,size))
    for i in range(depth):
        x[i].flat[slice(0,None,1+size)] = a[i]
    return x

def create_Gaussian_mixture_MC():
    np.random.seed(4)
    
    num = 10
    center_x = np.linspace(0.2, 0.8, num)
    center_y = np.linspace(0.2, 0.8, num)
    
    np.random.shuffle(center_x)
    np.random.shuffle(center_y)

    xv, yv = np.meshgrid(center_x, center_y)

    MEAN_MATRIX = np.array((xv, yv)).reshape(2, -1).T
    COV = makediag3d(np.random.uniform(0.0005, 0.002, MEAN_MATRIX.shape[0]*2).reshape(num*num, 2))        
    weights = np.ones((num, num))*(0.3/(num-1))+np.eye(num)*(0.7-0.3/(num-1))
    
    return MEAN_MATRIX, COV, weights, center_x, center_y

MEAN_matrix, COV_matrix, weights_matrix, center_x, center_y = create_Gaussian_mixture_MC()
COV_matrix = COV_matrix*3
weights_matrix = weights_matrix.reshape(-1)

pdf = construct_contour_gauss(MEAN_matrix, weights_matrix.reshape(-1), COV_matrix)
normalize = pdf/np.sum(pdf,1).reshape(-1, 1)

print('Generating pdf for EXP #1...')

plt.imshow(pdf, origin='lower', extent=[min, max, min, max])
plt.show()

plt.imshow(normalize, origin='lower', extent=[min, max, min, max])
plt.show()

In [None]:
def gaussian_1d(input, m, sigma):
    det = sigma
    inv = 1/sigma
        
    input = input.reshape(-1, 1)
    m = m.reshape(1, -1)

    return ((2*np.pi)**(-1/2))*det**(-1/2)*np.exp(-(1/2)*((input-m)**2*inv))

def generate_gauss_samples_various(MEAN, COV_matrix, samples_per_class=3000000):
    
    num_class = MEAN.shape[0]
    component_ = []
    for i in range(0, num_class):
        COV = COV_matrix[i]
        samples = np.random.normal(MEAN[i], np.sqrt(COV), int(samples_per_class))
        component_.append(samples)
    component_ = np.array(component_)
    
    return component_

np.random.seed(4)

inter = 100

x_0 = np.linspace(0.1, 0.9, inter)
#x_0 = np.array([0.5])
next_samples = np.copy(x_0)
density = []
stored_samples = generate_gauss_samples_various(MEAN_matrix[:, 0], COV_matrix[:, 0, 0])

iter = 1000

current_ = np.zeros((iter, inter))
next_ = np.zeros((iter, inter))
num_samples = np.zeros((weights_matrix.shape[0]), dtype=int)

print('Start generating samples...')

for i in range(0, iter):
    current_[i] = np.copy(next_samples)
    weights_x0 = weights_matrix*gaussian_1d(next_samples, MEAN_matrix[:, 1], (COV_matrix[:, 1, 1]))
    weights_x0 = weights_x0/np.sum(weights_x0, 1).reshape(-1, 1)

    next_chosen = np.array([np.random.choice(weights_.shape[0], 1, p=weights_)[0] for weights_ in weights_x0])
    next_samples = []
    for j in next_chosen:
        next_samples.append(stored_samples[j, num_samples[j]])
        num_samples[j]+=1
    next_samples = np.array(next_samples)
    density.append(np.histogram(next_samples, bins=100)[0])
    
    next_[i] = np.copy(next_samples)
    
print('Done')
joint_samples = np.array((current_, next_)).reshape(2, -1).T

In [None]:
import matplotlib.pyplot as plt

# plt.style.reload_library()

# plt.style.use('science')
# plt.style.use(['science','no-latex', 'high-vis', 'notebook'])
# plt.rcParams["figure.figsize"] = [6,4]

# WHETHER TO COMMENT

import numpy as np
import matplotlib.pyplot as plt

import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from sklearn.linear_model import LinearRegression

import sklearn
from sklearn import datasets

#torch.cuda.set_device(1)

def GD(net, lr):
    for param in net.parameters():
        if param.requires_grad:
            param.data = param.data - lr*param.grad
            
    net.zero_grad()
    return 0

def sample_uniform_data(min=-5, max=5, batch=300):
    return np.random.uniform(min, max, batch)

def gaussian_nd_pytorch(MEAN, VARIANCE):
    bs = VARIANCE.shape[0]
    dim = VARIANCE.shape[1]

    det = VARIANCE[:, 0]
    for i in range(1, dim):
        det = det*VARIANCE[:, i]    
        
    product = torch.sum(((MEAN.reshape(-1, dim)*(1/VARIANCE).view(-1, dim))*MEAN.reshape(-1, dim)), 1)
    
    return ((2*np.pi)**(-dim/2))*det**(-1/2)*torch.exp(-(1/2)*product)

def sample_c(batchsize=32, dis_category=5):
    rand_c = np.zeros((batchsize,dis_category),dtype='float32')
    for i in range(0,batchsize):
        rand = np.random.multinomial(1, dis_category*[1/float(dis_category)], size=1)
        rand_c[i] = rand

    label_c = np.argmax(rand_c,axis=1)
    label_c = torch.LongTensor(label_c.astype('int'))
    rand_c = torch.from_numpy(rand_c.astype('float32'))
    return rand_c,label_c

def generate_laplacian(center_x, center_y,  scale=0.01, samples_per_class=3000):
    
    component_ = []
    num_class = MEAN.shape[0]

    for i in range(0, num_class):
        dim_1 = np.random.laplace(MEAN[i, 0], scale=scale, size=samples_per_class)
        dim_2 = np.random.laplace(MEAN[i, 1], scale=scale, size=samples_per_class)
        samples = np.array((dim_1, dim_2)).T
        component_.append(samples)

    component_ = np.array(component_).reshape(-1, MEAN.shape[1])    
    
    return component_

def generate_uniform(center_x, center_y,  length=0.5, samples_per_class=3000):
    component_ = []
    
    num_class = MEAN.shape[0]
    for i in range(0, num_class):
        x_1 = np.random.uniform(center_x[i], center_x[i]+length, samples_per_class)
        x_3 = np.random.uniform(center_y[i], center_y[i]+length, samples_per_class)
        x = np.array((x_1, x_3)).T
        component_.append(x)
            
    component_ = np.concatenate(component_, 0)
    return component_

class DIS_MOG_relu(nn.Module):
    def __init__(self, rand, HIDDEN, dim):
        super(DIS_MOG_relu, self).__init__()
        self.dim = dim
    
        self.fc1 = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        
        self.fc1_ = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33_ = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.fc1_w = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2_w = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3_w = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33_w = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.sum_dim_mean = 128
        self.sum_dim_var = 128
        self.sum_dim_weights = 128

        self.fc5 = nn.Linear(HIDDEN, dim*self.sum_dim_mean, bias=True)
        self.fc6 = nn.Linear(HIDDEN, dim*self.sum_dim_var, bias=True)
        self.fcw = nn.Linear(HIDDEN, self.sum_dim_weights, bias=True)

    def forward(self, x):

        x = self.bn1_(torch.relu((self.fc1_(x))))
        x = self.bn2_(torch.relu((self.fc2_(x))))
        #x = self.bn3_(torch.relu((self.fc3_(x))))

        # x_m = self.bn1_(torch.sigmoid((self.fc1_(x))))
        # x_m = self.bn2_(torch.sigmoid((self.fc2_(x_m))))
        #x_m = self.bn3_(torch.relu((self.fc3_(x))))
        x_m = (torch.relu((self.fc33_(x))))

        # x_v = self.bn1(torch.sigmoid((self.fc1(x))))
        # x_v = self.bn2(torch.sigmoid((self.fc2(x_v))))
        #x_v = self.bn3(torch.relu((self.fc3(x))))
        x_v = (torch.relu((self.fc33(x))))

        # x_w = self.bn1_w(torch.sigmoid((self.fc1_w(x))))
        # x_w = self.bn2_w(torch.sigmoid((self.fc2_w(x_w))))
        #x_w = self.bn3_w(torch.relu((self.fc3_w(x))))
        x_w = (torch.relu((self.fc33_w(x))))

        dim = self.dim 
        
        mean = torch.sigmoid(self.fc5(x_m)).view(x.shape[0], dim, self.sum_dim_mean)
        mean = torch.mean(mean, 2)

        variance = torch.sigmoid(self.fc6(x_v)).view(x.shape[0], dim, self.sum_dim_var)
        variance = (torch.mean(variance, 2))*0.1+1e-6
        
        # weights = torch.sigmoid(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        # weights = torch.mean(weights, 2)

        weights = torch.sigmoid(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        weights = torch.mean(weights, 2)

        return weights, mean, variance

def gaussian_nd(input, m, sigma):
    k = sigma.shape[0]
    det = np.linalg.det(sigma)
    inv = np.linalg.pinv(sigma)
    
    return ((2*np.pi)**(-k/2))*det**(-1/2)*np.exp(-(1/2)*np.sum((input-m)@inv*(input-m), 1))

def construct_contour1d(centers, weights, learned_variance, interp=100):
    QMI_TRUE_LIST = []
    min = 0
    max = 1
    delta = (max-min)/interp

    x_axis = np.linspace(min, max, interp)

    gaussian_plot_joint_ = []
    gaussian_plot_split_x_ = []
    gaussian_plot_split_y_ = []

    gaussian_plot_joint_ = (weights*gaussian_1d(x_axis, centers, learned_variance))/np.sum(weights)
    gaussian_plot_joint = np.mean(np.array(gaussian_plot_joint_), 1)/delta
    
    return gaussian_plot_joint

def visualization_():
    
    MOG_NET.eval()
    d_class = 300
    num = 1
    bs = d_class*num
    discrete_vec = generate_fix_discrete(bs, d_class).float().cuda()

    uniform_vector_3 = torch.cat((torch.rand(bs, rand).cuda(), discrete_vec), 1)
    scanning_input = torch.from_numpy(np.linspace(0, 1, 100)).float().cuda()

    interp = scanning_input.reshape(-1, 1).shape[0]
    uniform_vector_3_ = uniform_vector_3.reshape(bs, 1, -1).repeat(1, interp, 1)
    scanning_input_ = scanning_input.reshape(1, -1, 1).repeat(bs, 1, 1)
    input = torch.cat((uniform_vector_3_, scanning_input_), 2).reshape(bs*interp, -1)
    WEIGHTS_3, MEAN_3, VARIANCE_3 = MOG_NET(input)    

    learned_mean = MEAN_3.detach().cpu().numpy().reshape(bs, interp)
    learned_variance = VARIANCE_3.detach().cpu().numpy().reshape(bs, interp)
    learned_weights = WEIGHTS_3.detach().cpu().numpy().reshape(bs, interp)

    gaussian_plot_joint_ = []
    for i in range(0, interp):
        gaussian_plot_joint = construct_contour1d(learned_mean[:, i], learned_weights[:, i], learned_variance[:, i], interp=100)
        #plt.plot(np.linspace(0, 1, 100), gaussian_plot_joint)
        gaussian_plot_joint_.append(gaussian_plot_joint)

    plt.contour(np.array(gaussian_plot_joint_), origin='lower', extent=[min, max, min, max])
    plt.show()
    
    MOG_NET.train()
    
torch.manual_seed(6)
np.random.seed(4)

iter = 200000
rand = 1
HIDDEN = 256
dim = 1

d_class = 300
num = 20
bs = d_class*num

d_howmany = 1

MOG_NET = DIS_MOG_relu(rand+d_class*d_howmany+dim, HIDDEN, dim).cuda()

optimizer = optim.AdamW([
            {'params': MOG_NET.parameters(), 'lr': 0.001, 'betas': (0.9, 0.999)},
        ])

entropy_list = []
discrete_prob = torch.ones((d_class,)).float().cuda()

beta = 0.999
v_t = 0.

beta_2 = 0.999
c_t = 0.

x = current_[:].reshape(-1)
y = next_[:].reshape(-1)

VAR_ZERO = torch.zeros((bs, 1, 1)).cuda()
WEIGHT_ZERO = torch.zeros((bs, 1)).cuda()

def generate_fix_discrete(bs, d_class):
    num = bs/d_class
    return torch.cat([torch.nn.functional.one_hot(torch.arange(0, d_class))]*int(num), 0)

discrete_vec = generate_fix_discrete(bs, d_class).float().cuda()

for i in range(1, iter):
    
    b1 = np.random.choice(x.shape[0], bs)
    x_i = torch.from_numpy(x[b1].reshape(-1, 1)).float().cuda()
    y_i = torch.from_numpy(y[b1].reshape(-1, 1)).float().cuda()
        
    uniform_vector_3 = torch.cat((torch.rand(bs, rand).cuda(), discrete_vec), 1)
    uniform_vector_3 = uniform_vector_3[torch.randperm(uniform_vector_3.size()[0])]
    
    uniform_vector_3 = torch.cat((uniform_vector_3, x_i), 1)

    uniform_vector_4 = torch.cat((torch.rand(bs, rand).cuda(), discrete_vec), 1)
    uniform_vector_4 = uniform_vector_4[torch.randperm(uniform_vector_4.size()[0])]

    uniform_vector_4 = torch.cat((uniform_vector_4, x_i), 1)

    WEIGHTS_3, MEAN_3, VARIANCE_3 = MOG_NET(uniform_vector_3)    
    WEIGHTS_4, MEAN_4, VARIANCE_4 = MOG_NET(uniform_vector_4)    

    MEAN_DATA = torch.cat((y_i - MEAN_3, y_i - MEAN_4))
    VARIANCE_DATA = torch.cat((VARIANCE_3, VARIANCE_4))  
    WEIGHT_DATA = torch.cat((WEIGHTS_3, WEIGHTS_4))  
    
    MEAN_DIFF = MEAN_3 - MEAN_4
    VARIANCE_SUM = VARIANCE_3 + VARIANCE_4
    WEIGHT_DIFF = WEIGHTS_3*WEIGHTS_4
    
    square_term = torch.mean(WEIGHT_DIFF*gaussian_nd_pytorch(MEAN_DIFF, VARIANCE_SUM))
    v_t = beta*v_t + (1-beta)*square_term.detach()
    square_term_unbiased = torch.sqrt(v_t/(1-beta**i))
    
    numerator = torch.mean(WEIGHT_DATA*gaussian_nd_pytorch(MEAN_DATA, VARIANCE_DATA))
    c_t = beta_2*c_t + (1-beta_2)*numerator.detach()
    numerator_unbiased = c_t/(1-beta_2**i)
    
    corr_ = (numerator/square_term_unbiased) - 0.5*numerator_unbiased*square_term/(square_term_unbiased)**3
    (-corr_).backward()
    
    optimizer.step()
    optimizer.zero_grad()
    
    entropy_list.append((numerator_unbiased/square_term_unbiased).item())
    
    if i%100 == 0:
        
        plt.title('Model Conditional')
        visualization_()
        
        learned_mean = MEAN_3.detach().cpu().numpy().reshape(-1, 1)
        plt.scatter(x_i.detach().cpu(), learned_mean, s=0.1, color='black')
        
        learned_mean = MEAN_4.detach().cpu().numpy().reshape(-1, 1)
        plt.scatter(x_i.detach().cpu(), learned_mean, s=0.1, color='black', label='Model Centers')

        plt.plot(center_y, center_x, 'ro',color='red', label='True Main Components')
        # plt.contour(normalize, origin='lower', extent=[min, max, min, max])

        for j in range(0, center_y.shape[0]):
            plt.text(center_y[j]+0.01, center_x[j]+0.01, str(np.where(np.argsort(center_y)==j)[0][0]+1), color='red', fontsize=12)
        plt.legend()
        plt.show()
        print('Iteration:', i, 'Cross-Entropy:', entropy_list[-1])