In [None]:
#### GENERATE A LOT DATASET
min = 0
max = 1

!mkdir -p model

import numpy as np
import matplotlib.pyplot as plt

def laplacian_pdf(input, mean, scale):
     return (1/(2*scale))*np.exp(-np.abs(input-mean)/scale)

def gaussian_nd(input, m, sigma):
    k = sigma.shape[0]
    det = np.linalg.det(sigma)
    inv = np.linalg.pinv(sigma)
    
    return ((2*np.pi)**(-k/2))*det**(-1/2)*np.exp(-(1/2)*np.sum((input-m)@inv*(input-m), 1))

def construct_contour(centers, weights, learned_variance, interp=100):
    QMI_TRUE_LIST = []
    interp = 100
    delta = (max-min)/interp

    x_axis = np.linspace(min, max, interp)
    y_axis = np.linspace(min, max, interp)
    xv, yv = np.meshgrid(x_axis,y_axis)

    input = np.array((xv, yv)).reshape(2, -1).T

    gaussian_plot_joint_ = []
    gaussian_plot_split_x_ = []
    gaussian_plot_split_y_ = []

    #centers = np.concatenate((center_x, center_y), 1)
    difference = input.reshape(1, -1, 2) - centers.reshape(-1, 1, 2)

    for i in range(0, centers.shape[0]):
        gaussian_plot_joint_.append(weights[i]*gaussian_nd(difference[i], 0, np.array([[learned_variance[i,0], 0], [0, learned_variance[i,1]]])).reshape(interp, interp))

    gaussian_plot_joint = np.mean(np.array(gaussian_plot_joint_), 0)*delta*delta
    
    return gaussian_plot_joint

def construct_pdf_mix_mixture(center_x, center_y, COV, samples_per_class=3000):
    
    MEAN = np.concatenate([center_x, center_y], 1)   
    COV_ = []
    weights_ = []
    component_ = []
    num_class = MEAN.shape[0]
    
    interp = 1000
    max = 1
    min = 0
    delta = (max-min)/interp

    x_axis = np.linspace(min, max, interp)
    y_axis = np.linspace(min, max, interp)
    xv, yv = np.meshgrid(x_axis,y_axis)

    input = np.array((xv, yv)).reshape(2, -1).T
    pdf_map = np.zeros((interp, interp))
    
    for i in range(0, num_class):
        COV[0, 0] = np.random.uniform(0.0001, 0.002, 1)[0]
        COV[1, 1] = np.random.uniform(0.0001, 0.002, 1)[0]
        weights = np.random.uniform(0, 1, 1)[0]+0.5
        
        rv = np.random.choice(3, 1)[0]
        
        if rv==0:
            samples = np.random.multivariate_normal(MEAN[i], COV, int(weights*samples_per_class))
            pdf_map = pdf_map + (weights*gaussian_nd(input - MEAN[i], 0, COV)).reshape(interp, interp)
            
        if rv==1:
            dim_1 = np.random.uniform(MEAN[i, 0]-COV[0, 0]*10, MEAN[i, 0]+COV[0, 0]*10, int(weights*samples_per_class))
            dim_2 = np.random.uniform(MEAN[i, 1]-COV[1, 1]*10, MEAN[i, 1]+COV[1, 1]*10, int(weights*samples_per_class))
            samples = np.array((dim_1, dim_2)).T
            pdf = (input[:, 0]>(MEAN[i, 0]-COV[0, 0]*10))*(input[:, 0]<(MEAN[i, 0]+COV[0, 0]*10))
            pdf = pdf*(input[:, 1]>(MEAN[i, 1]-COV[1, 1]*10))*(input[:, 1]<(MEAN[i, 1]+COV[1, 1]*10))
            pdf_map = pdf_map + (pdf/(400*COV[0, 0]*COV[1, 1])).reshape(interp, interp)
            
        if rv==2:
            dim_1 = np.random.laplace(MEAN[i, 0], scale=COV[0, 0]*10, size=int(weights*samples_per_class))
            dim_2 = np.random.laplace(MEAN[i, 1], scale=COV[1, 1]*10, size=int(weights*samples_per_class))
            samples = np.array((dim_1, dim_2)).T
            pdf_dim_1 = laplacian_pdf(input[:, 0], MEAN[i, 0], COV[0, 0]*10)
            pdf_dim_2 = laplacian_pdf(input[:, 1], MEAN[i, 1], COV[1, 1]*10)
            
            pdf_map = pdf_map + weights*(pdf_dim_1*pdf_dim_2).reshape(interp, interp)
            
        component_.append(samples)
        COV_.append(np.copy(COV))
        weights_.append(np.copy(weights))
    component_ = np.concatenate(component_).reshape(-1, COV.shape[0])
    
    return component_, MEAN, np.array(COV_), np.array(weights_), pdf_map, delta

K = 20
dim = 1

seed_ = [4, 5, 6, 7, 8]
for seed in seed_:

    np.random.seed(seed)

    center_x = np.array([np.random.uniform(0.2, 0.8, K)]*dim).T
    center_y = np.array([np.random.uniform(0.2, 0.8, K)]*dim).T

    how_many_samples = 20000

    var = 0.001

    dim = 1
    # JOINT DISTRIBUTION
    COV = np.eye(dim*2)*var

    QMI_LIST = []
    np.random.seed(seed)
    samples, MEAN, COV_, weights_, pdf_map, delta = construct_pdf_mix_mixture(center_x, center_y, COV, samples_per_class=how_many_samples)
    pdf_map = pdf_map/np.sum(pdf_map)

    true_entropy = np.sqrt(np.sum(pdf_map*pdf_map)/(delta*delta))
    print('EXP #{0} entropy:'.format(seed-3), true_entropy)
    plt.imshow(np.log(pdf_map+1e-5), origin='lower', extent=[min, max, min, max])
    plt.show()

# plt.style.reload_library()

# plt.style.use('science')
# plt.style.use(['science','no-latex', 'high-vis', 'notebook'])
# plt.rcParams["figure.figsize"] = [6,4]

# WHETHER TO COMMENT

import numpy as np
import matplotlib.pyplot as plt

import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from sklearn.linear_model import LinearRegression

import sklearn
from sklearn import datasets

#torch.cuda.set_device(1)

def GD(net, lr):
    for param in net.parameters():
        if param.requires_grad:
            param.data = param.data - lr*param.grad
            
    net.zero_grad()
    return 0

def sample_uniform_data(min=-5, max=5, batch=300):
    return np.random.uniform(min, max, batch)

def gaussian_nd_pytorch(MEAN, VARIANCE):
    bs = VARIANCE.shape[0]
    dim = VARIANCE.shape[1]

    det = VARIANCE[:, 0]
    for i in range(1, dim):
        det = det*VARIANCE[:, i]    
        
    product = torch.sum(((MEAN.reshape(-1, dim)*(1/VARIANCE).view(-1, dim))*MEAN.reshape(-1, dim)), 1)
    
    return ((2*np.pi)**(-dim/2))*det**(-1/2)*torch.exp(-(1/2)*product)

def sample_c(batchsize=32, dis_category=5):
    rand_c = np.zeros((batchsize,dis_category),dtype='float32')
    for i in range(0,batchsize):
        rand = np.random.multinomial(1, dis_category*[1/float(dis_category)], size=1)
        rand_c[i] = rand

    label_c = np.argmax(rand_c,axis=1)
    label_c = torch.LongTensor(label_c.astype('int'))
    rand_c = torch.from_numpy(rand_c.astype('float32'))
    return rand_c,label_c

def generate_laplacian(center_x, center_y,  scale=0.01, samples_per_class=3000):
    
    component_ = []
    num_class = MEAN.shape[0]

    for i in range(0, num_class):
        dim_1 = np.random.laplace(MEAN[i, 0], scale=scale, size=samples_per_class)
        dim_2 = np.random.laplace(MEAN[i, 1], scale=scale, size=samples_per_class)
        samples = np.array((dim_1, dim_2)).T
        component_.append(samples)

    component_ = np.array(component_).reshape(-1, MEAN.shape[1])    
    
    return component_

def generate_uniform(center_x, center_y,  length=0.5, samples_per_class=3000):
    component_ = []
    
    num_class = MEAN.shape[0]
    for i in range(0, num_class):
        x_1 = np.random.uniform(center_x[i], center_x[i]+length, samples_per_class)
        x_3 = np.random.uniform(center_y[i], center_y[i]+length, samples_per_class)
        x = np.array((x_1, x_3)).T
        component_.append(x)
            
    component_ = np.concatenate(component_, 0)
    return component_

class DIS_MOG_relu(nn.Module):
    def __init__(self, rand, HIDDEN, dim):
        super(DIS_MOG_relu, self).__init__()
        self.dim = dim
    
        self.fc1 = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        
        self.fc1_ = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33_ = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.fc1_w = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2_w = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3_w = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33_w = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.sum_dim_mean = 128
        self.sum_dim_var = 128
        self.sum_dim_weights = 128

        self.fc5 = nn.Linear(HIDDEN, dim*self.sum_dim_mean, bias=True)
        self.fc6 = nn.Linear(HIDDEN, dim*self.sum_dim_var, bias=True)
        self.fcw = nn.Linear(HIDDEN, self.sum_dim_weights, bias=True)

    def forward(self, x):

        x = self.bn1_(torch.relu((self.fc1_(x))))
        x = self.bn2_(torch.relu((self.fc2_(x))))
        #x = self.bn3_(torch.relu((self.fc3_(x))))

        # x_m = self.bn1_(torch.sigmoid((self.fc1_(x))))
        # x_m = self.bn2_(torch.sigmoid((self.fc2_(x_m))))
        #x_m = self.bn3_(torch.relu((self.fc3_(x))))
        x_m = (torch.relu((self.fc33_(x))))

        # x_v = self.bn1(torch.sigmoid((self.fc1(x))))
        # x_v = self.bn2(torch.sigmoid((self.fc2(x_v))))
        #x_v = self.bn3(torch.relu((self.fc3(x))))
        x_v = (torch.relu((self.fc33(x))))

        # x_w = self.bn1_w(torch.sigmoid((self.fc1_w(x))))
        # x_w = self.bn2_w(torch.sigmoid((self.fc2_w(x_w))))
        #x_w = self.bn3_w(torch.relu((self.fc3_w(x))))
        x_w = (torch.relu((self.fc33_w(x))))

        dim = self.dim 
        
        mean = torch.sigmoid(self.fc5(x_m)).view(x.shape[0], dim, self.sum_dim_mean)
        mean = torch.mean(mean, 2)

        variance = torch.sigmoid(self.fc6(x_v)).view(x.shape[0], dim, self.sum_dim_var)
        variance = (torch.mean(variance, 2))+1e-6
        
        # weights = torch.sigmoid(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        # weights = torch.mean(weights, 2)

        weights = torch.sigmoid(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        weights = torch.mean(weights, 2)

        return weights, mean, variance

def generate_fix_discrete(bs, d_class):
    num = bs/d_class
    return torch.cat([torch.nn.functional.one_hot(torch.arange(0, d_class))]*int(num), 0)

def compute_VR(learned_weights, learned_mean, MEAN):

  learned_weights = learned_weights/np.mean(learned_weights)
  return np.mean(learned_weights*((((learned_mean.reshape(-1, 1, 2) - MEAN.reshape(1, -1, 2))**2).sum(2).min(1))<1e-4))

def run_SGM(SEED, x, MEAN, iter=60000):
  torch.manual_seed(SEED)
  np.random.seed(SEED)

  rand = 1
  HIDDEN = 1024
  dim = 1

  d_class = 300
  num = 20
  bs = d_class*num

  d_howmany = 1

  MOG_NET = DIS_MOG_relu(rand+d_class*d_howmany, HIDDEN, dim*2).cuda()

  optimizer = optim.Adam([
              {'params': MOG_NET.parameters(), 'lr': 0.001, 'betas': (0.9, 0.999)},
          ])

  entropy_list = []
  discrete_prob = torch.ones((d_class,)).float().cuda()

  beta = 0.999
  v_t = 0.

  beta_2 = 0.999
  c_t = 0.

  VAR_ZERO = torch.zeros((bs, 1, 2)).cuda()
  WEIGHT_ZERO = torch.zeros((bs, 1)).cuda()

  discrete_vec = generate_fix_discrete(bs, d_class).float().cuda()

  for i in range(1, iter):
  #     if i%1000 == 0 and bs<11000:
  #         bs+=1000
  #         VAR_ZERO = torch.zeros((bs, 1, 2)).cuda()
  #         WEIGHT_ZERO = torch.zeros((bs, 1)).cuda()

  #     uniform_vector_1 = torch.cat((torch.rand(bs, rand).cuda(), 
  #         torch.nn.functional.one_hot(torch.multinomial(discrete_prob, d_howmany*bs, replacement=True), d_class).view(bs, d_howmany*d_class).float()), 1)
  #     WEIGHTS_1, MEAN_1, VARIANCE_1 = MOG_NET(uniform_vector_1)
      
  #     uniform_vector_2 = torch.cat((torch.rand(bs, rand).cuda(), 
  #         torch.nn.functional.one_hot(torch.multinomial(discrete_prob, d_howmany*bs, replacement=True), d_class).view(bs, d_howmany*d_class).float()), 1)
  #     WEIGHTS_2, MEAN_2, VARIANCE_2 = MOG_NET(uniform_vector_2)
      
  #     uniform_vector_3 = torch.cat((torch.rand(bs, rand).cuda(), 
  #         torch.nn.functional.one_hot(torch.multinomial(discrete_prob, d_howmany*bs, replacement=True), d_class).view(bs, d_howmany*d_class).float()), 1)
  #     WEIGHTS_3, MEAN_3, VARIANCE_3 = MOG_NET(uniform_vector_3)
      #WEIGHTS_3.fill_(1)
      
      uniform_vector_3 = torch.cat((torch.rand(bs, rand).cuda(), discrete_vec), 1)
      WEIGHTS_3, MEAN_3, VARIANCE_3 = MOG_NET(uniform_vector_3)
      
      b1 = np.random.choice(x.shape[0], bs)
      b2 = np.random.choice(x.shape[0], bs)
      b3 = np.random.choice(x.shape[0], bs)
      
      joint_ = torch.from_numpy(x[b1, :]).float().cuda()
      disjoint_ = torch.from_numpy(np.concatenate((x[b2, :dim], x[b3, dim:]), 1)).float().cuda()
      
      input = joint_
      
      MEAN_DIFF = (MEAN_3.view(bs, 1, 2) - MEAN_3.view(1, bs, 2)).view(-1, 2)
      VARIANCE_SUM = (VARIANCE_3.view(bs, 1, 2) + VARIANCE_3.view(1, bs, 2)).view(-1, 2)
      WEIGHT_DIFF = (WEIGHTS_3.view(bs, 1)*WEIGHTS_3.view(1, bs)).view(-1)

  #     MEAN_DIFF = MEAN_1 - MEAN_3
  #     VARIANCE_SUM = VARIANCE_1 + VARIANCE_3

      square_term = torch.mean(WEIGHT_DIFF*gaussian_nd_pytorch(MEAN_DIFF, VARIANCE_SUM))
      v_t = beta*v_t + (1-beta)*square_term.detach()
      square_term_unbiased = torch.sqrt(v_t/(1-beta**i))
      
      MEAN_DATA = (input.view(bs, 1, 2) - MEAN_3.view(1, bs, 2)).view(-1, 2)
      VARIANCE_DATA = (VAR_ZERO + VARIANCE_3.view(1, bs, 2)).view(-1, 2)
      WEIGHT_DATA = (WEIGHT_ZERO + WEIGHTS_3.view(1, bs)).view(-1)
      
  #     MEAN_DATA = input - MEAN_3
  #     VARIANCE_DATA = VARIANCE_3
      
      numerator = torch.mean(WEIGHT_DATA*gaussian_nd_pytorch(MEAN_DATA, VARIANCE_DATA))
      c_t = beta_2*c_t + (1-beta_2)*numerator.detach()
      numerator_unbiased = c_t/(1-beta_2**i)
      
      corr_ = (numerator/square_term_unbiased) - 0.5*numerator_unbiased*square_term/(square_term_unbiased)**3
      #corr_ = -square_term + 2*numerator

      (-corr_).backward()
      
      optimizer.step()
      optimizer.zero_grad()
      entropy_list.append((numerator_unbiased/square_term_unbiased).item())
      
      if i%10 == 0:
          
          plt.rcParams["figure.figsize"] = [4,4]

          learned_mean = MEAN_3.detach().cpu().numpy().reshape(-1, 2)
          learned_variance = VARIANCE_3.detach().cpu().numpy().reshape(-1, 2)
          learned_weights = WEIGHTS_3.detach().cpu().numpy()

  #         WEIGHTS_3 = WEIGHTS_3.detach().cpu().numpy()
          
          # plt.scatter(learned_mean[0:,0], learned_mean[0:,1], s=1, c=learned_weights[:,0], zorder=333)
          # joint = joint_.detach().cpu()
          # plt.scatter(joint[:,0], joint[:,1], s=1, alpha=0.6, color='pink')
          # plt.show()

          print('Iteration:',i, 'Cross-entropy:', entropy_list[-1])
          np.save('./model/MOG_NET_AM_seed{0}.npy'.format(SEED), MOG_NET)
          np.save('./model/entropy_list_AM_seed{0}.npy'.format(SEED), entropy_list)

          #print(i,'2',(numerator/torch.sqrt(square_term)).item())

  learned_mean = MEAN_3.detach().cpu().numpy().reshape(-1, 2)
  learned_variance = VARIANCE_3.detach().cpu().numpy().reshape(-1, 2)
  learned_weights = WEIGHTS_3.detach().cpu().numpy()

  np.save('./model/MOG_NET_AM_seed{0}.npy'.format(SEED), MOG_NET)
  np.save('./model/entropy_list_AM_seed{0}.npy'.format(SEED), entropy_list)
  rate = compute_VR(learned_weights, learned_mean, MEAN)

  print('seed: {0}, entropy: {1}, validation rate:{2}'.format(SEED, entropy_list[-1], rate))

seed_ = [4, 5, 6, 7, 8]
for seed in seed_:

    np.random.seed(seed)

    center_x = np.array([np.random.uniform(0.2, 0.8, K)]*dim).T
    center_y = np.array([np.random.uniform(0.2, 0.8, K)]*dim).T

    how_many_samples = 20000

    var = 0.001

    dim = 1
    # JOINT DISTRIBUTION
    COV = np.eye(dim*2)*var

    QMI_LIST = []
    np.random.seed(seed)
    samples, MEAN, COV_, weights_, pdf_map, delta = construct_pdf_mix_mixture(center_x, center_y, COV, samples_per_class=how_many_samples)
    pdf_map = pdf_map/np.sum(pdf_map)

    true_entropy = np.sqrt(np.sum(pdf_map*pdf_map)/(delta*delta))
    print('EXP #{0} entropy:'.format(seed-3), true_entropy)
    plt.imshow(np.log(pdf_map+1e-5), origin='lower', extent=[min, max, min, max])
    plt.show()

    run_SGM(seed, samples, MEAN, iter=60000)