In [1]:
import torch
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
from tqdm.auto import tqdm
from torch import autocast
from PIL import Image
import cv2
import numpy as np
import math
import random
import bisect
import operator
import matplotlib.pyplot as plt
import copy
import sys
import os
import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler

# 1. Load the autoencoder model which will be used to decode the latents into image space. 
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# 2. Load the tokenizer and text encoder to tokenize and encode the text. 
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.8.layer_norm2.weight', 'vision_model.encoder.layers.10.self_attn.k_proj.weight', 'vision_model.embeddings.position_embedding.weight', 'vision_model.encoder.layers.13.layer_norm2.weight', 'vision_model.encoder.layers.2.mlp.fc2.bias', 'vision_model.encoder.layers.2.mlp.fc1.bias', 'vision_model.encoder.layers.21.mlp.fc1.bias', 'vision_model.encoder.layers.14.self_attn.q_proj.weight', 'vision_model.encoder.layers.9.mlp.fc2.bias', 'vision_model.encoder.layers.17.layer_norm2.bias', 'vision_model.encoder.layers.3.layer_norm1.bias', 'vision_model.encoder.layers.19.self_attn.v_proj.bias', 'vision_model.encoder.layers.9.layer_norm1.bias', 'vision_model.encoder.layers.17.mlp.fc2.weight', 'vision_model.encoder.layers.7.self_attn.out_proj.weight', 'vision_model.encoder.layers.16.self_attn.out_proj.bias', 'vision_model.encoder.layers.14.self_attn.v_proj.b

In [3]:
# from diffusers import LMSDiscreteScheduler

# scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

from diffusers import DPMSolverMultistepScheduler

scheduler = DPMSolverMultistepScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    num_train_timesteps=1000,
    trained_betas=None,
    predict_epsilon=True,
    thresholding=False,
    algorithm_type="dpmsolver++",
    solver_type="midpoint",
    lower_order_final=True,
)

  predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)


In [4]:
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)

In [34]:
def func(latents,text_embeddings):
    
    num_inference_steps = 50            # Number of denoising steps

    guidance_scale = 7.5                # Scale for classifier-free guidance

    #############################################

    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    with torch.no_grad():
      uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]

    ###############################################

    embeddings = torch.cat([uncond_embeddings, text_embeddings])

    #############################################

    scheduler.set_timesteps(num_inference_steps)

    ##########################################

    latents_1 = latents * scheduler.init_noise_sigma

    #############################################



    for t in tqdm(scheduler.timesteps):
      # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
      latent_model_input_a = torch.cat([latents_1] * 2)

      latent_model_input = scheduler.scale_model_input(latent_model_input_a, t)

      # predict the noise residual
      with torch.no_grad():
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeddings).sample

      # perform guidance
      noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
      noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

      # compute the previous noisy sample x_t -> x_t-1
      latents_1 = scheduler.step(noise_pred, t, latents_1).prev_sample

    ###############################################

    # scale and decode the image latents with vae
    latents_3 = 1 / 0.18215 * latents_1

    with torch.no_grad():
      image = vae.decode(latents_3).sample


    ##########################################

    image_1 = (image / 2 + 0.5).clamp(0, 1)
    image_2 = image_1.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image_2 * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    result_image = pil_images[0]
    
    return result_image

In [37]:
#gaの関数
population = 30 #個体数
generations = 100 #世代数
X = 5
Y = 6
tournament = 5

# mutation = 10 #突然変異の個体数
# population = 3 #個体数
# generations = 2 #世代数
# X = 1
# Y = 3
# # mutation = 1 #突然変異の個体数
# tournament = 3

if X*Y!=population:
    print('Error: parameter error', file=sys.stderr)
    sys.exit(1)
    
gene_length = 64
elite = 3 #エリートの数
initializa_txt_num = 100 #初期化個体においてtxtのベクトルをどれくらい元から変異させるか
batch_size = 1
height = 512                        # default height of Stable Diffusion
width = 512                         # default width of Stable Diffusion

if max(elite,tournament)>population:
    print('Error: parameter error', file=sys.stderr)
    sys.exit(1)
    
def pil2cv(image):
    ''' PIL型 -> OpenCV型 '''
    new_image = np.array(image, dtype=np.uint8)
    if new_image.ndim == 2:  # モノクロ
        pass
    elif new_image.shape[2] == 3:  # カラー
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
    elif new_image.shape[2] == 4:  # 透過
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
    return new_image


#相関係数 大きければ大きいほどよい
def calc_fitness(target,images):
    fitness = []
    for image in images:
        target_hist = cv2.calcHist([target], [0], None, [256], [0, 256])
        comparing_hist = cv2.calcHist([pil2cv(image)], [0], None, [256], [0, 256])
        ret = cv2.compareHist(target_hist, comparing_hist, 0)
#         ans = 25*(np.exp(max(0,ret)))-24 #適合度の式
        fitness.append(ret)
    return fitness

#seed値を複数個用意
def initialize_gene(text_embeddings):
    arr = []
    for i in range(population):
        embeddings = text_embeddings.clone()
        seed_here = random.randrange(1000)
        latents_torch = torch.randn(
          (batch_size, unet.in_channels, height // 8, width // 8),
          generator=torch.manual_seed(seed_here),
        )
        latents = latents_torch.to(torch_device)
        for j in range(initializa_txt_num):
            a = random.randrange(77)
            b = random.randrange(768)
            text_embeddings[0][a][b] = np.random.randn()
        arr.append([latents,embeddings])    
    return arr


def show_image(images,epoch):
    os.mkdir('experiment_data/'+experiment_id+'/epoch'+str(epoch))
    fig = plt.figure()
    for i in range(len(images)):
        arrPIL = np.asarray(images[i])
        images[i].save('experiment_data/'+experiment_id+'/epoch'+str(epoch)+"/"+str(i)+".png")
        ax1 = fig.add_subplot(X, Y, i+1)
        ax1.set_title(str(i))
        ax1.axis("off")
        plt.imshow(arrPIL)
#     print(epoch)
    plt.savefig('experiment_data/'+experiment_id+'/epoch'+str(epoch)+"/"+"ttl.png")
    plt.show()
    
    
def evolve(fitness,genes):
    new_genes = []
    
    #エリート戦略
    elite_array = []
    for i in range(population):
        extra_gene = copy.deepcopy(genes[i])
        elite_array.append([fitness[i],extra_gene])
    elite_array = sorted(elite_array, key=operator.itemgetter(0))
    elite_array = list(reversed(elite_array))
    for i in range(elite):
        new_genes.append(elite_array[i][1])
    
    #mutation, cross overで使う配列の準備
    index = []
    num = 0
    for i in range(population):
        num += fitness[i]
        index.append(num)
    index.append(num)
    
    #mutation
#     for i in range(mutation):
#         p1 = random.randrange(num)
#         g1 = bisect.bisect(index,p1)
#         new_gene1 = copy.deepcopy(genes[g1])
#         for j in range(4):
#             for k in range(64):
#                 for l in range(64):
#                     mutation_flag = random.randrange(100)
#                     if mutation_flag==0:
#                         new_gene1[0][0][j][k][l] = np.random.randn()
#         for j in range(77):
#             for k in range(768):
#                 mutation_flag = random.randrange(100)
#                 if mutation_flag==0:
#                     new_gene1[1][0][j][k] = np.random.randn()
#         new_genes.append(new_gene1)

    # cross over
    for i in range(population-elite):
        used1 = [0]*population 
        p1 = []
        while len(p1)<tournament:
            r1 = random.randrange(population)
            if used1[r1]==0:
                p1.append([fitness[r1],r1])
        p1 = sorted(p1, key=operator.itemgetter(0))
#         print(p1)
        new_gene = copy.deepcopy(genes[p1[tournament-1][1]])
#         print(p1[tournament-1][1])
        used1 = [0]*population 
        p1 = []
        while len(p1)<tournament:
            r1 = random.randrange(population)
            if used1[r1]==0:
                p1.append([fitness[r1],r1])
        p1 = sorted(p1, key=operator.itemgetter(0))
#         print(p1)
        new_gene1 = copy.deepcopy(genes[p1[tournament-1][1]])
#         print(p1[tournament-1][1])
        cross = random.randrange(gene_length)
        for j in range(4):
            for k in range(64):
                for l in range(64):
                    a = random.randrange(2)
                    if a==0:
                        new_gene[0][0][j][k][l] = new_gene1[0][0][j][k][l]
        for j in range(77):
            for k in range(768):
                a = random.randrange(2)
                if a==0:
                    new_gene[1][0][j][k] = new_gene1[1][0][j][k]
                    
        for j in range(4):
            for k in range(64):
                for l in range(64):
                    mutation_flag = random.randrange(300)
                    if mutation_flag==0:
                        new_gene[0][0][j][k][l] += np.random.randn()
                        if new_gene[0][0][j][k][l]>= 5:
                            new_gene[0][0][j][k][l] = 4.99
                        if new_gene[0][0][j][k][l]<= -5:
                            new_gene[0][0][j][k][l] = -4.99
        for j in range(77):
            for k in range(768):
                mutation_flag = random.randrange(6000)
                if mutation_flag==0:
                    new_gene[1][0][j][k] += np.random.randn()
                    if new_gene[1][0][j][k]>= 5:
                        new_gene[1][0][j][k] = 4.99
                    if new_gene[1][0][j][k]<= -5:
                        new_gene[1][0][j][k] = -4.99

        new_genes.append(new_gene)
                    
    if len(new_genes)!=population:
        print('Error: evolve error', file=sys.stderr)
        sys.exit(1)
    return new_genes

In [38]:
def do_experiment():
    experiment_time_delta = datetime.timedelta(hours=9)
    experiment_JST = datetime.timezone(experiment_time_delta, 'JST')
    global experiment_id
    experiment_now = datetime.datetime.now(experiment_JST)
    experiment_id = experiment_now.strftime('%Y%m%d%H%M%S')
    #ディレクトリ初期化
    os.mkdir('experiment_data/'+experiment_id)
    
    prompt = ["a beautiful photograph of mountains, green covered hills, meadows, sunset, sky"]
    # prompt = ["a photograph of a rice field in summer with clear sky"]
    target_image = Image.open("target_image/mountain.png").convert('RGB').resize((512,512))
    target = pil2cv(target_image)
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")


    with torch.no_grad():
      text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]

    genes = initialize_gene(text_embeddings)

    with autocast("cuda"):
        #初期配置
        images = []
        for i in range(population):
            image = func(genes[i][0],genes[i][1])
            images.append(image)
        show_image(images,0)
        #進化+選択
        for i in range(1,generations+1):
            #選択
            fitness = calc_fitness(target,images)
            with open('experiment_data/'+experiment_id+'/fitness.txt', 'a') as f:
                print(fitness, file=f)
            #進化
            genes = evolve(fitness,genes)
            #画像生成
            images.clear()
            for j in range(population):
                image = func(genes[j][0],genes[j][1])
                images.append(image)
            #表示
            show_image(images,i)
        fitness = calc_fitness(target,images)
        with open('experiment_data/'+experiment_id+'fitness.txt', 'a') as f:
            print(fitness, file=f)

In [44]:
do_experiment()

In [45]:
do_experiment()

In [46]:
do_experiment()

In [47]:
do_experiment()

In [48]:
do_experiment()