[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lexiconium/textual-inversion/blob/main/conceptualizer.ipynb)

In [None]:
#@title ## Install Dependencies

!pip install git+https://github.com/lexiconium/textual-inversion > /dev/null 2>&1

In [41]:
#@title ## Import Dependencies

import os
import secrets
import shutil

import gradio as gr
import torch
import torchvision
from diffusers import AutoencoderKL, PNDMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

from textual_inversion.dataset import TextualInversionDataset
from textual_inversion.pipeline import TextualInversionDiffusionPipeline, TextualInversionDiffusionTrainingConfig

In [None]:
#@title ## Configuration

#@markdown ---

#@markdown #### 1. Model and training configs

pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"  #@param {type:"string"}
size = 512  #@param {type:"integer"}
num_training_epochs = 1000  #@param {type:"integer"}
training_batch_size = 1  #@param {type:"integer"}
gradient_accumulation_steps = 4  #@param {type:"integer"}
learning_rate = 1e-4  #@param {type:"number"}
adam_beta1 = 0.9  #@param {type:"number"}
adam_beta2 = 0.999  #@param {type:"number"}
adam_weight_decay = 1e-2  #@param {type:"number"}
adam_epsilon = 1e-8  #@param {type:"number"}
lr_scheduler = "constant"  #@param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
warmup_ratio = 0.2  #@param {type:"number"}
num_warmup_steps = 0  #@param {type:"integer"}
mixed_precision = "fp16"  #@param ["no", "fp16", "bf16"]
seed = 42  #@param {type:"integer"}
use_auth_token = True  #@param {type:"boolean"}
output_dir = "outputs"  #@param {type:"string"}

#@markdown ---

#@markdown #### 2. Conceptualization config
#@markdown One must write in `placeholder_token` and `initializer_token`. \\
#@markdown When one wants to train an exotic cat picture, for instance, `placeholder_token`
#@markdown and `initializer_token` could be "\<exotic-cat\>", "cat" respectively.

learnable_property = "object"  #@param ["object", "style"]
placeholder_token = "<exotic-cat>"  #@param {type:"string"}
initializer_token = "cat"  #@param {type:"string"}

#@markdown ---

In [None]:
#@title ## Upload Data

g_save_dir = ""

with gr.Blocks() as uploader:
    with gr.Row().style(equal_height=True):
        uploaded_files = gr.File(
            file_count="multiple",
            label="Upload images",
            interactive=True
        )

        with gr.Column():
            status_msg = gr.Textbox(
                value="If uploaded, click Save.",
                lines=10,
                label="Status",
                show_label=False,
                interactive=False
            )
            save_button = gr.Button("Save", variant="primary")


    def save_fn(tmpfiles):
        global g_save_dir

        save_dir = os.path.join(os.getcwd(), secrets.token_hex(nbytes=16))
        os.makedirs(save_dir, exist_ok=True)

        for tmpfile in tmpfiles:
            shutil.copy(tmpfile.name, save_dir)

        g_save_dir = save_dir

        return "Done."


    save_button.click(
        save_fn,
        inputs=[uploaded_files],
        outputs=[status_msg]
    )

uploader.launch()

In [None]:
#@title ## Login to Hugging Face Hub

from huggingface_hub import notebook_login

notebook_login()

In [None]:
#@title ## Setup

tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
    use_auth_token=use_auth_token
)
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="text_encoder",
    use_auth_token=use_auth_token
)
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="unet",
    use_auth_token=use_auth_token
)
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="vae",
    use_auth_token=use_auth_token
)
scheduler = PNDMScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    num_train_timesteps=1000,
    tensor_format="pt"
)

pipeline = TextualInversionDiffusionPipeline(
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    unet=unet,
    vae=vae,
    scheduler=scheduler
)

training_config = TextualInversionDiffusionTrainingConfig(
    placeholder_token=placeholder_token,
    initializer_token=initializer_token,
    num_training_epochs=num_training_epochs,
    training_batch_size=training_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    adam_beta1=adam_beta1,
    adam_beta2=adam_beta2,
    adam_weight_decay=adam_weight_decay,
    adam_epsilon=adam_epsilon,
    lr_scheduler=lr_scheduler,
    warmup_ratio=warmup_ratio,
    num_warmup_steps=num_warmup_steps,
    mixed_precision=mixed_precision,
    seed=seed,
    output_dir=output_dir
)

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(
        (size, size),
        torchvision.transforms.InterpolationMode.BICUBIC
    ),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.48145466, 0.4578275, 0.40821073],
        std=[0.26862954, 0.26130258, 0.27577711]
    )
])
dataset = TextualInversionDataset(
    data_dir=g_save_dir,
    transforms=transforms,
    tokenizer=tokenizer,
    placeholder_token=placeholder_token,
    learnable_property=learnable_property
)

pipeline = pipeline.train(training_config=training_config, dataset=dataset)

In [None]:
#@title ## Generate

def text_to_image(
    prompt,
    num_samples,
    height,
    width,
    num_inference_steps,
    guidance_scale
):
    with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
        images = pipeline(
            [prompt] * num_samples,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale
        ).images

    return images


interface_inputs = [
    gr.Textbox(placeholder="Write a prompt you want to generate in image.", label="Prompt"),
    gr.Number(value=2, label="Number of samples", precision=0),
    gr.Number(value=512, label="Height", precision=0),
    gr.Number(value=512, label="Width", precision=0),
    gr.Slider(minimum=10, maximum=100, value=50, step=1, label="Number of inference steps"),
    gr.Slider(minimum=1, maximum=10, value=7.5, step=0.1, label="guidance_scale")
]
interface_output = [gr.Gallery(label="Images from prompt", interactive=False)]

gr.Interface(
    fn=text_to_image,
    inputs=interface_inputs,
    outputs=interface_output
).launch()