In [None]:
import gc
import os
import time
# import argparse
import torch
import torch.nn.functional as F
from glob import glob
from torch.utils.data import DataLoader
from tqdm import trange
from collections import OrderedDict

# from data import AblatingDataset

import utils

In [4]:
def get_text_embeddings(text_encoder, tokenized_text):
    # ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L377

    device = text_encoder.device
    weight_dtype = text_encoder.dtype

    text_embedding = text_encoder(tokenized_text.to(device))[0].to(weight_dtype)
    return text_embedding

def get_target_noise(scheduler, noise, latents=None, timesteps=None):
    if scheduler.config.prediction_type == "epsilon":
        target = noise
    elif scheduler.config.prediction_type == "v_prediction":
        target = scheduler.get_velocity(latents, noise, timesteps)
    
    return target

In [None]:
from PIL import Image
from torch.utils.data import Dataset

class AblatingDataset(Dataset):
    def __init__(
        self,
        data_root,
        tokenizer,
        placeholder_token,
        vae,
        # concept_type:str="object",
        size=512,
        interpolation="bicubic",
        center_crop=False,
        device="cuda:0",
        # batch_size:int=0,
        # is_zero_shot=False
    ) -> None:
        super().__init__()

        self.prompt = placeholder_token
        self.tokenizer = tokenizer
        self.image_embeddings = []
        
        # if not is_zero_shot:
        scaling_factor = 0.18215
        for f in glob(f"{data_root}/*"):
            if ".png" in f or ".jpg" in f or ".jpeg" in f:
                image = utils.preprocess(f, center_crop, size, interpolation).to(device)
                # image = Image.open(f).convert("RGB")
                with torch.no_grad():
                    latents = vae.encode(image).latent_dist.sample().detach().cpu() * scaling_factor
                self.image_embeddings.append(latents[0])
        # else:
        #     for _ in range(batch_size):
        #         self.image_embeddings.append(torch.rand((1, 4, 64, 64)))
        self._length = len(self.image_embeddings)
    
    def __len__(self):
        return self._length

    def __getitem__(self, index):
        # text = random.choice(self.templates).format(self.prompt)
        text = "red"
        tokenized = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        return tokenized, self.image_embeddings[index]



In [None]:

def train():
    
    target_prompt = "red"

    device = "cuda"
    dataset_path = "ds/not_red"
    # dataset_path = "ds/cyan"

    save_path = "R_maximize_Cyan"
    batch_size = 1
    num_epochs = 50
    
    
    tokenizer, text_encoder, vae, unet, scheduler = utils.load_models_from_local_optioned_path(
        text_encoder_path="openai/clip-vit-large-patch14",
        unet_path="models/sd-15/unet",
        vae_path="models/sd-15/vae",
        tokenizer_version="openai/clip-vit-large-patch14",
    )

    unet.to(device)
    text_encoder.to(device)
    vae.to(device)
    vae.eval()
    
    # freeze unet parameters
    for param in unet.parameters():
        param.requires_grad = False

    
    text_encoder = utils.freeze_and_unfreeze_text_encoder(text_encoder, method="mlp-final-attn")

    # optimizer setting
    # optimizer = utils.get_optimizer(
    #     text_encoder.parameters(),
    #     config.optimizer_name,
    #     config.lr,
    #     (config.beta1, config.beta2),
    #     config.weight_decay,
    #     config.eps,
    # )
    optimizer = torch.optim.Adam(text_encoder.parameters(), lr = 1e-5)

    train_dataset = AblatingDataset(
        data_root=dataset_path,
        tokenizer=tokenizer,
        size=512,
        # concept_type=config.concept_type,
        placeholder_token=target_prompt,
        center_crop=False,
        vae=vae,
        device=device,
        batch_size=batch_size,
        # is_zero_shot=config.is_zero_shot
    )


    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    del vae
    gc.collect()
    torch.cuda.empty_cache()
    
    history = {"loss": []}

    os.makedirs(save_path, exist_ok=True)
    start = time.perf_counter()



    pbar = trange(0, num_epochs, desc="Epoch")
    for epoch in pbar:

        loss_avg = 0
        cnt = 0
        text_encoder.train()
        print(f"Starting epoch {epoch}")
        for step, (tokenized, image_embedding) in enumerate(train_dataloader):
            print(f"Processing step {step}")
            text_embedding = get_text_embeddings(
                text_encoder=text_encoder, 
                tokenized_text=tokenized
            )
            
            # bs, 4, 64, 64
            # if zero shot, image_embedding is random noise
            latents = image_embedding.to(device)

            noise = torch.randn_like(latents).to(device)
            bsz = latents.shape[0]
            timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)
            
            model_pred = unet(noisy_latents, timesteps, text_embedding).sample

            target = get_target_noise(scheduler=scheduler, noise=noise, latents=latents, timesteps=timesteps)
            
            print(f"Device of text_embedding: {text_embedding.device}")
            print(f"Device of latents: {latents.device}")
            print(f"Device of noise: {noise.device}")

            # reversed Textual Inversion, loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            loss = -F.mse_loss(model_pred.float(), target.float(), reduction="mean") 

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_avg += loss.detach().item()
            cnt += step
            
            history["loss"].append(loss.detach().item())
        pbar.set_postfix(OrderedDict(loss=loss_avg / (cnt + 1e-9)))
        text_encoder.eval()
        text_encoder.save_pretrained(f"{save_path}/epoch-{epoch}")
    
    end = time.perf_counter()
    print(f"Time : {end - start}")

    utils.plot_loss(history, save_path)


In [7]:
train()



text_model.encoder.layers.0.mlp.fc1
text_model.encoder.layers.0.mlp.fc2
text_model.encoder.layers.1.mlp.fc1
text_model.encoder.layers.1.mlp.fc2
text_model.encoder.layers.2.mlp.fc1
text_model.encoder.layers.2.mlp.fc2
text_model.encoder.layers.3.mlp.fc1
text_model.encoder.layers.3.mlp.fc2
text_model.encoder.layers.4.mlp.fc1
text_model.encoder.layers.4.mlp.fc2
text_model.encoder.layers.5.mlp.fc1
text_model.encoder.layers.5.mlp.fc2
text_model.encoder.layers.6.mlp.fc1
text_model.encoder.layers.6.mlp.fc2
text_model.encoder.layers.7.mlp.fc1
text_model.encoder.layers.7.mlp.fc2
text_model.encoder.layers.8.mlp.fc1
text_model.encoder.layers.8.mlp.fc2
text_model.encoder.layers.9.mlp.fc1
text_model.encoder.layers.9.mlp.fc2
text_model.encoder.layers.10.mlp.fc1
text_model.encoder.layers.10.mlp.fc2
text_model.encoder.layers.11.self_attn.k_proj
text_model.encoder.layers.11.self_attn.v_proj
text_model.encoder.layers.11.self_attn.q_proj
text_model.encoder.layers.11.self_attn.out_proj
text_model.encoder.l

  hidden_states = F.scaled_dot_product_attention(
Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Starting epoch 0
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:   2%|▏         | 1/50 [00:01<01:28,  1.81s/it, loss=-2.12e+6]

Starting epoch 1
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:   4%|▍         | 2/50 [00:02<00:57,  1.21s/it, loss=-5.22e+6]

Starting epoch 2
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:   6%|▌         | 3/50 [00:03<00:47,  1.01s/it, loss=-1.07e+6]

Starting epoch 3
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:   8%|▊         | 4/50 [00:04<00:42,  1.08it/s, loss=-4.28e+5]

Starting epoch 4
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  10%|█         | 5/50 [00:04<00:39,  1.13it/s, loss=-5.55e+6]

Starting epoch 5
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  12%|█▏        | 6/50 [00:05<00:38,  1.14it/s, loss=-1.28e+6]

Starting epoch 6
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  14%|█▍        | 7/50 [00:06<00:38,  1.12it/s, loss=-4.68e+5]

Starting epoch 7
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  16%|█▌        | 8/50 [00:07<00:35,  1.18it/s, loss=-1.2e+6] 

Starting epoch 8
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  18%|█▊        | 9/50 [00:08<00:31,  1.30it/s, loss=-8.48e+5]

Starting epoch 9
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  20%|██        | 10/50 [00:08<00:29,  1.37it/s, loss=-8.77e+5]

Starting epoch 10
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  22%|██▏       | 11/50 [00:09<00:26,  1.47it/s, loss=-4.44e+6]

Starting epoch 11
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  24%|██▍       | 12/50 [00:09<00:25,  1.52it/s, loss=-1.38e+7]

Starting epoch 12
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  26%|██▌       | 13/50 [00:10<00:23,  1.57it/s, loss=-8.01e+5]

Starting epoch 13
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  28%|██▊       | 14/50 [00:11<00:22,  1.61it/s, loss=-2.41e+6]

Starting epoch 14
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  30%|███       | 15/50 [00:11<00:21,  1.64it/s, loss=-9.81e+5]

Starting epoch 15
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  32%|███▏      | 16/50 [00:12<00:20,  1.66it/s, loss=-2.14e+6]

Starting epoch 16
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  34%|███▍      | 17/50 [00:12<00:19,  1.67it/s, loss=-7.98e+6]

Starting epoch 17
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  36%|███▌      | 18/50 [00:13<00:18,  1.69it/s, loss=-2.59e+6]

Starting epoch 18
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  38%|███▊      | 19/50 [00:14<00:18,  1.64it/s, loss=-5.2e+6] 

Starting epoch 19
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  40%|████      | 20/50 [00:14<00:18,  1.65it/s, loss=-1.07e+7]

Starting epoch 20
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  42%|████▏     | 21/50 [00:15<00:17,  1.66it/s, loss=-1.66e+7]

Starting epoch 21
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  44%|████▍     | 22/50 [00:15<00:16,  1.67it/s, loss=-3.44e+7]

Starting epoch 22
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  46%|████▌     | 23/50 [00:16<00:16,  1.68it/s, loss=-3.58e+7]

Starting epoch 23
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  48%|████▊     | 24/50 [00:17<00:15,  1.68it/s, loss=-7.64e+7]

Starting epoch 24
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  50%|█████     | 25/50 [00:17<00:14,  1.67it/s, loss=-1.95e+7]

Starting epoch 25
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  52%|█████▏    | 26/50 [00:18<00:14,  1.68it/s, loss=-2.68e+7]

Starting epoch 26
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  54%|█████▍    | 27/50 [00:18<00:13,  1.68it/s, loss=-1.1e+7] 

Starting epoch 27
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  56%|█████▌    | 28/50 [00:19<00:13,  1.68it/s, loss=-1.72e+7]

Starting epoch 28
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  58%|█████▊    | 29/50 [00:20<00:13,  1.61it/s, loss=-2.53e+7]

Starting epoch 29
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  60%|██████    | 30/50 [00:20<00:12,  1.60it/s, loss=-8.66e+7]

Starting epoch 30
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  62%|██████▏   | 31/50 [00:21<00:11,  1.63it/s, loss=-8.2e+7] 

Starting epoch 31
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  64%|██████▍   | 32/50 [00:21<00:10,  1.66it/s, loss=-1.9e+7]

Starting epoch 32
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  66%|██████▌   | 33/50 [00:22<00:10,  1.68it/s, loss=-1.66e+8]

Starting epoch 33
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  68%|██████▊   | 34/50 [00:23<00:09,  1.68it/s, loss=-2.68e+7]

Starting epoch 34
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  70%|███████   | 35/50 [00:23<00:08,  1.69it/s, loss=-1.94e+8]

Starting epoch 35
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  72%|███████▏  | 36/50 [00:24<00:08,  1.70it/s, loss=-1.07e+8]

Starting epoch 36
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  74%|███████▍  | 37/50 [00:24<00:07,  1.65it/s, loss=-1.08e+8]

Starting epoch 37
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  76%|███████▌  | 38/50 [00:25<00:07,  1.67it/s, loss=-8.33e+7]

Starting epoch 38
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  78%|███████▊  | 39/50 [00:26<00:06,  1.68it/s, loss=-2.21e+7]

Starting epoch 39
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  80%|████████  | 40/50 [00:26<00:05,  1.69it/s, loss=-1.04e+7]

Starting epoch 40
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  82%|████████▏ | 41/50 [00:27<00:05,  1.70it/s, loss=-2.29e+8]

Starting epoch 41
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  84%|████████▍ | 42/50 [00:27<00:04,  1.70it/s, loss=-5.14e+7]

Starting epoch 42
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  86%|████████▌ | 43/50 [00:28<00:04,  1.71it/s, loss=-3.27e+8]

Starting epoch 43
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  88%|████████▊ | 44/50 [00:28<00:03,  1.70it/s, loss=-4.61e+8]

Starting epoch 44
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  90%|█████████ | 45/50 [00:29<00:02,  1.68it/s, loss=-1.53e+8]

Starting epoch 45
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  92%|█████████▏| 46/50 [00:30<00:02,  1.67it/s, loss=-2.03e+8]

Starting epoch 46
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  94%|█████████▍| 47/50 [00:30<00:01,  1.68it/s, loss=-9.15e+7]

Starting epoch 47
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  96%|█████████▌| 48/50 [00:31<00:01,  1.69it/s, loss=-1.53e+7]

Starting epoch 48
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch:  98%|█████████▊| 49/50 [00:31<00:00,  1.69it/s, loss=-1.65e+8]

Starting epoch 49
Processing step 0
Device of text_embedding: cuda:0
Device of latents: cuda:0
Device of noise: cuda:0


Epoch: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s, loss=-5.57e+8]

Time : 32.55622520000907



