In [None]:
import gc
import os
import time
# import argparse
import torch
import torch.nn.functional as F
import numpy as np
from glob import glob
from torch.utils.data import DataLoader
from tqdm import trange
from diffusers import StableDiffusionPipeline
from collections import OrderedDict



import utils

In [4]:
def get_text_embeddings(text_encoder, tokenized_text):
    # ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L377

    device = text_encoder.device
    weight_dtype = text_encoder.dtype

    text_embedding = text_encoder(tokenized_text.to(device))[0].to(weight_dtype)
    return text_embedding

def get_target_noise(scheduler, noise, latents=None, timesteps=None):
    if scheduler.config.prediction_type == "epsilon":
        target = noise
    elif scheduler.config.prediction_type == "v_prediction":
        target = scheduler.get_velocity(latents, noise, timesteps)
    
    return target

In [5]:
from PIL import Image
from torch.utils.data import Dataset

class AblatingDataset(Dataset):
    def __init__(
        self,
        data_root,
        tokenizer,
        placeholder_token,
        vae,
        # concept_type:str="object",
        size=512,
        interpolation="bicubic",
        center_crop=False,
        device="cuda:0",
        batch_size:int=0,
        # is_zero_shot=False
    ) -> None:
        super().__init__()

        self.prompt = placeholder_token
        self.tokenizer = tokenizer
        self.image_embeddings = []
        
        # if not is_zero_shot:
        scaling_factor = 0.18215
        for f in glob(f"{data_root}/*"):
            if ".png" in f or ".jpg" in f or ".jpeg" in f:
                image = utils.preprocess(f, center_crop, size, interpolation).to(device)
                # image = Image.open(f)
                # if not image.mode == "RGB":
                #     image = image.convert("RGB")
                # image.show()
                # image.save()
                # image = np.array(image).astype(np.float32) / 255.0
                # image = 2.0 * image - 1.0
                # print(image.shape)
                # image = image.transpose(2, 0, 1)
                # image = image[None].transpose(0, 3, 1, 2)

                # image = torch.from_numpy(image).to(device)
                # print(image)
                # # print(image.shape)
                # # image = 2.0 * image - 1.0
                # blue = torch.zeros(4, 64, 64)
                # blue[2] = 255
                # blue[3] = 255
                # blue = blue[None]
                # # blue_show = trans()(blue)
                # # blue_show.show()
                # print(blue.shape)

                with torch.no_grad():
                    latents = vae.encode(image).latent_dist.sample().detach().cpu() * scaling_factor
                    # print(latents)
                    #show encoded and decoded image to check
                    import torchvision.transforms as T
                    trans = T.ToPILImage
                    # blue_show = trans()(image.squeeze())
                    # blue_show.show()
                    # show_latent = to
                    show_latent = trans()(latents.squeeze(0))
                    # show_latent.show()

                    # show_latent.save()
                    temp = vae.decode(latents.to(device)).sample
                    temp = trans()(temp.squeeze())
                    temp.show()
                    # temp.save()
                self.image_embeddings.append(latents[0])
                # self.image_embeddings.append(image)

        # else:
        #     for _ in range(batch_size):
        #         self.image_embeddings.append(torch.rand((1, 4, 64, 64)))
        self._length = len(self.image_embeddings)
    
    def __len__(self):
        return self._length

    def __getitem__(self, index):
        # text = random.choice(self.templates).format(self.prompt)
        text = "red"
        tokenized = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        return tokenized, self.image_embeddings[index]



In [None]:

def train():
    
    target_prompt = "red"

    device = "cuda"
    dataset_path = "ds/red"
    # dataset_path = "ds/not_red"

    # save_path = "default_SDpipe/R_maximize_BG_no_Norm"
    # save_path = "updated_CLIP/R_maximize_GB_26_G_maximize_RB_3_B_maximize_RG"
    save_path = "updated_CLIP/R_maximize_multi_red_layer7"


    batch_size = 6
    num_epochs = 50
    
    model = "runwayml/stable-diffusion-v1-5"

    from diffusers import PNDMScheduler
    from transformers import CLIPTextModel, CLIPTokenizer

    pipe = StableDiffusionPipeline.from_pretrained(model, safety_checker=None).to(device)
    pipe.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    pipe.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", low_cpu_mem_usage=False)
    # pipe.text_encoder = CLIPTextModel.from_pretrained("updated_CLIP/R_maximize_BG_first_attn_all_100_2/epoch-26", low_cpu_mem_usage=False)

    pipe.scheduler =  PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, skip_prk_steps=True, steps_offset=1)
    pipe.scheduler.config.prediction_type
    # pipe.tokenizer.to(device)
    pipe.text_encoder.to(device)
    pipe.unet.to(device)
    # pipe.scheduler.to(device)
   

    pipe.vae.eval()
    
    # freeze unet parameters
    for param in pipe.unet.parameters():
        param.requires_grad = False


    text_encoder = pipe.text_encoder
    for param in text_encoder.parameters():
        param.require_grad = False

    for param_name, module in text_encoder.named_modules():
        # if "0.self_attn." in param_name:
        if "layers.7" in param_name:
        # if "mlp.fc" in param_name:
            for param in module.parameters():
                param.requires_grad = True
        # elif "layers.0.mlp" in param_name:
        #     for param in module.parameters():
        #         param.requires_grad = True

    # for param in text_encoder.parameters():
        # param.require_grad = True

    optimizer = torch.optim.Adam(text_encoder.parameters(), lr = 1e-5)
    # optimizer = torch.optim.Adam(text_encoder.parameters(), lr = 5e-6)


    train_dataset = AblatingDataset(
        data_root=dataset_path,
        tokenizer=pipe.tokenizer,
        size=512,
        # concept_type=config.concept_type,
        placeholder_token=target_prompt,
        center_crop=False,
        vae= pipe.vae,
        device=device,
        batch_size=batch_size,
        # is_zero_shot=config.is_zero_shot
    )


    # train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # del pipe.vae
    # gc.collect()
    # torch.cuda.empty_cache()
    
    # history = {"loss": []}

    # os.makedirs(save_path, exist_ok=True)
    # start = time.perf_counter()



    # pbar = trange(0, num_epochs, desc="Epoch")
    # for epoch in pbar:

    #     loss_avg = 0
    #     cnt = 0
    #     text_encoder.train()

    #     for step, (tokenized, image_embedding) in enumerate(train_dataloader):

    #         text_embedding = get_text_embeddings(
    #             text_encoder=text_encoder, 
    #             tokenized_text=tokenized
    #         )
            
    #         # bs, 4, 64, 64
    #         # if zero shot, image_embedding is random noise
    #         latents = image_embedding.to(device)

    #         noise = torch.randn_like(latents).to(device)
    #         bsz = latents.shape[0]
    #         timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
    #         timesteps = timesteps.long()
    #         noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
            
    #         model_pred = pipe.unet(noisy_latents, timesteps, text_embedding).sample

    #         target = get_target_noise(scheduler=pipe.scheduler, noise=noise, latents=latents, timesteps=timesteps)
            
    #         # in Textual Inversion, loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
    #         loss =  -F.mse_loss(model_pred.float(), target.float(), reduction="mean") # revert to neg

    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    #         loss_avg += loss.detach().item()
    #         cnt += step
            
    #         history["loss"].append(loss.detach().item())
    #     pbar.set_postfix(OrderedDict(loss=loss_avg / (cnt + 1e-9)))
    #     text_encoder.eval()
    #     text_encoder.save_pretrained(f"{save_path}/epoch-{epoch}")
    
    # end = time.perf_counter()
    # print(f"Time : {end - start}")

    # utils.plot_loss(history, save_path)


In [7]:
train()

Couldn't connect to the Hub: 404 Client Error. (Request ID: Root=1-66df54c8-1bd8151d1a2556094e1e6231;ea529015-dd7b-42d9-a17f-3399ccf220cb)

Repository Not Found for url: https://huggingface.co/api/models/runwayml/stable-diffusion-v1-5.
Please make sure you specified the correct `repo_id` and `repo_type`.
If you are trying to access a private or gated repo, make sure you are authenticated..
Will try to load from local cache.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.7.self_attn.out_proj.weight', 'vision_model.encoder.layers.23.mlp.fc2.bias', 'vision_model.encoder.layers.16.mlp.fc2.bias', 'vision_model.encoder.layers.23.self_attn.q_proj.weight', 'vision_model.encod