In [None]:
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from models import *

Next, we load the pretrained weights for each network.

In [None]:
import os
from pathlib import Path

gen_state_dict = torch.load(
    'best_gan_weights/mnist_gen_epoch_91.pth',
    map_location=device,
)
d_state_dict = torch.load(
    'best_gan_weights/mnist_dis_epoch_91.pth',
    map_location=device,
)
c_state_dict = torch.load(
    'best_gan_weights/classifier_mnist.pth',
    map_location=device,
)

ngpu = 1
nz = 100
ngf = 64
ndf = 64
nc = 1

generator = Generator(ngpu).to(device)
generator.load_state_dict(gen_state_dict)

discriminator = Discriminator(ngpu).to(device)
discriminator.load_state_dict(d_state_dict)

classifier = Network().to(device)
classifier.load_state_dict(c_state_dict)

In [None]:
from constellation_noise import *

In [None]:
location = 'mnist/6.jpg'
noise_dotted = constellation_create(location,15,1,0.003)


In [None]:
import matplotlib.image as mpimg
import scipy.ndimage as ndi

mass = mpimg.imread(location)

def get_position(np_array):
    return ndi.center_of_mass(np_array)

get_position(mass)

In [None]:
def pass_dots_loss(image_noise_dots, image):
  noise_dots = stimuli_dots(image_noise_dots)
  gray = drawing_figure(image)

  d_img = points_on_image(gray,noise_dots,10)
  loss_1 = len(d_img)/len(noise_dots)
  return loss_1
from tqdm import tqdm

from math import atan2, cos, sin, sqrt, pi
import numpy as np

def get_angle(numpy_array):
    _, bw = cv2.threshold(numpy_array, 50, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)

    contours, _ = cv2.findContours(bw, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
    needed_angle = 0
    biggest_area = 0
    for i, c in enumerate(contours):
      area = cv2.contourArea(c)
      if area > biggest_area:
        biggest_area = area
        rect = cv2.minAreaRect(c)
        center = (int(rect[0][0]),int(rect[0][1])) 
        width = int(rect[1][0])
        height = int(rect[1][1])
        angle = int(rect[2])
        if width < height:
            angle = 90 - angle
        else:
            angle = -angle
        needed_angle = angle
    return needed_angle
#get_angle(img_array)

In [None]:
import torch.nn.functional as F

def get_prediction(imgs, classifier):
    return torch.max(classifier(imgs), 1).indices.cpu().numpy().reshape(150,1)

In [None]:
def denorm(img_tensors):
    return img_tensors * 2 -1 
def norm(img_tensors):
    return (img_tensors+ 1)/2 


In [None]:
import glob
imagess = glob.glob('mnist_together/*')


In [None]:
def inference_image_generator(immage,distance,times):

    archive = SlidingBoundariesArchive(
    dims = [200,60,6],
    ranges = [(0,500),(0.1,0.6),(0,9)],  # boldness, pass_dots, class.
)

    emitters = [
    ImprovementEmitter(
        archive,
        np.zeros(100),
        0.1,
        batch_size=30,
    ) for _ in range(5)
]

    optimizer = Optimizer(archive, emitters)
    #total_itrs = 200
    flat_img_size = 64*64  # 28 * 28
    start_time = time.time()
    class_scores = []
    values = []
    images_through_training = []
    noised = constellation_create(immage,distance,1,0.003)

    total_itrs = times
    for itr in range(1, total_itrs + 1):
        rotation_angles = []
        sols = optimizer.ask()
        with torch.no_grad():
            tensor_sols = torch.tensor(
                sols,
                dtype=torch.float32,
                device=device,
                    )
            tensor_sols = tensor_sols.unsqueeze(-1).unsqueeze(-1)
            generated_imgs = generator(tensor_sols)
            classes = get_prediction(generated_imgs, classifier)
            digit_realness_scores = discriminator(generated_imgs).detach().cpu().numpy()
            normalized_imgs = (generated_imgs + 1.0) / 2.0

            dot_loss_scores = []
            for img in generated_imgs:
              t = img[0].cpu().numpy()*255
              t = t.astype('uint8')
              score=pass_dots_loss(noised,t)
              angle = get_angle(t)
              dot_loss_scores.append(score)
              rotation_angles.append([angle])
            dot_loss_scores = np.array(dot_loss_scores,dtype = np.float32)

            flattened_imgs = normalized_imgs.cpu().numpy().reshape((-1, flat_img_size))
            boldness = np.count_nonzero(flattened_imgs >= 0.5,axis=1,keepdims=True)
            rotations = np.array(rotation_angles)
            objs = 0.9*digit_realness_scores + 0.1*dot_loss_scores 
            dot_loss_scores = dot_loss_scores.reshape(150,1)

            bcs = np.concatenate([boldness,dot_loss_scores,classes], axis=1)
        optimizer.tell(objs, bcs)
        
    new_archive = archive.as_pandas().sort_values(by=['objective'], ascending = False)
    array = np.array(new_archive.iloc[0, 7:])
    tensor_sols = torch.tensor(
            array,
            dtype=torch.float32,
            device=device,
        )
    tensor_sols = tensor_sols.unsqueeze(-1).unsqueeze(-1).unsqueeze(0)

    output = generator(tensor_sols)
    
    return noised,(output[0],new_archive.iloc[0,6])

In [None]:
import glob
imagess = glob.glob('mnist_together/*')

noised_images = []
top_image = []
for img in tqdm(range(len(imagess))):
    noised, tuple_image_score = inference_image_generator(imagess[img],15,35)
    top_image.append(tuple_image_score)
    noised_images.append(noised)

In [None]:
imgs_drawn = []

def drawing_figure(image):
    gray = image 
    gray = cv2.resize(gray,(800,800))
    th, threshed = cv2.threshold(gray, 40, 255,cv2.THRESH_BINARY)
    cnts = cv2.findContours(threshed, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)[-2]
    sc =[]
    for ind,c in enumerate(cnts):
        area = cv2.contourArea(c)
        #print(area)
        if(area<25000): continue
        sc.append(c.copy())
    draw_img = np.zeros(gray.shape)
    for s in sc:
        for pt in s:
            draw_img[pt[0][1],pt[0][0]] = 255
    return draw_img, sc
for i in range(len(top_image)-1):
    noise_dotted_for_drawing = cv2.resize(noised_images[i], dsize=(800,800))
    t = top_image[i][0][0].detach().cpu().numpy()*255
    t = t.astype('uint8')
    contoured_img, hh = drawing_figure(t) 
    #plt.imshow(contoured_img)
    im2 = cv2.drawContours(noise_dotted_for_drawing, hh, -1, (255, 255, 255), 8)
    im2 = torch.tensor(im2).unsqueeze(0)
    #im2 = transform(im2)
    imgs_drawn.append(im2)#im2 = cv2.resize(im2, dsize=(120,120))


In [None]:
figsize=(8, 5)
from torchvision.utils import make_grid

fig = plt.figure(figsize=figsize)
img_grid1 = make_grid(imgs_drawn, nrow=10, padding=10,pad_value=255)
ax = plt.imshow(np.transpose(img_grid1.detach().cpu().numpy(), (1, 2, 0)),cmap='gray')
plt.axis('off')
fig.savefig('drawn_gan.png', bbox_inches='tight', pad_inches=0)