In [None]:
%%capture
!pip install -r requirements.txt

In [None]:
# to load and transform image datasets to tensor
from datasets import load_dataset, load_from_disk, VerificationMode
from datasets.arrow_dataset import Dataset
# diffusers model
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker

# Text processing
from transformers import CLIPTextModel, CLIPTokenizer

# Image processing
from torchvision.transforms import v2
from torchvision.transforms import InterpolationMode

In [None]:
from configs import DATA_FOLDER

In [None]:
import os
from accelerate.utils import write_basic_config

write_basic_config()  # Write a config file
os._exit(00)  # Restart the notebook

# Note that step 1, 2, 3 should be migrated to EMR

## 1. Load Images
Using the imagefolder feature of HuggingFace's load_dataset

In [None]:
img_dataset: Dataset = load_dataset("imagefolder", data_dir=DATA_FOLDER, split="train")

In [None]:
img_dataset[0]

In [None]:
img_dataset[0].get("image")

## 2. Process Image

(References: 
https://huggingface.co/docs/diffusers/v0.27.2/en/training/text2image  
https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L40  
https://pytorch.org/vision/main/transforms.html  
)

This step should be moved to EMR, which includes
+ Resize Image to 512x512 (with Bilinear interpolation) (For stable diffusion 512 seems to be the optimal size, but for some images it might distort the images, so we might want to consider just add padding?)
+ normalize images

In [None]:
from torchvision.transforms import v2
from torchvision.transforms import InterpolationMode

In [None]:
transform_pipeline = v2.Compose([
    # TODO
    # Instead of resize, enlarge the photo by ratio and add padding
    v2.Resize(size=(512, 512), interpolation=InterpolationMode.LANCZOS),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


new_img = transform_pipeline(img_dataset[0].get("image"))

In [None]:
new_img.shape

## 3. Process Caption
To process caption (in english), we need a Cliptextmodel
+ https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel  
+ https://huggingface.co/runwayml/stable-diffusion-v1-5  
+ https://huggingface.co/openai/clip-vit-large-patch14  

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

In [None]:
def preprocess_train(examples):
    # examples are a batch of 4 images
    # we apply the transformation (reference above for what it transfomed to)
    # then apply the tokenization

    examples["pixel_values"] = [transform_pipeline(image) for image in examples["image"]]
    
    inputs = tokenizer([example for example in examples["caption"]],
                       padding="max_length",
                       truncation=True,
                       return_tensors="pt")

    examples["input_ids"] = inputs.input_ids
    return examples

train_set = img_dataset.with_transform(preprocess_train)

In [None]:
# TODO
# Upload to S3 parquet

# Note that Step 4 and beyond is carried out in SageMaker

## 4. Setup CLIP Embedding, VAE, UNET and remaining part of architecture for training
other args we can try, like whether to use EMA or not https://huggingface.co/stabilityai/stable-diffusion-2-1/discussions/22

In [None]:
type(train_set)

## 5. Get a scheduler for adding noise

## 6. Training loop
This train loop is referenced from Huggingface training script https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L912

Ref: https://huggingface.co/docs/accelerate/quicktour  


In [None]:
import torch
import torch.nn.functional as F
from accelerate import Accelerator

In [None]:
# TODO: add checkpoint and resume from checkpoint

In [None]:
def collate_fn(examples):
    """
    Collate Function is used to create a batch
    """
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([example["input_ids"] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids}




In [None]:
train_set

# TODO:
A bit research on Adding noise step:
+ input_perturbation
+ noise_offset

Prediction Type: epsilon vs v_prediction 
https://medium.com/@zljdanceholic/three-stable-diffusion-training-losses-x0-epsilon-and-v-prediction-126de920eb73


In [None]:
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import LMSDiscreteScheduler

def train():
    # Getting the model weights from the pretrained models hub
    
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

    vae = AutoencoderKL.from_pretrained(
        "runwayml/stable-diffusion-v1-5", subfolder="vae"
    )

    unet = UNet2DConditionModel.from_pretrained(
        "runwayml/stable-diffusion-v1-5", subfolder="unet"
    )

    # Freeze vae and text_encoder (we only train the UNET)

    for params in vae.parameters():
        params.requires_grad = False

    for params in text_encoder.parameters():
        params.requires_grad = False
        
    noise_scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

    lr = 0.001
    batch_size = 4
    optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                              shuffle=True, collate_fn=collate_fn, num_workers=1)
    accelerator = Accelerator(mixed_precision="fp16")
    device = accelerator.device

    epochs = 2
    weight_dtype = torch.float32

    unet, optimizer, train_loader, noise_scheduler = accelerator.prepare(
        unet, optimizer, train_loader, noise_scheduler
    )

    prediction_type = "v_prediction"

    # Move vae and unet to device
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)

    accelerator.wait_for_everyone()
    for epoch in range(epochs):
        text_encoder.train()
        train_loss = 0.0

        for step, batch in enumerate(train_loader):
            with accelerator.accumulate(unet):
                # First encode the image to laten space with the VAE encoder
                latent = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()

                # sample noise to add to latent
                noise = torch.randn_like(latent)
                bsz = latent.shape[0]

                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latent.device)
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latent, noise, timesteps)

                # Get the text embedding for conditioning
                encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]

                # Get the target for loss depending on the prediction type
                if prediction_type is not None:
                    # set prediction_type of scheduler if defined
                    noise_scheduler.register_to_config(prediction_type=prediction_type)

                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latent, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                # Predict the noise residual and compute loss
                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]

                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                # Backpropagate
                accelerator.backward(loss)
                optimizer.step()
                noise_scheduler.step()
                optimizer.zero_grad()
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(unet.parameters(), 1)
        accelerator.wait_for_everyone()

    # Create the pipeline using the trained modules and save it.
    if accelerator.is_main_process:
        pipeline = StableDiffusionPipeline(
            text_encoder=accelerator.unwrap_model(text_encoder),
            vae=vae,
            unet=unet.module if accelerator.num_processes >1 else unet,
            tokenizer=tokenizer,
            scheduler=PNDMScheduler(
                beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
            ),
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
        pipeline.save_pretrained(OUTPUT_DIR)

    accelerator.end_training()

In [None]:
from accelerate import notebook_launcher


In [None]:
notebook_launcher(train, num_processes=1)