In [0]:
import numpy as np
from collections import deque
import tensorflow as tf
from tensorflow.python.keras import models, layers, Sequential, optimizers, metrics
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from keras.datasets import mnist
from keras import backend as K
import skimage.util as sk
from skimage.util.shape import view_as_blocks
import csv

##Data creation

In [0]:
data = np.load('mnist_perso_train.npy')
data = np.reshape(data, (60000, 27, 27))

data_size = 30000
global_count = 0

In [0]:
def getRealImg(number, step):
  listImg = []
  for i in range(number):
    img = np.array(data[np.random.randint(data_size)+data_size])    
    listImg.append(img.reshape(-1))
  return np.array(listImg)

def getTrainingD(generatedImg, batchsize, step):
  trainDataD = np.concatenate((generatedImg, getRealImg(batchsize, step)))
  trainLabelsD = np.concatenate((np.zeros(batchsize), np.ones(batchsize)))
  return trainDataD, trainLabelsD

#Model

In [0]:
p=9 #positions
f=9 #fragments
d=81 #dim encodage fragment
D=f*d #dim encodage F

##Generative Custom layer

In [0]:
class MagicLayer(layers.Layer):
    def __init__(self, f, d, D):
        super(MagicLayer, self).__init__()
        self.p = 9
        self.f = f
        self.d = d
        self.D = D
    
    def build(self, input_shape):
        self.kernel = self.add_weight("kernel", shape=[p, self.d, self.D], trainable=True)
        self.bias = self.add_weight("bias", shape=[self.D], trainable=True)
    
    def call(self, input):
        assert isinstance(input, list)
        A, B = input        
        x = tf.einsum('pdD,bfd->bfpD', self.kernel, A)
        B_biased = tf.add(B, self.bias)
        x = tf.einsum('bD,bfpD->bfp', B_biased, x)
        x = tf.layers.flatten(x)#flatten
        return tf.nn.leaky_relu(x, 0.1)

    def compute_output_shape(self, input_shape):
        assert isinstance(input_shape, list)
        return (input_shape[0], self.D)

##Generative Network

In [0]:
def get_fi_single_extractor_CNN():
  extractor_inputs = layers.Input(shape=(81,))
  fi_extractor_out = layers.Reshape((9, 9, 1))(extractor_inputs)
  fi_extractor_out = layers.Conv2D(32, (3, 3), use_bias=False, padding='same')(fi_extractor_out)
  fi_extractor_out = layers.MaxPooling2D((2, 2))(fi_extractor_out)
  fi_extractor_out = layers.Flatten()(fi_extractor_out)
  fi_extractor_out = layers.Dense(81, activation='relu')(fi_extractor_out)
  return models.Model(inputs=extractor_inputs, outputs=fi_extractor_out)

def get_fi_single_extractor_FC():
  extractor_inputs = layers.Input(shape=(81,))
  fi_extractor_out = layers.Dense(81, activation='relu')(extractor_inputs)
  return models.Model(inputs=extractor_inputs, outputs=fi_extractor_out)


def ccx(labels, output):
    return tf.sign(labels)*tf.square(output-labels)

def get_Gnetwork():
    
    #System inputs
    fi = layers.Input(shape=(f, d))
    F = layers.Input(shape=(D,))
    
    #noise = layers.GaussianNoise(0.1)
    magic = MagicLayer(f, d, D)
    
    #fragment extractor    
    fi_extractor = get_fi_single_extractor_CNN()
    fi_2 = layers.TimeDistributed(fi_extractor)(fi)
    
    #image extractor                                  
    F_2 = layers.Reshape((27, 27, 1))(F)
    F_2  = layers.Conv2D(32, (5, 5), activation='relu', padding='same')(F_2)
    F_2 = layers.MaxPooling2D((2, 2))(F_2)
    F_2 = layers.Flatten()(F_2)
    F_2 = layers.Dense(9*81, activation='relu')(F_2)
    
    output = magic([fi_2, F_2])
    #output = layers.Dropout(0.5)(output)

    network = models.Model(inputs=[fi, F], outputs=output)
    network.compile(optimizer=optimizers.RMSprop(lr=0.00001, rho=0.9, epsilon=None, decay=0.0), loss=ccx, metrics=['accuracy'])
    return network

##Discriminator

In [0]:
def getDnetwork():
  model = Sequential()
  model.add(layers.Reshape((27, 27, 1)))
  model.add(layers.Conv2D(32, (3, 3), activation='relu', padding='same'))
  model.add(layers.BatchNormalization())
  model.add(layers.MaxPooling2D((2, 2)))
  model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
  model.add(layers.BatchNormalization())
  model.add(layers.MaxPooling2D((2, 2)))
  model.add(layers.Flatten())
  model.add(layers.Dense(256, activation='relu'))
  model.add(layers.Dense(1, activation='sigmoid'))
  model.compile(loss='mean_squared_error', optimizer=optimizers.RMSprop(lr=0.000005), metrics=['accuracy'])
  return model

##World

In [0]:
class World():
    def __init__(self,  inputImg, permutation=True):
        self.permutation = permutation
        self.inputImg = inputImg
        self.history = []
        self.fi_table = np.random.permutation(9)
        self.initialize()
    
    def initialize(self):
        self.t = 0
        self.F, self.F_mask = np.ones(27*27).reshape(27, 27)*(-0.1), np.full(9, False)
        
        #random selection
        if self.inputImg < 0:
          img = data[np.random.randint(data_size)]
        else:
          img = data[self.inputImg]
        
        self.fi = np.array(view_as_blocks(img, block_shape=(9, 9))).reshape(9, 81)
        
        if (self.permutation):
            self.fi = self.fi[self.fi_table]
            self.fi_mask = np.full(9, True)
            self.pi = self.fi.copy()
    
    def __repr__(self):
        return "<t:%d \nfi: \n%s\n%s \nF:\n%s\n%s \npi:\n%s>" % (self.t, self.fi, self.fi_mask, self.F, self.F_mask, self.pi)
    
    def legal_move(self, action):
        index_fi, index_F = action
        return self.fi_mask[index_fi] and not self.F_mask[index_F]
    
    def move(self, action):
        if not self.legal_move(action):
            return False
        
        index_fi, index_F = action
        self.t += 1
        
        #add the new fragment on the image
        self.F[9*(index_F%3):9*((index_F%3)+1) , 9*(index_F//3):9*((index_F//3)+1)] = self.fi[index_fi].reshape(9,9)
        
        self.F_mask[index_F] = True
        self.fi[index_fi] = -1
        self.fi_mask[index_fi] = False
        self.pi[index_fi] = -1
        self.pi[:,index_F] = 0
        self.history.append(action)

        
    def get_random_move(self):
        index_fi = np.random.choice(np.where(self.fi_mask == True)[0])
        index_F = np.random.choice(np.where(self.F_mask == False)[0])
        return (index_fi, index_F)
    
    def best_legal_move(self, P, rand_choice):
        h,w = np.shape(P)
        
        if self.inputImg < 0: #add noise for a better exploration on the training
          if np.random.random() < rand_choice:
            return self.get_random_move()
          
        P[np.where(self.fi_mask == False)[0]] = -1000
        P[:, np.where(self.F_mask == True)[0]] = -1000
        return np.unravel_index(P.argmax(), P.shape)
    
    def end_episode(self):
        return np.all(self.F_mask == True)
    
    def get_Fo(self):
        return self.F.reshape(-1).copy()
    
    def get_state(self):
        return self.fi.copy(), self.F.reshape(-1).copy(), self.pi.copy()
    
    def terminal_value(self):
        return np.array(sum([self.fi_table[action_fi]==action_F for action_fi, action_F in self.history]) / len(self.history))
      
    def first_frag(self):
      action_fi, action_F = self.history[0]
      if self.fi_table[action_fi]==action_F:
        return 1
      else:
        return 0

##Memory

In [0]:
class memoryV2():
  def __init__(self, max_size):
    self.size = max_size*9
    self.images = deque(maxlen = max_size*9)
  
  def addImg(self, fi, Fi, Fo, Mv): #mv is the position in the 81 probability vector of the choosen one
    if len(self.images) + 1 == self.size:
      del self.images[np.random.randint(self.size-1)]
    self.images.append({'fi':fi, 'Fi':Fi, 'Fo':Fo, 'mv':Mv[0]*9+Mv[1]})
  
  def getMv(self, step):
    return self.images[step]['mv']
  
  def getSample(self, number):
    return np.random.choice(self.images, number)

#execution generation and training



*   Generate episodes
*   Pick samples in this memory of episodes
*   Rate the images picked using the discriminator
*   Train both of the discriminator and the generator



In [0]:
def generateEpisode(network, memory, nb_step, rand_choice, testOn, inputImg):
    world  = World(inputImg)
    step = 0
    fi = []
    Fi = []
    action = []
    while(step < nb_step):
        tmpfi, tmpFi, pi = world.get_state()
        fi.append(tmpfi)
        Fi.append(tmpFi)
        P = network.predict([np.expand_dims(tmpfi, axis=0), np.expand_dims(tmpFi, axis=0)])
        if testOn > 1:
          print(P)
        action.append(world.best_legal_move(P.reshape(9,9), rand_choice))
        world.move(action[step])
        Fo = world.get_Fo()
        step = step + 1
    for i in range(step):
        memory.addImg(fi[i], Fi[i], Fo, action[i])
    if testOn > 0:
        print("image")
        plt.imshow(Fo.reshape(27,27).copy())
        plt.show()
    return world

In [0]:
def getTrainingG(sample, evaluation):
    trainLabels = []
    count = 0
    for i in sample:
      a = np.zeros(81)
      a[i['mv']] = evaluation[count]
      trainLabels.append(a)
      count = count + 1
    trainDatafi = np.array([d['fi'] for d in sample])
    trainDataFi = np.array([d['Fi'] for d in sample])
    return trainDatafi, trainDataFi, np.array(trainLabels)

## Init of the system

Put some episodes in the memory to start

In [0]:
#to play with
rand_choice = 0.5
batchsize = 32
maxEpisodes = batchsize*4
nb_step = 9 #number of steps in the generated episodes
memory1 = memoryV2(maxEpisodes)
generator = get_Gnetwork()
discriminator = getDnetwork()
for i in range(batchsize*2):
  generateEpisode(generator, memory1, nb_step, rand_choice, 0, -1)

##Training loop

In [0]:
D_out_loss = 0
D_out_acc = 0
G_out_loss = 0
G_out_acc = 0
solving = 0
solvingList = []
D_predict_error = 0
FoList = []
localsolving = 0
first_frag = 0

for i in range(1000000):
  solvingList.clear()
  FoList.clear()
  
  #lower the random exploration rate
  if i%100 == 0:
    rand_choice = rand_choice - 0.02
  
  #add new episodes
  for j in range(batchsize):
    world = generateEpisode(generator, memory1, nb_step, rand_choice, 0, -1)
    #getting some metrics
    solvingList.append(world.terminal_value())
    FoList.append(world.get_Fo())
    localsolving = localsolving + world.local_value()
    first_frag = first_frag + world.first_frag()
  
  # error between discriminator and groud truth
  D_predict_error = D_predict_error + (sum(abs(np.subtract(solvingList, discriminator.predict(np.array(FoList)).reshape(-1))))/batchsize)
  solving = solving + sum(solvingList)
        
  #get a sample
  sample = memory1.getSample(batchsize)
  generatedImg = np.array([d['Fo'] for d in sample])
  
  #generate trainingG
  trainDataGfi, trainDataGFi, trainLabelsG = getTrainingG(sample, discriminator.predict(generatedImg))
    
    
  #train discriminator
  trainDataD, trainLabelsD = getTrainingD(generatedImg, batchsize, nb_step)
  D_loss, D_acc = discriminator.train_on_batch(trainDataD, trainLabelsD)
  D_out_loss = D_out_loss + D_loss
  D_out_acc = D_out_acc + D_acc    
    
  #train generator
  G_loss, G_acc = generator.train_on_batch([trainDataGfi, trainDataGFi], trainLabelsG)
  G_out_loss = G_out_loss + G_loss
  G_out_acc = G_out_acc + G_acc
  
  if i%50 == 0:
    print("current generated episode" )
    print(i)
    plt.imshow(world.get_Fo().reshape(27, 27).copy())
    plt.show()
    print("global solving %:" + str(solving/(50*32)) + "     local solving: " + str(localsolving/(50*32)) + "     first frag: " + str(first_frag/(50*32)))
    print("discriminator:")
    print("loss: " + str(D_out_loss/50) + "      accuracy: " + str(D_out_acc/50) + "     accuracy2: " + str(D_predict_error/(50*32)))
    print("generator:")
    print("loss: " + str(G_out_loss/50) + "      accuracy: " + str(G_out_acc/50))
    
                                       
    #save in csv
    with open('/content/drive/My Drive/M2/Research project/output/MNIST_metric3V2.csv', mode='a') as solving_csv:
      solving_csv = csv.writer(solving_csv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
      solving_csv.writerow([i , solving/(50*32), localsolving/(50*32), first_frag/(50*32), D_out_acc/50, D_predict_error/(50*32), D_out_loss/50, G_out_acc/50, G_out_loss/50])
      
      
    #re init 
    first_frag = 0
    localsolving = 0
    solving = 0
    D_predict_error = 0
    D_out_loss = 0
    D_out_acc = 0
    G_out_loss = 0
    G_out_acc = 0

#tests

In [0]:
#test discriminator with real image => 1
print(discriminator.predict(getRealImg(10, nb_step)).reshape(-1))

In [0]:
#test discriminator with images from the generator => 0
sample = memory1.getSample(10)
generatedImg = np.array([d['Fo'] for d in sample])
print(discriminator.predict(generatedImg))

##Validation

In [0]:
score = 100
for i in range(100):
    test = generateEpisode(generator, memory1, nb_step, 0, 0, (i%1000)+100).get_Fo().reshape(27, 27).copy()
    #plt.imshow(data[i%10])
    #plt.show()
    result = np.sum(np.abs(test - data[(i+100)%data_size]))
    #print("result:")
    #print(result)
    if result > 10:
      score = score - 1
      if discriminator.predict(np.array([test])) > 0.45 :
        mpimg.imsave(str(i+1000) + "error.png", test)
    #elif i <20:
    #  mpimg.imsave("perfect/" +str(i) + "perfect.png", test)
print("score total de:")
print(score/100)

In [0]:
error = 1000
for i in range(1000):
    test = generateEpisode(generator, memory1, nb_step, 0, 0, (i)+10000).get_Fo().reshape(27, 27).copy()
    #plt.imshow(data[i%10])
    #plt.show()
    result = np.sum(np.abs(test - data[(i)+10000]))
    #print("result:")
    #print(result)
    if result > 10:
       error = error - 1
print("erreur total de:")
print(error/1000)