In [None]:
import numpy as np
import matplotlib.pyplot as plt

import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from sklearn.linear_model import LinearRegression

import sklearn
from sklearn import datasets

def GD(net, lr):
    for param in net.parameters():
        if param.requires_grad:
            param.data = param.data - lr*param.grad
            
    net.zero_grad()
    return 0

def sample_uniform_data(min=-5, max=5, batch=300):
    return np.random.uniform(min, max, batch)

############ WITH DISCRETE
##### NEW CENTERS
# def gaussian_nd_pytorch(MEAN, VARIANCE):
#     bs = VARIANCE.shape[0]
#     dim = VARIANCE.shape[1]

#     det = (VARIANCE[:, 0]*VARIANCE[:, 1])
    
#     product = torch.sum(((MEAN.reshape(-1, dim)*(1/VARIANCE).view(-1, dim))*MEAN.reshape(-1, dim)), 1)
    
#     return ((2*np.pi)**(-dim/2))*det**(-1/2)*torch.exp(-(1/2)*product)

def gaussian_nd_pytorch(MEAN, VARIANCE):
    bs = VARIANCE.shape[0]
    dim = VARIANCE.shape[1]

    det = VARIANCE[:, 0]
    for i in range(1, dim):
        det = det*VARIANCE[:, i]    
        
    product = torch.sum(((MEAN.reshape(-1, dim)*(1/VARIANCE).view(-1, dim))*MEAN.reshape(-1, dim)), 1)
    
    return ((2*np.pi)**(-dim/2))*det**(-1/2)*torch.exp(-(1/2)*product)

def sample_c(batchsize=32, dis_category=5):
    rand_c = np.zeros((batchsize,dis_category),dtype='float32')
    for i in range(0,batchsize):
        rand = np.random.multinomial(1, dis_category*[1/float(dis_category)], size=1)
        rand_c[i] = rand

    label_c = np.argmax(rand_c,axis=1)
    label_c = torch.LongTensor(label_c.astype('int'))
    rand_c = torch.from_numpy(rand_c.astype('float32'))
    return rand_c,label_c

def generate_laplacian(center_x, center_y,  scale=0.01, samples_per_class=3000):
    
    component_ = []
    num_class = MEAN.shape[0]

    for i in range(0, num_class):
        dim_1 = np.random.laplace(MEAN[i, 0], scale=scale, size=samples_per_class)
        dim_2 = np.random.laplace(MEAN[i, 1], scale=scale, size=samples_per_class)
        samples = np.array((dim_1, dim_2)).T
        component_.append(samples)

    component_ = np.array(component_).reshape(-1, MEAN.shape[1])    
    
    return component_

def generate_uniform(center_x, center_y,  length=0.5, samples_per_class=3000):
    component_ = []
    
    num_class = MEAN.shape[0]
    for i in range(0, num_class):
        x_1 = np.random.uniform(center_x[i], center_x[i]+length, samples_per_class)
        x_3 = np.random.uniform(center_y[i], center_y[i]+length, samples_per_class)
        x = np.array((x_1, x_3)).T
        component_.append(x)
            
    component_ = np.concatenate(component_, 0)
    return component_

class DIS_MOG_relu(nn.Module):
    def __init__(self, rand, HIDDEN, dim):
        super(DIS_MOG_relu, self).__init__()
        self.dim = dim
    
        self.fc1 = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3 = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        
        self.fc1_ = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3_ = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33_ = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.fc1_w = nn.Linear(rand, HIDDEN, bias=True)
        self.bn1_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc2_w = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc3_w = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn3_w = torch.nn.BatchNorm1d(HIDDEN)
        self.fc33_w = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.sum_dim_mean = 128
        self.sum_dim_var = 128
        self.sum_dim_weights = 128

        self.fc5 = nn.Linear(HIDDEN, dim*self.sum_dim_mean, bias=True)
        self.fc6 = nn.Linear(HIDDEN, dim*self.sum_dim_var, bias=True)
        self.fcw = nn.Linear(HIDDEN, self.sum_dim_weights, bias=True)

    def forward(self, x):

        x = self.bn1_(torch.relu((self.fc1_(x))))
        x = self.bn2_(torch.relu((self.fc2_(x))))
        #x = self.bn3_(torch.relu((self.fc3_(x))))

        # x_m = self.bn1_(torch.sigmoid((self.fc1_(x))))
        # x_m = self.bn2_(torch.sigmoid((self.fc2_(x_m))))
        #x_m = self.bn3_(torch.relu((self.fc3_(x))))
        x_m = (torch.relu((self.fc33_(x))))

        # x_v = self.bn1(torch.sigmoid((self.fc1(x))))
        # x_v = self.bn2(torch.sigmoid((self.fc2(x_v))))
        #x_v = self.bn3(torch.relu((self.fc3(x))))
        x_v = (torch.relu((self.fc33(x))))

        # x_w = self.bn1_w(torch.sigmoid((self.fc1_w(x))))
        # x_w = self.bn2_w(torch.sigmoid((self.fc2_w(x_w))))
        #x_w = self.bn3_w(torch.relu((self.fc3_w(x))))
        x_w = (torch.relu((self.fc33_w(x))))

        dim = self.dim 
        
        mean = torch.sigmoid(self.fc5(x_m)).view(x.shape[0], dim, self.sum_dim_mean)
        mean = torch.mean(mean, 2)

        variance = torch.sigmoid(self.fc6(x_v)).view(x.shape[0], dim, self.sum_dim_var)
        variance = (torch.mean(variance, 2))+1e-6
        
        # weights = torch.sigmoid(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        # weights = torch.mean(weights, 2)

        weights = torch.sigmoid(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        weights = torch.mean(weights, 2)

        return weights, mean, variance

def generate_fix_discrete(bs, d_class):
    num = bs/d_class
    return torch.cat([torch.nn.functional.one_hot(torch.arange(0, d_class))]*int(num), 0).float()

def compute_VR(learned_weights, learned_mean, MEAN):
    learned_weights = learned_weights/np.mean(learned_weights)
    return np.mean(learned_weights*((((learned_mean.reshape(-1, 1, 2) - MEAN.reshape(1, -1, 2))**2).sum(2).min(1))<1e-4))

def sample_discrete_class(bs=3000, how_many=5, d_class=5):
  return torch.nn.functional.one_hot(torch.randint(d_class, (bs*how_many,)), num_classes=d_class).view(bs, -1).float()

import torchvision.utils as vutils
import matplotlib.image as mpimg

def compute_square_term(W_G, M_G, V_G):

  bs = M_G.shape[0]
  dim = M_G.shape[1]

  MEAN_DIFF = (M_G.view(bs, 1, dim) - M_G.view(1, bs, dim)).view(-1, dim)
  VARIANCE_DIFF = (V_G.view(bs, 1, dim) + V_G.view(1, bs, dim)).view(-1, dim)
  WEIGHT_DIFF = (W_G.view(bs, 1)*W_G.view(1, bs)).view(-1)

  return torch.mean(WEIGHT_DIFF*gaussian_nd_pytorch(MEAN_DIFF, VARIANCE_DIFF))

def compute_cross_term(W_G, M_G, V_G, W_D, M_D, V_D):
  bs = M_G.shape[0]
  dim = M_G.shape[1]

  MEAN_CROSS = (M_G.view(bs, 1, dim) - M_D.view(1, bs, dim)).view(-1, dim)
  VARIANCE_CROSS =  (V_G.view(bs, 1, dim) + V_D.view(1, bs, dim)).view(-1, dim)
  WEIGHT_CROSS = (W_G.view(bs, 1)*W_D.view(1, bs)).view(-1)

  return torch.mean(WEIGHT_CROSS*gaussian_nd_pytorch(MEAN_CROSS, VARIANCE_CROSS))

def compute_cross_per_sample(W_G, M_G, V_G, W_D, M_D, V_D):
  bs = M_G.shape[0]
  dim = M_G.shape[1]

  MEAN_CROSS = M_G - M_D
  VARIANCE_CROSS =  V_G + V_D
  WEIGHT_CROSS = (W_G*W_D).view(-1)

  return torch.mean(WEIGHT_CROSS*gaussian_nd_pytorch(MEAN_CROSS, VARIANCE_CROSS))

def adaptive_estimation(v_t, beta, square_term, i):
    v_t = beta*v_t + (1-beta)*square_term.detach()
    return v_t, (v_t/(1-beta**i))

# def adaptive_estimation(v_t, beta, square_term, i):
#     if i==1:
#         return square_term.detach(), square_term.detach()
#     new = beta*v_t + (1-beta)*square_term.detach()
#     return new, new

### NORMALIZE INPUT
def normalize(M_G_1, V_G_1):
    mean_g = torch.mean(M_G_1)
    M_G_1 = M_G_1 - mean_g
    M_G_1 = M_G_1/torch.std(M_G_1)
    #std_g = torch.mean(V_G_1+torch.mean(M_G_1**2))
    #V_G_1 = V_G_1/std_g
    #M_G_1 = M_G_1/std_g
    return M_G_1, V_G_1

def normalize_nogradient(M_G_1, V_G_1):
    M_G_1 = M_G_1 - mean_g.detach()
    M_G_1 = M_G_1/torch.std(M_G_1).detach()
    return M_G_1, V_G_1

def gaussian_1d(input, m, sigma):
    det = sigma
    inv = 1/sigma
        
    input = input.reshape(-1, 1)
    m = m.reshape(1, -1)

    return ((2*np.pi)**(-1/2))*det**(-1/2)*np.exp(-(1/2)*((input-m)**2*inv))

def construct_contour1d(centers, weights, learned_variance, interp=100):
    QMI_TRUE_LIST = []
    min = 0
    max = 1
    delta = (max-min)/interp

    x_axis = np.linspace(min, max, interp)

    gaussian_plot_joint_ = []
    gaussian_plot_split_x_ = []
    gaussian_plot_split_y_ = []

    gaussian_plot_joint_ = (weights*gaussian_1d(x_axis, centers, learned_variance))/np.sum(weights)
    gaussian_plot_joint = np.mean(np.array(gaussian_plot_joint_), 1)/delta
    
    return gaussian_plot_joint

def visualize():
    learned_mean = M_G_1.detach().cpu().numpy().reshape(-1)
    learned_variance = V_G_1.detach().cpu().numpy().reshape(-1)
    learned_weights = W_G_1.detach().cpu().numpy().reshape(-1)

    gaussian_plot_joint = construct_contour1d(learned_mean, learned_weights, learned_variance, interp=100)
    plt.plot(gaussian_plot_joint, label='G')

    learned_mean = M_D_1.detach().cpu().numpy().reshape(-1)
    learned_variance = V_D_1.detach().cpu().numpy().reshape(-1)
    learned_weights = W_D_1.detach().cpu().numpy().reshape(-1)

    gaussian_plot_joint = construct_contour1d(learned_mean, learned_weights, learned_variance, interp=100)
    plt.plot(gaussian_plot_joint, label='D')
    plt.legend()
    plt.show()

class avgpool(nn.Module):
    def __init__(self, up_size=0):
        super(avgpool, self).__init__()
        
    def forward(self, x):
        out_man = (x[:,:,::2,::2] + x[:,:,1::2,::2] + x[:,:,::2,1::2] + x[:,:,1::2,1::2]) / 4
        return out_man
    
class ResidualBlock(nn.Module):

    def __init__(self, in_dim, out_dim, resample=None, up_size=0):
        super(ResidualBlock, self).__init__()
        if resample == 'up' or resample == 'up_nbn':
            self.bn1 = nn.BatchNorm2d(in_dim)
            self.conv1 = nn.Conv2d(in_dim, out_dim, 3, 1, 1, bias=True)
            self.upsample = torch.nn.Upsample(scale_factor=2)
            self.upsample_conv = nn.Conv2d(in_dim, out_dim, 1, 1, 0, bias=True)
            self.conv2 = nn.Conv2d(out_dim, out_dim, 3, 1, 1, bias=True)
            self.bn2 = nn.BatchNorm2d(out_dim)
            
        elif resample == 'down':
            self.bn1 = nn.BatchNorm2d(in_dim)
            self.bn2 = nn.BatchNorm2d(out_dim)

            self.conv1 = nn.Conv2d(in_dim, out_dim, 3, 1, 1, bias=True)
            self.conv2 = nn.Conv2d(out_dim, out_dim, 3, 1, 1, bias=True)
            #self.pool = avgpool()
            self.pool = torch.nn.AvgPool2d(2, 2)
            self.pool_conv = nn.Conv2d(in_dim, out_dim, 1, 1, 0, bias=True)
        
        elif resample == None:
            self.bn1 = nn.BatchNorm2d(in_dim)
            self.bn2 = nn.BatchNorm2d(out_dim)
            
            self.conv1 = nn.Conv2d(in_dim, out_dim, 3, 1, 1, bias=True)
            self.conv2 = nn.Conv2d(out_dim, out_dim, 3, 1, 1, bias=True)
            
        self.resample = resample

    def forward(self, x):
        
        if self.resample == None:
            shortcut = x
            output = x
            
            output = self.bn1(output)
            output = nn.functional.relu(output)
            output = self.conv1(output)
            
            output = self.bn2(output)
            output = nn.functional.relu(output)
            output = self.conv2(output)
            
        elif self.resample == 'up_nbn':
            shortcut = x
            output = x
            
            shortcut = self.upsample(shortcut) #upsampleconv
            shortcut = self.upsample_conv(shortcut)
            
            #output = self.bn1(output)
            output = nn.functional.relu(output)
            output = self.conv1(output)

            #output = self.bn2(output)
            output = nn.functional.relu(output)
            output = self.upsample(output) #upsampleconv
            output = self.conv2(output)
            
        elif self.resample == 'up':
            shortcut = x
            output = x
            
            shortcut = self.upsample(shortcut) #upsampleconv
            shortcut = self.upsample_conv(shortcut)
            
            output = self.bn1(output)
            output = nn.functional.relu(output)
            output = self.conv1(output)

            output = self.bn2(output)
            output = nn.functional.relu(output)
            output = self.upsample(output) #upsampleconv
            output = self.conv2(output)
                        
        elif self.resample == 'down':
            shortcut = x
            output = x
            
            shortcut = self.pool_conv(shortcut) #convmeanpool
            shortcut = self.pool(shortcut)

            output = self.bn1(output)
            output = nn.functional.relu(output)
            output = self.conv1(output)
            
            output = self.bn2(output)
            output = nn.functional.relu(output)
            output = self.conv2(output)    #convmeanpool
            output = self.pool(output)
            
        return output+shortcut

class ResidualBlock_thefirstone(nn.Module):

    def __init__(self, in_dim, out_dim, resample=None, up_size=0):
        super(ResidualBlock_thefirstone, self).__init__()
        
        self.conv1 = nn.Conv2d(in_dim, out_dim, 3, 1, 1, bias=True)
        self.conv2 = nn.Conv2d(out_dim, out_dim, 3, 1, 1, bias=True)
        #self.pool = avgpool()
        self.pool = torch.nn.AvgPool2d(2, 2)
        self.pool_conv = nn.Conv2d(in_dim, out_dim, 1, 1, 0, bias=True)
        self.bn1 = nn.BatchNorm2d(out_dim)
        
    def forward(self, x):
        
        shortcut = x
        output = x
        
        shortcut = self.pool(shortcut) #meanpoolconv
        shortcut = self.pool_conv(shortcut)

        output = self.conv1(output)
        #output = self.bn1(output)
        output = nn.functional.relu(output)
        output = self.conv2(output) #convmeanpool
        output = self.pool(output)
            
        return output+shortcut

#create_gan_architecture
class Generator(nn.Module):

    def __init__(self, rand=128):
        super(Generator, self).__init__()
        self.rand = rand
        self.linear = nn.Linear(rand, 2048, bias=True)
        self.layer_up_1 = ResidualBlock(128, 128, 'up', up_size=8)
        self.layer_up_2 = ResidualBlock(128, 128, 'up', up_size=16)
        self.layer_up_3 = ResidualBlock(128, 128, 'up', up_size=32)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv_last = nn.Conv2d(128, 3, 3, 1, 1, bias=True)

    def forward(self, x):
        x = x.view(-1,self.rand)
        x = self.linear(x)
        x = x.view(-1,128,4,4)
        x = self.layer_up_1(x)
        x = self.layer_up_2(x)
        x = self.layer_up_3(x)
        
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = self.conv_last(x)
        #x = x.reshape(x.shape[0], 30, 3, x.shape[2], x.shape[3])
        #x = nn.functional.sigmoid(x)
        #x = torch.tanh(x)

        x = torch.relu(x)
        x = torch.clamp(x, 0, 1)

#         x = torch.relu(x)
#         x = torch.clamp(x, 0, 1)
        #x = x.mean(1)

        return x


class Discriminator(nn.Module):
    def __init__(self, dim_input=3*32*32, dim=1):
        super(Discriminator, self).__init__()
        self.dim = dim
        HIDDEN = 2048
    
        self.fc33 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        
        self.fc1_ = nn.Linear(dim_input, HIDDEN, bias=True)
        self.bn1_ = torch.nn.BatchNorm1d(HIDDEN, affine=False)
        self.fc2_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2_ = torch.nn.BatchNorm1d(HIDDEN, affine=False)
        
        self.fc01_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.fc02_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.fc03_ = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.bn3_ = torch.nn.BatchNorm1d(HIDDEN)
        self.bn4_ = torch.nn.BatchNorm1d(HIDDEN)
        self.bn5_ = torch.nn.BatchNorm1d(HIDDEN)

        self.fc33_ = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.fc33_w = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.sum_dim_mean = 1024
        self.sum_dim_var = 1024
        self.sum_dim_weights = 1024

        self.fc5 = nn.Linear(HIDDEN, dim*self.sum_dim_mean, bias=True)
        self.fc6 = nn.Linear(HIDDEN, dim*self.sum_dim_var, bias=True)
        self.fcw = nn.Linear(HIDDEN, self.sum_dim_weights, bias=True)

    def forward(self, x):

        x = x.view(-1, 3*32*32)

        x = self.bn1_(torch.relu((self.fc1_(x))))
        x = self.bn2_(torch.relu((self.fc2_(x))))
        #x = (torch.sigmoid((self.fc01_(x))))

        # x_m = self.bn1_(torch.sigmoid((self.fc1_(x))))
        # x_m = self.bn2_(torch.sigmoid((self.fc2_(x_m))))
        #x_m = self.bn3_(torch.relu((self.fc3_(x))))
        x_m = (torch.relu((self.fc33_(x))))

        # x_v = self.bn1(torch.sigmoid((self.fc1(x))))
        # x_v = self.bn2(torch.sigmoid((self.fc2(x_v))))
        #x_v = self.bn3(torch.relu((self.fc3(x))))
        x_v = (torch.relu((self.fc33(x))))

        # x_w = self.bn1_w(torch.sigmoid((self.fc1_w(x))))
        # x_w = self.bn2_w(torch.sigmoid((self.fc2_w(x_w))))
        #x_w = self.bn3_w(torch.relu((self.fc3_w(x))))
        x_w = (torch.relu((self.fc33_w(x))))

        dim = self.dim 
        
        mean = torch.relu(self.fc5(x_m)).view(x.shape[0], dim, self.sum_dim_mean)
        mean = torch.mean(mean, 2)
        mean = torch.clamp(mean, 1e-5, 1)

        variance = torch.relu(self.fc6(x_v)).view(x.shape[0], dim, self.sum_dim_var)
        variance = torch.mean(variance, 2)
        variance = torch.clamp(variance, 1e-3, 1)*0+1e-2
        #variance = torch.clamp(variance, 1e-6, 1e-1)
        
        # weights = torch.sigmoid(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        # weights = torch.mean(weights, 2)

        weights = torch.relu(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        weights = torch.mean(weights, 2)
        #weights = torch.clamp(weights, 1e-5)

        return mean

class Generator(nn.Module):

    def __init__(self, rand=128):
        super(Generator, self).__init__()
        self.rand = rand
        self.linear = nn.Linear(rand, 2048, bias=True)
        self.layer_up_1 = ResidualBlock(128, 128, 'up', up_size=8)
        self.layer_up_2 = ResidualBlock(128, 128, 'up', up_size=16)
        self.layer_up_3 = ResidualBlock(128, 128, 'up', up_size=32)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv_last = nn.Conv2d(128, nc, 3, 1, 1, bias=True)

    def forward(self, x):
        x = x.view(-1,self.rand)
        x = self.linear(x)
        x = x.view(-1,128,4,4)
        x = self.layer_up_1(x)
        x = self.layer_up_2(x)
        x = self.layer_up_3(x)
        
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = self.conv_last(x)
        #x = x.reshape(x.shape[0], 30, 3, x.shape[2], x.shape[3])
        #x = nn.functional.sigmoid(x)
        #x = torch.tanh(x)

        x = torch.relu(x)
        x = torch.clamp(x, 0, 1)

        return x

class Discriminator(nn.Module):
    def __init__(self, dim_input=1*32*32, dim=1):
        super(Discriminator, self).__init__()
        self.dim = dim
        HIDDEN = 2048
    
        self.fc33 = nn.Linear(HIDDEN, HIDDEN, bias=True)
        
        self.fc1_ = nn.Linear(dim_input, HIDDEN, bias=True)
        self.bn1_ = torch.nn.BatchNorm1d(HIDDEN, affine=False)
        self.fc2_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.bn2_ = torch.nn.BatchNorm1d(HIDDEN, affine=False)
        
        self.fc01_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.fc02_ = nn.Linear(HIDDEN, HIDDEN, bias=True)
        self.fc03_ = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.bn3_ = torch.nn.BatchNorm1d(HIDDEN)
        self.bn4_ = torch.nn.BatchNorm1d(HIDDEN)
        self.bn5_ = torch.nn.BatchNorm1d(HIDDEN)

        self.fc33_ = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.fc33_w = nn.Linear(HIDDEN, HIDDEN, bias=True)

        self.sum_dim_mean = 1024
        self.sum_dim_var = 1024
        self.sum_dim_weights = 1024

        self.fc5 = nn.Linear(HIDDEN, dim*self.sum_dim_mean, bias=True)
        self.fc6 = nn.Linear(HIDDEN, dim*self.sum_dim_var, bias=True)
        self.fcw = nn.Linear(HIDDEN, self.sum_dim_weights, bias=True)

    def forward(self, x):

        x = x.view(-1, 1*32*32)

        x = self.bn1_(torch.relu((self.fc1_(x))))
        x = self.bn2_(torch.relu((self.fc2_(x))))
        #x = (torch.sigmoid((self.fc01_(x))))

        # x_m = self.bn1_(torch.sigmoid((self.fc1_(x))))
        # x_m = self.bn2_(torch.sigmoid((self.fc2_(x_m))))
        #x_m = self.bn3_(torch.relu((self.fc3_(x))))
        x_m = (torch.relu((self.fc33_(x))))

        # x_v = self.bn1(torch.sigmoid((self.fc1(x))))
        # x_v = self.bn2(torch.sigmoid((self.fc2(x_v))))
        #x_v = self.bn3(torch.relu((self.fc3(x))))
        x_v = (torch.relu((self.fc33(x))))

        # x_w = self.bn1_w(torch.sigmoid((self.fc1_w(x))))
        # x_w = self.bn2_w(torch.sigmoid((self.fc2_w(x_w))))
        #x_w = self.bn3_w(torch.relu((self.fc3_w(x))))
        x_w = (torch.relu((self.fc33_w(x))))

        dim = self.dim 
        
        mean = torch.relu(self.fc5(x_m)).view(x.shape[0], dim, self.sum_dim_mean)
        mean = torch.mean(mean, 2)
        mean = torch.clamp(mean, 0, 1)

        variance = torch.relu(self.fc6(x_v)).view(x.shape[0], dim, self.sum_dim_var)
        variance = torch.mean(variance, 2)
        variance = torch.clamp(variance, 1e-3, 1)*0+1e-2
        #variance = torch.clamp(variance, 1e-6, 1e-1)
        
        # weights = torch.sigmoid(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        # weights = torch.mean(weights, 2)

        weights = torch.relu(self.fcw(x_w)).view(x.shape[0], 1, self.sum_dim_weights)
        weights = torch.mean(weights, 2)
        #weights = torch.clamp(weights, 1e-5)

        return mean

from torch.utils.data import Subset
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset

torch.cuda.set_device(1)

transform = transforms.Compose(
    [transforms.Resize((32,32)), transforms.ToTensor()])

bs = 100

# trainset = torchvision.datasets.MNIST(root='./MNIST/', train=True,
#                                         download=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs,
#                                           shuffle=True, num_workers=2)

# testset = torchvision.datasets.MNIST(root='./MNIST/', train=False,
#                                        download=True, transform=transform)
# testloader = torch.utils.data.DataLoader(testset, batch_size=bs,
#                                          shuffle=False, num_workers=2)

trainset = torchvision.datasets.FashionMNIST(root='./fashion-mnist/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.FashionMNIST(root='./fashion-mnist/', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=bs,
                                         shuffle=False, num_workers=2)

nc = 1
ndf = 32
ngf = 64

beta = 0.9
v_t = 0.

beta_2 = 0.999
v_t_D = 0.

beta_3 = 0.9
v_t_g = 0.

beta_4 = 0.999
v_t_D_g = 0.

entropy_list = []

i = 1

SEED = 4
torch.manual_seed(SEED)
np.random.seed(SEED)

rand = 100
output_dim = 32*32*3

dim = 1

d_class = 10
how_many = 10

def uniform(stdev, size):
    return np.random.uniform(
                low=-stdev * np.sqrt(3),
                high=stdev * np.sqrt(3),
                size=size
            ).astype('float32')

def initialize_conv(m,he_init=True):
    fan_in = m.in_channels * m.kernel_size[0]**2
    fan_out = m.out_channels * m.kernel_size[0]**2 / (m.stride[0]**2)

    if m.kernel_size[0]==3:
        filters_stdev = np.sqrt(4./(fan_in+fan_out))
    # Normalized init (Glorot & Bengio)
    else: 
        filters_stdev = np.sqrt(2./(fan_in+fan_out))
        
    filter_values = uniform(
                    filters_stdev,
                    (m.kernel_size[0], m.kernel_size[0], m.in_channels, m.out_channels)
                )
    
    return filter_values

def initialize_linear(m):
    weight_values = uniform(
                np.sqrt(2./(m.in_features+m.out_features)),
                (m.in_features, m.out_features)
            )
    return weight_values

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        weight = torch.from_numpy(initialize_conv(m))
        #m.weight.data.copy_(weight.permute(3,2,1,0))
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        weight_values = torch.from_numpy(initialize_linear(m))
        m.weight.data.copy_(weight_values.transpose(0,1))
        m.bias.data.fill_(0)

gen_net = Generator(d_class*how_many+rand).cuda()
dis_net = Discriminator().cuda()

# gen_net.apply(weights_init)
# dis_net.apply(weights_init)

optimizer_g = optim.Adam([
            {'params': gen_net.parameters(), 'lr': 1e-5, 'betas': (0.5, 0.999)},
        ])

optimizer_d = optim.Adam([
    
            {'params': dis_net.parameters(), 'lr': 1e-5, 'betas': (0.5, 0.999)},
        ])

torch.manual_seed(SEED)
# trainloader = torch.utils.data.DataLoader(new_dataset, batch_size=10000,
#                                           shuffle=True, num_workers=2)
dataiter = iter(trainloader)
#####
fixed_input, _ = dataiter.next()
x = fixed_input.numpy()[:2]

iters = 40000
iter_d = 1
iter_g = 1

fixed_noise_ = torch.cat((torch.rand(bs, rand).cuda(), sample_discrete_class(bs, how_many, d_class).cuda()), 1) 
vtrack = []
vdtrack = []

for j in range(1, 5555555):
    
    k = 0
    
    while 1:
        for param in gen_net.parameters():
          param.requires_grad = True
        for param in dis_net.parameters():
          param.requires_grad = True
        
        #gen_net.eval()

        # sample inputs
        uniform_vector_3 = torch.cat((torch.rand(bs, rand).cuda(), sample_discrete_class(bs, how_many, d_class).cuda()), 1) 
        
        try:
            input, labels = dataiter.next()
            input = input.float().cuda()
        except:
            dataiter = iter(trainloader)
            input, labels = dataiter.next()
            input = input.float().cuda()
        
        x_generated = gen_net(uniform_vector_3)
        
        M_G_1 = dis_net(x_generated)
        M_D_1 = dis_net(input)

        top = torch.mean(M_G_1)
        down = torch.mean(M_D_1**2)
        
        v_t, mean_cons = adaptive_estimation(v_t, 0.9, top, i)
        v_t_D, quadratic_cons = adaptive_estimation(v_t_D, 0.999, down, i)
        
        vtrack.append(v_t.item())
        vdtrack.append(v_t_D.item())
        corr_ = top/torch.sqrt(quadratic_cons) - 0.5*mean_cons*down/torch.sqrt(quadratic_cons**3)
        
        ((-corr_)).backward()
                        
        for p in gen_net.parameters():
            p.grad = -p.grad  # or whatever other operation

        optimizer_g.step()
        optimizer_d.step()
        optimizer_g.zero_grad()
        optimizer_d.zero_grad()

        i+=1
        iter_d+=1
        k+=1
        
        if k>1:
            break

    entropy_list.append((mean_cons/torch.sqrt(quadratic_cons)).item())
    
    if j%10 == 0:
        print(j, entropy_list[-1])
        
        plt.rcParams["figure.figsize"] = [4,4]

        x_generated = gen_net(fixed_noise_)
        vutils.save_image(x_generated[0:100, :3, :, :].data, './fake_fmnist.png', nrow=10, normalize=True)
        img = mpimg.imread('fake_fmnist.png')
        plt.imshow(img)
        plt.show()

        np.save('./model/model_gan_SGM', gen_net)
        np.save('./model/model_gan_d_SGM', dis_net)
        
        plt.rcParams["figure.figsize"] = [4,4]
        
        plt.title('Density Ratio')
        plt.hist(M_G_1.detach().cpu().view(-1), bins=100, label='Generated data')
        plt.hist(M_D_1.detach().cpu().view(-1), bins=100, label='Real data')
        plt.legend()
        plt.show()