In [1]:
import numpy as np
import math
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from PIL import Image
import scipy.stats as stats
import scipy as sp
from scipy import ndimage
import torch
import multiprocessing
import cv2
import os
import tensorflow as tf
import gc
import time
import random
np.random.seed(5340)
random.seed(5340)


In [None]:
from google.colab import drive

# This will prompt for authorization.
drive.mount('/content/drive')

In [None]:
multiprocessing.cpu_count()

In [2]:

class ImageLoader():
  def __init__(self, image_dir, file_extension, loaded_batch = 100):
      self.batch = loaded_batch
      self.image_dir = image_dir
      self.image_names = []
      self.load_image_names(file_extension)

  def load_image_names(self, file_extension):
      self.image_names.clear()
      for filename in os.listdir(self.image_dir):
        if file_extension in filename:
          self.image_names.append(filename)
      print("Num of Files : ", len(self.image_names))

  def next_loading_images(self):
      n_len = len(self.image_names)
      if n_len == 0:
        print("all images are loaded")
        return []

      loaded_images = []
      if n_len <= self.batch:
        for filename in self.image_names:
          img = cv2.imread(os.path.join(self.image_dir,filename))
          img = np.moveaxis(img, -1, 0)
          if img is not None:
            loaded_images.append(img/255.0)
        self.image_names.clear()
      else:
        indices = np.random.choice(n_len, self.batch, replace=False)
        del_list = []
        for idx in indices:
          filename = self.image_names[idx]
          del_list.append(filename)
          img = cv2.imread(os.path.join(self.image_dir,filename))
          img = np.moveaxis(img, -1, 0)
          if img is not None:
            loaded_images.append(img/255.0)
        
        for del_item in del_list:
          self.image_names.remove(del_item)
        
      return loaded_images

        
  



In [3]:
class Gaussian():
	def __init__(self, mean=0, std=0.1):
		self.mean = mean
		self.std = std

	def initialize(self, size):
		return np.random.normal(self.mean, self.std, size=size)

In [4]:
# This cell contains helper functions
# img2col -> convert conv operations to matrix product
# conv -> main convolution class
# conv2d -> wrapper around conv
def img2col(data, h_indices, w_indices, k_h, k_w):
    """
    Convert convolution operation into a matrix product / helper for convolutions
    """
    batch = data.shape[0]
    ###############################################
    _, c, w, h = data.shape
    h_len = len(h_indices)
    w_len = len(w_indices)
    
    def func(x):
        col = 0
        image2col_X = np.zeros((c*k_h*k_w, w_len*h_len))
        for h_i in h_indices:
            for w_i in w_indices:
                image2col_X[:, col] = x[0:c, h_i:h_i + k_h, w_i:w_i + k_w].flatten()
                col += 1
                
        return image2col_X
    
    out = np.stack(map(func, data), axis=0)
    ###############################################
    return out


class conv():
    def __init__(self, conv_params):
        """
        # Arguments
            conv_params: dictionary, containing these parameters:
                'kernel_h': The height of kernel.
                'kernel_w': The width of kernel.
                'stride': The number of pixels between adjacent receptive fields in the horizontal and vertical directions.
                'pad': The total number of 0s to be added along the height (or width) dimension; half of the 0s are added on the top (or left) and half at the bottom (or right). we will only test even numbers.
                'in_channel': The number of input channels.
                'out_channel': The number of output channels.
        """
        self.conv_params = conv_params

    def forward(self, input, weights):
        """
        # Arguments
            input: numpy array with shape (batch, in_channel, in_height, in_width)
            weights: numpy array with shape (out_channel, in_channel, kernel_h, kernel_w)
        # Returns
            output: numpy array with shape (batch, out_channel, out_height, out_width)
        """
        kernel_h = self.conv_params['kernel_h']  # height of kernel
        kernel_w = self.conv_params['kernel_w']  # width of kernel
        pad = self.conv_params['pad']
        stride = self.conv_params['stride']
        in_channel = self.conv_params['in_channel']
        out_channel = self.conv_params['out_channel']

        batch, in_channel, in_height, in_width = input.shape
        out_height = 1 + (in_height - kernel_h + pad) // stride
        out_width = 1 + (in_width - kernel_w + pad) // stride
        output = np.zeros((batch, out_channel, out_height, out_width))

        pad_scheme = (pad//2, pad - pad//2)
        input_pad = np.pad(input, pad_width=((0,0), (0,0), pad_scheme, pad_scheme),
                           mode='constant', constant_values=0)

        # get initial nodes of receptive fields in height and width direction
        recep_fields_h = [stride*i for i in range(out_height)]
        recep_fields_w = [stride*i for i in range(out_width)]
        input_conv = img2col(input_pad, recep_fields_h,
                             recep_fields_w, kernel_h, kernel_w)
        output = np.stack(map(
            lambda x: np.matmul(weights.reshape(out_channel, -1), x), input_conv), axis=0)
        
        output = output.reshape(batch, out_channel, out_height, out_width)
        return output

    def backward(self, out_grad, input, weights):
        """
        # Arguments
            out_grad: gradient to the forward output of conv layer, with shape (batch, out_channel, out_height, out_width)
            input: numpy array with shape (batch, in_channel, in_height, in_width)
            weights: numpy array with shape (out_channel, in_channel, kernel_h, kernel_w)
        # Returns
            in_grad: gradient to the forward input of conv layer, with same shape as input
            w_grad: gradient to weights, with same shape as weights
        """
        kernel_h = self.conv_params['kernel_h']  # height of kernel
        kernel_w = self.conv_params['kernel_w']  # width of kernel
        pad = self.conv_params['pad']
        stride = self.conv_params['stride']
        in_channel = self.conv_params['in_channel']
        out_channel = self.conv_params['out_channel']

        batch, in_channel, in_height, in_width = input.shape
        out_height = 1 + (in_height - kernel_h + pad) // stride
        out_width = 1 + (in_width - kernel_w + pad) // stride

        pad_scheme = (pad//2, pad - pad//2)
        input_pad = np.pad(input, pad_width=((0,0), (0,0), pad_scheme, pad_scheme),
                           mode='constant', constant_values=0)
                           
        # get initial nodes of receptive fields in height and width direction
        recep_fields_h = [stride*i for i in range(out_height)]
        recep_fields_w = [stride*i for i in range(out_width)]

        ########################################
        input_conv = img2col(input_pad, recep_fields_h, recep_fields_w, kernel_h, kernel_w)
        dX = np.stack(map(
            lambda x: np.matmul(weights.reshape(out_channel, -1).T, x), out_grad.reshape(batch, out_channel, -1)), axis=0)
                
        dX_pad = np.zeros(input_pad.shape)
        col = 0
        for h in recep_fields_h:
            for w in recep_fields_w:
                block = dX[:, :, col].reshape(batch, in_channel, kernel_h, kernel_w)
                dX_pad[:, :, h : h + kernel_h, w: w + kernel_w] += block
                col+=1
                
        in_grad = dX_pad[:, :, pad_scheme[0]: pad_scheme[0] + in_height, pad_scheme[0]:pad_scheme[0] + in_width] 
           
        w_grad = np.stack(map(lambda x, y: np.matmul(x.reshape(out_channel, -1), y.T).reshape(weights.shape), out_grad, input_conv), axis=0)
        w_grad = np.sum(w_grad, axis = 0)
        
        ###############################################

        return in_grad, w_grad



class Conv2D():
    def __init__(self, conv_params, initializer=Gaussian(), name='conv'):
        """Initialization
        # Arguments
            conv_params: dictionary, containing these parameters:
                'kernel_h': The height of kernel.
                'kernel_w': The width of kernel.
                'stride': The number of pixels between adjacent receptive fields in the horizontal and vertical directions.
                'pad': The total number of 0s to be added along the height (or width) dimension; half of the 0s are added on the top (or left) and half at the bottom (or right). we will only test even numbers.
                'in_channel': The number of input channels.
                'out_channel': The number of output channels.
            initializer: Initializer class, to initialize weights
        """
        self.conv_params = conv_params
        self.conv = conv(conv_params)

        self.trainable = True

        self.weights = initializer.initialize(
            (conv_params['out_channel'], conv_params['in_channel'], conv_params['kernel_h'], conv_params['kernel_w']))
        # self.bias = np.zeros((conv_params['out_channel']))

        self.w_grad = np.zeros(self.weights.shape)
        # self.b_grad = np.zeros(self.bias.shape)

    def forward(self, input):
        output = self.conv.forward(input, self.weights)
        return output

    def backward(self, out_grad, input):
        in_grad, self.w_grad = self.conv.backward(out_grad, input, self.weights)
        return self.w_grad

    def update(self, d_weights):
        """Update parameters (self.weights and self.bias) with new params
        # Arguments
            params: dictionary, one key contains 'weights' and the other contains 'bias'
        # Returns
            none
        """
        self.weights = self.weights + d_weights

    def getweights(self):
        return self.weights

    def loadweights(self, weights):
        self.weights = weights



In [5]:
class unormalized_distribution():
  def __init__(self, energy):
    self.energy = energy
  def forward(self,x):
    x = np.exp(-self.energy.forward(x))
    return x

In [6]:
# The main  Contrastive Divergence class
class energy(nn.Module):
  def __init__(self, expert_num, filter_size):

    """
    Initialized the experts and the filter parameters
        # Arguments
            expert_num: the number of experts in FOE
            filter_size: the filter size
        # Returns
            None
    """
    super().__init__()
    self.expert_num = expert_num
    params = { 
      'kernel_h': filter_size[0],
      'kernel_w': filter_size[1],
      'pad': 0,
      'stride': 1,
      'in_channel': 3,
      'out_channel': self.expert_num,
    }
    self.conv = Conv2D(params)
    self.alpha = np.ones((self.expert_num, 1), dtype=float)
    self.P = None
  
    # self.conv1 = nn.Conv2d(in_channels=1, out_channels=self.expert_num, kernel_size=3, bias=False)
    # self.conv = tf.keras.layers.Conv2D(self.expert_num, (3, 3), input_shape=(28, 28, 1))

  def forward(self, x):
    """
    Function to compute E_FoE, once parameters are learned
        # Arguments
            x: image sample
        # Returns
            returns E_FoE
    """
    out1 = self.conv.forward(x)
    b,c,h,w = out1.shape
    out2 = out1.reshape(c,b,h,w).reshape(c, b*h*w)
    # out2 = out2.reshape(c, b*h*w)
    out3 = self.alpha*np.log(1 + out2**2/2)
    out3 = out3.reshape(c,b,h,w).reshape(b,c,h,w).reshape(b,c*h*w)
    v = np.sum(out3, axis = 1)
    return v

  def gradient(self, x):
    """
    Compute derivative of E_Foe w.r.t to Theta 
    In this case, theta consist of a (alpha - expert parameters), and J (kernel - filter parameters)
        # Arguments
            x: image sample
        # Returns
            d_j: derivative of E_FoE w.r.t J
            d_alpha: derivative of E_FoE w.r.t alpha
    """

    # Generate conv_out - Jx
    conv_out = self.conv.forward(x)
    b,c,h,w = conv_out.shape
    conv_out = conv_out.reshape(c,b,h,w)
    conv_out = conv_out.reshape(c, b*h*w)

    # Get d_alpha
    d_alpha = np.log(1 + conv_out**2/2)
    d_alpha = np.sum(d_alpha,axis = 1)
    d_alpha = d_alpha.reshape(-1, 1)

    # d_conv_out represents derivative of E_FoE w.r.t f_i(j_i) or Jx
    d_conv_out = self.alpha * conv_out/(1 + conv_out**2/2)
    d_conv_out = d_conv_out.reshape(c,b,h,w)
    d_conv_out = d_conv_out.reshape(b,c,h,w)

    # Get d_j
    d_j = self.conv.backward(d_conv_out, x)

    return d_j, d_alpha

  def update(self, grad1, grad2, lr= 0.0001):
    """
    Update expert parameters and filter parameters according to grad1 and grad2
        # Arguments
            grad1: first term in contrastive divergence equation
            grad2: second term in contrastive divergence equation
            lr: learning rate
        # Returns
            none
    """

    # updating the parameters
    self.conv.update(lr*(grad1[0] - grad2[0]))
    self.alpha = self.alpha + lr*(grad1[1] - grad2[1])

  def print(self):
    weights = self.conv.getweights()
    print(weights)
    print(self.alpha)

  def save(self):
    weights = self.conv.getweights()
    state_dict = {'weights':weights, 'alpha':self.alpha}
    # state_dict = {'alpha':self.alpha}
    torch.save(state_dict, "energy_model.pth")

  def load(self):
    state = torch.load("energy_model.pth")
    self.alpha = state['alpha']
    weights = state['weights']
    self.conv.loadweights(weights)
    print("load weights :" , weights)
    print("load alpha :", self.alpha)

  def set_P(self, P):
    self.P = P

  # def sample(self, image, N=1,T=5):
  #   # phat = lambda x : stats.uniform.pdf(x)
  #   # qpdf = lambda x, mu, sig : stats.norm.pdf(x, mu, sig)
  #   # qsample = lambda mu, sig : stats.norm.rvs(mu,sig)

  #   chanel, heigh, width =  image.shape
  #   result = image
  #   n_len  = chanel*heigh*width

  #   for n in range (N):
  #       for t in range(T):
  #         sampled_image = stats.uniform.rvs(0,1,n_len).reshape(chanel, heigh, width)
  #         sampled_image_ex = np.expand_dims(sampled_image,axis = 0)
  #         score_1 = self.P.forward(sampled_image_ex)[0]

  #         result_image_ex = np.expand_dims(result,axis = 0)
  #         score_2 = self.P.forward(result_image_ex)[0]
  #         u = np.random.uniform(0,1)
  #         if u < min(1, np.nan_to_num((1./score_2 - 1.)/(1./score_1 - 1.))):
  #           result = sampled_image
  #   return result
  def sample(self, image_batch, N=1,T=5):
    # phat = lambda x : stats.uniform.pdf(x)
    # qpdf = lambda x, mu, sig : stats.norm.pdf(x, mu, sig)
    qsample = lambda mu, sig, size : stats.norm.rvs(mu,sig,size)
    batch, chanel, heigh, width =  image_batch.shape
    result = []
    for b in range(batch):
      image = image_batch[b]
      x = image
      n_len  = chanel*heigh*width
      for n in range (N):
        for t in range(T):
          sig = ndimage.standard_deviation(x)
          mu = ndimage.mean(x)
          sampled_image = qsample(mu,sig,n_len).reshape(chanel, heigh, width)
          sampled_image_ex = np.expand_dims(sampled_image,axis = 0)
          score_1 = self.P.forward(sampled_image_ex)[0]

          x_image_ex = np.expand_dims(x,axis = 0)
          score_2 = self.P.forward(x_image_ex)[0]
          u = np.random.uniform(0,1)
          if u < min(1, np.nan_to_num(score_1/score_2)):
            x = sampled_image
        result.append(x)
    return result



In [None]:
# import matplotlib.pyplot as plt
# import matplotlib.image as mpimg
# plt.imshow(rimage[0])



In [None]:
# x = tf.image.random_crop(images[0], size=[3, 15, 15,])

# import matplotlib.pyplot as plt
# import matplotlib.image as mpimg
# plt.imshow(x[0])


In [7]:

def randomCroppedImageSelection(images, sample):
  N = len(images)
  indices = np.random.choice(N, sample, replace=False)
  random_cropped_images = []
  for i in indices:
    x = tf.image.random_crop(images[i], size=[3, 15, 15])
    random_cropped_images.append(x)
  random_cropped_images = np.array(random_cropped_images)
  return random_cropped_images

In [8]:
energy_model = energy(32,(5,5))
# energy_model.load()


In [None]:
# energy_model = energy(16,(5,5))

image_loader = ImageLoader('/content/data','png', 200 )
iter = 5000
batch_size = 64


In [None]:
load_training_images = []
gc.collect()

# Train with N =1, T = 5

In [None]:
load_training_images = image_loader.next_loading_images()
N = len(load_training_images)

while N > 0:
  test_image = randomCroppedImageSelection(load_training_images, 10)
  print("-----forward N : ", N, "energy : ", energy_model.forward(test_image))

  # energy_model.sample(np.zeros((3,10,10)), P , 100, 1, True)
  count_iter = 0
  for i in range (iter):
    P = unormalized_distribution(energy_model) 
    energy_model.set_P(P)

    batch_cropped_images = randomCroppedImageSelection(load_training_images, batch_size)
    dw_x, da_x = energy_model.gradient(batch_cropped_images)
    dw_x = dw_x/batch_size
    da_x = da_x/batch_size

    # batch_cropped_images = randomCroppedImageSelection(load_training_images, batch_size)
    # pool = multiprocessing.Pool()
    # pool = multiprocessing.Pool(processes=4)
    # sampled_images = pool.map(energy_model.sample, batch_cropped_images)
    sampled_images = energy_model.sample(batch_cropped_images)
    sampled_images = np.array(sampled_images)
    # sampled_images = energy_model.sample(batch_cropped_images, P, 1, 2)
    dw_p, da_p = energy_model.gradient(sampled_images)
    dw_p = dw_p/batch_size
    da_p = da_p/batch_size

    d_x = (dw_x, da_x)
    d_y = (dw_p, da_p)
    energy_model.update(d_y, d_x, 0.001)
    print("-----forward citer : ", count_iter, "energy : ", energy_model.forward(test_image))
    count_iter += 1
    sampled_images = None
    gc.collect()

  load_training_images.clear()
  gc.collect()
  print("--------------------------------------------------------------------------")
  load_training_images = image_loader.next_loading_images()
  N = len(load_training_images)

energy_model.save()


In [None]:
print("128 th")
energy_model.print()


In [None]:
!cp 'energy_model.pth' '/content/drive/My Drive/AI' 

14925882.938054346


# Train with N = 2, T = 2

In [9]:
image_loader = ImageLoader('/content/data','png', 200 )
iter = 1000
batch_size = 64

Num of Files :  800


In [10]:
load_training_images = image_loader.next_loading_images()
N = len(load_training_images)
n_th = 0
while N > 0:
  test_image = randomCroppedImageSelection(load_training_images, 10)
  print("-----forward N : ", N, "energy : ", energy_model.forward(test_image))

  # energy_model.sample(np.zeros((3,10,10)), P , 100, 1, True)
  count_iter = 0
  for i in range (iter):
    P = unormalized_distribution(energy_model) 
    energy_model.set_P(P)

    batch_cropped_images = randomCroppedImageSelection(load_training_images, batch_size)
    dw_x, da_x = energy_model.gradient(batch_cropped_images)
    dw_x = dw_x/batch_size
    da_x = da_x/batch_size

    # batch_cropped_images = randomCroppedImageSelection(load_training_images, batch_size)
    # pool = multiprocessing.Pool()
    # pool = multiprocessing.Pool(processes=4)
    # sampled_images = pool.map(energy_model.sample, batch_cropped_images)
    sampled_images = energy_model.sample(batch_cropped_images,2,2)
    sampled_images = np.array(sampled_images)
    # sampled_images = energy_model.sample(batch_cropped_images, P, 1, 2)
    dw_p, da_p = energy_model.gradient(sampled_images)
    dw_p = dw_p/len(sampled_images)
    da_p = da_p/len(sampled_images)

    d_x = (dw_x, da_x)
    d_y = (dw_p, da_p)
    energy_model.update(d_y, d_x, 0.001)
    print("-----forward N : ", n_th, " sample len:",len(sampled_images), " citer : ", count_iter, "energy : ", energy_model.forward(test_image))
    count_iter += 1
    sampled_images = None
    gc.collect()

  load_training_images.clear()
  gc.collect()
  print("--------------------------------------------------------------------------")
  load_training_images = image_loader.next_loading_images()
  N = len(load_training_images)
  n_th += 1
energy_model.save()




-----forward N :  200 energy :  [293.19166696 361.3758298   24.29000769  38.73433362 415.23914988
 312.01707887 353.21678496 105.22358762   5.72424113 342.78951308]




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
-----forward N :  2  sample len: 128  citer :  3 energy :  [11.65202689 16.16765767 42.00839882  3.75607406  6.36928026  2.99144319
  6.05070281  0.34491474  1.73646127 30.98241248]
-----forward N :  2  sample len: 128  citer :  4 energy :  [11.7277213  16.2314449  41.99258033  3.78212777  6.46687904  2.99481377
  6.18538301  0.34482274  1.73585672 31.09257158]
-----forward N :  2  sample len: 128  citer :  5 energy :  [11.69041411 16.21124688 41.97850978  3.77321949  6.43431601  2.99285264
  6.12550255  0.34469     1.73925798 31.01875626]
-----forward N :  2  sample len: 128  citer :  6 energy :  [11.59888    16.14770488 41.85758729  3.59666381  6.14349146  2.98519059
  5.93047776  0.34441764  1.74597077 31.14076898]
-----forward N :  2  sample len: 128  citer :  7 energy :  [11.53258701 16.07478465 41.77953704  3.47596447  5.95607481  2.9757675
  5.77754191  0.344666    1.74899129 31.18450736]
-----forward N :  2  sampl