<a href="https://colab.research.google.com/github/ongsim0629/draw-animal/blob/main/draw_animal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import cv2, random, os, sys
import numpy as np
import urllib.request
from google.colab.patches import cv2_imshow
from copy import deepcopy
from skimage.metrics import mean_squared_error as compare_mse
import multiprocessing as mp
import re


#zip 파일 압축 해제
! unzip  -uq "/content/animal.zip" -d "/content/animal"

#데이터가 존재하는 동물 목록과 유의어 목록 생성
animalList = ['alligator', 'badger', 'bear','beaver','boar','bull','butterfly','camel','cat','chick','chicken','chimpanzee','cow','crow','deer','dog',
              'dolphin','donkey','dragonfly','duck','eagle','earthworm','elephant','elk','falcon','fly','fox','frog','gazelle','giraffe','goat','gold fish',
              'goose','gorilla','grampus','grasshopper','hamster','hedgehog','hippopotamus','horse','kangaroo','koala','ladybug','leopard','lion','lizard','lobster',
              'mole','monkey','moth','mouse','ostrich','otter','owl','parrot','peacock','penguin','pig','platypus','pigeon','polar bear','rabbit','racoon',
              'rhinoceros','scorpion','seal','shark','sheep','skunk','snail','snake','sparrow','spider','squirrel','swan','tiger','toad','turtle','whale',
              'wolf','zebra']
changeList = ['crocodile','wild pig','ox','kitten','rooster','ape','doe','puppy','porpoise','worm','moose','hawk','vixen','killer whale','locust','hippo',
              'koala bear','ladybird','rat','peahen','piglet','dove','raccoon','rhino','pup','lamb','serpent','chipmunk','tortoise']

#데이터 입력받기
myAnimal = input('Enter your fovorite animal: ')

#유의어 동물 목록 속 동물로 치환
def trans(myAnimal, translation):
    tr = re.compile('|'.join(map(re.escape, translation)))
    return tr.sub(lambda match: translation[match.group(0)], myAnimal)

#유의어가 입력되었을 경우 치환
if myAnimal in changeList:
  translation = {'crocodile':'alligator','wild pig':'boar','kitten':'cat','rooster':'chicken', 'ape':'chimpanzee', 'ox':'bull', 'doe':'deer', 'puppy':'dog',
                 'porpoise':'dolphin','worm':'earthworm','moose':'elk','hawk':'falcon','vixen':'fox','killer whale':'grampus','locust':'grasshopper',
                 'hippo':'hippopotamus','koala bear':'koala','ladybird':'ladybug','rat':'mouse','peahen':'peacock','piglet':'pig','dove':'pigeon',
                 'raccoon':'racoon','rhino':'rhinoceros','pup':'seal','lamb':'sheep','serpent':'snake','chipmunk':'squirrel','tortoise':'turtle'}
  translated = trans(myAnimal, translation)
  translated,myAnimal = myAnimal,translated

#데이터가 존재하는 동물이 입력되었을 때 - 그림 그리기 시작
if myAnimal in animalList:
  print("OK. Let's draw")
  filepath = f'/content/animal/animal/{myAnimal}.jpg'
  filename, ext = os.path.splitext(os.path.basename(filepath))

  img = cv2.imread(filepath)
  height, width, channels = img.shape

  # 유전자들의 범위 설정
  n_initial_genes = 50
  n_population = 50
  prob_mutation = 0.01
  prob_add = 0.3
  prob_remove = 0.2

  min_radius, max_radius = 5, 15
  save_every_n_iter = 100

  # 유전자 생성 및 돌연변이 확률 설정
  class Gene():
    def __init__(self):
      self.center = np.array([random.randint(0, width), random.randint(0, height)])
      self.radius = random.randint(min_radius, max_radius)
      self.color = np.array([random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)])

    def mutate(self):
      mutation_size = max(1, int(round(random.gauss(15, 4)))) / 100

      r = random.uniform(0, 1)
      if r < 0.33:
        self.radius = np.clip(random.randint(
            int(self.radius * (1 - mutation_size)),
            int(self.radius * (1 + mutation_size))
        ), 1, 100)
      elif r < 0.66: # center
       self.center = np.array([
         np.clip(random.randint(
             int(self.center[0] * (1 - mutation_size)),
             int(self.center[0] * (1 + mutation_size))),
            0, width),
         np.clip(random.randint(
          int(self.center[1] * (1 - mutation_size)),
          int(self.center[1] * (1 + mutation_size))),
        0, height)
      ])
      else: # color
       self.color = np.array([
        np.clip(random.randint(
          int(self.color[0] * (1 - mutation_size)),
          int(self.color[0] * (1 + mutation_size))),
        0, 255),
        np.clip(random.randint(
          int(self.color[1] * (1 - mutation_size)),
          int(self.color[1] * (1 + mutation_size))),
        0, 255),
        np.clip(random.randint(
          int(self.color[2] * (1 - mutation_size)),
          int(self.color[2] * (1 + mutation_size))),
        0, 255)
      ])
  # 적합도 계산
  def compute_fitness(genome):
    out = np.ones((height, width, channels), dtype=np.uint8) * 255

    for gene in genome:
      cv2.circle(out, center=tuple(gene.center), radius=gene.radius, color=(int(gene.color[0]), int(gene.color[1]), int(gene.color[2])), thickness=-1)
    # MSE 계산
    fitness = 255. / compare_mse(img, out)

    return fitness, out

  def compute_population(g):
    genome = deepcopy(g)
    # 돌연변이 생성
    if len(genome) < 200:
      for gene in genome:
        if random.uniform(0, 1) < prob_mutation:
          gene.mutate()
    else:
      for gene in random.sample(genome, k=int(len(genome) * prob_mutation)):
        gene.mutate()

    # 유전자 추가
    if random.uniform(0, 1) < prob_add:
      genome.append(Gene())

    # 유전자 삭제
    if len(genome) > 0 and random.uniform(0, 1) < prob_remove:
      genome.remove(random.choice(genome))

    # 새로 만든 유전자 적합도 측정
    new_fitness, new_out = compute_fitness(genome)

    return new_fitness, genome, new_out


  if __name__ == '__main__':
    os.makedirs('result', exist_ok=True)

    p = mp.Pool(mp.cpu_count() - 1)

    # 첫번째 세대 생성
    best_genome = [Gene() for _ in range(n_initial_genes)]

    best_fitness, best_out = compute_fitness(best_genome)

    n_gen = 0
    #무한 루프
    while True:
      try:
        results = p.map(compute_population, [deepcopy(best_genome)] * n_population)
      except KeyboardInterrupt:
        p.close()
        break
      results.append([best_fitness, best_genome, best_out])

      new_fitnesses, new_genomes, new_outs = zip(*results)
      #내림차순으로 제일 좋은 거 추출
      best_result = sorted(zip(new_fitnesses, new_genomes, new_outs), key=lambda x: x[0], reverse=True)

      best_fitness, best_genome, best_out = best_result[0]

      #세대 끝날 때 표시
      print('Generation #%s, Fitness %s' % (n_gen, best_fitness))
      n_gen += 1

      # visualize
      if n_gen % save_every_n_iter == 0:
        cv2.imwrite('result/%s_%s.jpg' % (filename, n_gen), best_out)

      cv2_imshow(best_out)
      if cv2.waitKey(1) == ord('q'):
        p.close()
        break
  cv2.imshow('best out', best_out)
  cv2.waitKey(0)

#데이터가 존재하지 않는 동물 입력 받았을 때 출력되는 문구
else: print("Sorry. It doesn't exist in the data. T.T")