Script to generate the encoders of the prompt

In [1]:
#Add repo path to the system path
from pathlib import Path
import os, sys
repo_path= Path.cwd().resolve()
while '.gitignore' not in os.listdir(repo_path): # while not in the root of the repo
    repo_path = repo_path.parent #go up one level
    print(repo_path)
    
sys.path.insert(0,str(repo_path)) if str(repo_path) not in sys.path else None
exp_path = Path.cwd().resolve() # path to the experiment folder

print(f"Repo Path: {repo_path}")
print(f"Experiment Path: {exp_path}")

/home/benet/tfg/experiments
/home/benet/tfg
Repo Path: /home/benet/tfg
Experiment Path: /home/benet/tfg/experiments/latent_finetuning


In [2]:
#Libraries
import yaml
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.transforms import (
    Compose,
    Resize,
    CenterCrop,
    ToTensor,
    Normalize,
    InterpolationMode,
)
import wandb
import datasets, diffusers
from diffusers import (
    UNet2DModel,
    DDPMScheduler,
)   
from diffusers import DDPMPipeline, AutoencoderKL, DiffusionPipeline
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTokenizer, CLIPTextModel
import logging
from accelerate.logging import get_logger
from accelerate import Accelerator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# load the config file
config_path = exp_path / 'config_latent_finetuning.yaml' # configuration file path (beter to call it from the args parser)
with open(config_path) as file: # expects the config file to be in the same directory
    config = yaml.load(file, Loader=yaml.FullLoader)

# Load the diffusion model
ldm = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")


# if config['text_promt']['prompt'] is not None:
#         # Encode prompt properly
#     prompt = [config['text_promt']['prompt']] * config['processing']['batch_size']
#     name = config['text_promt']['prompt'].replace(" ", "_")

# else:
#     # Make sure the text embeddings are None but in the format (batch_size, num_tokens, hidden_size)
#     prompt = [""] * config['processing']['batch_size']
#     name = "empty_prompt"
    
# tokenizer = ldm.tokenizer
# text_encoder = ldm.text_encoder.to(device)

# # Load tokenizer and text encoder and encode prompt
# text_inputs = tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(device)
# text_embeddings = text_encoder(**text_inputs).last_hidden_state
# print(text_embeddings.shape)

# # save the embeddings in a folder for later use with the name of the prompt
# print(name)
# os.makedirs(exp_path / 'text_embeddings', exist_ok=True)
# torch.save(text_embeddings, exp_path / 'text_embeddings' / f'{name}.pt')
# print(f"Text embeddings saved in {exp_path / 'text_embeddings' / f'{name}.pt'}")


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [14]:
def get_embeddings(prompt, bs, ldm):
    """
    Function to get the text embeddings from the prompt
    """
    if prompt is not None:
            # Encode prompt properly
        prompt = prompt #* bs
        name = prompt.replace(" ", "_")
    else:
        # Make sure the text embeddings are None but in the format (batch_size, num_tokens, hidden_size)
        prompt = "" #* bs
        name = "empty_prompt"
        
    tokenizer = ldm.tokenizer
    text_encoder = ldm.text_encoder

    # Load tokenizer and text encoder and encode prompt
    text_inputs = tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
    text_embeddings = text_encoder(**text_inputs).last_hidden_state
    return text_embeddings

In [15]:
text_embeddings = get_embeddings(config['text_promt']['prompt'], config['processing']['batch_size'], ldm)


In [16]:
text_embeddings.shape

torch.Size([1, 77, 768])

In [17]:
bs = config['processing']['batch_size']
text_embeddings = text_embeddings.expand(bs, -1, -1)
text_embeddings.shape

torch.Size([4, 77, 768])

In [4]:
# Load the embeddings
name = config['text_promt']['embedding_name']
text_embeddings = torch.load(exp_path / 'text_embeddings' / f'{name}.pt')
print(text_embeddings.shape)

torch.Size([1, 77, 768])
