## Import Metrics

In [30]:
import torch
import random
import requests
from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel, CLIPImageProcessor
import diffusers
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline, 
    StableDiffusionControlNetImg2ImgPipeline,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch, torchvision
from torch.utils.data import Dataset, DataLoader
import os

## Dataset & Data Loaders

In [34]:
class train_dataset(Dataset):
    def __init__(self, train_dir, temporal_radius = 1):
        self.train_dir = train_dir
        self.temporal_radius = temporal_radius
        self.video_names = os.listdir(os.path.join(train_dir, "train_sharp"))
        self.eligible_frames = [i for i in range(self.temporal_radius, 100-self.temporal_radius)]
        self.n_videos = len(self.video_names)
        self.n_eligible_frames = len(self.eligible_frames)
        self.n_total_eligible_images = self.n_videos * self.n_eligible_frames
        
        self.lr_h_bound = 180 - 128
        self.lr_w_bound = 320 - 128

    def __len__(self):
        return self.n_total_eligible_images

    def __getitem__(self, idx):
        vid_name = '{:03d}'.format(idx//self.n_eligible_frames)
        frame_name = '{:08d}.png'.format(self.temporal_radius + idx%self.n_eligible_frames)
        hr_frame = os.path.join(self.train_dir, "train_sharp", vid_name, frame_name)
        lr_frame = os.path.join(self.train_dir, "train_sharp_bicubic", "X4", vid_name, frame_name)

        hr_img = torchvision.io.read_image(hr_frame)
        lr_img = torchvision.io.read_image(lr_frame)

        ## Random Crop
        x = random.randint(0, self.lr_h_bound)
        y = random.randint(0, self.lr_w_bound)

        hr_img = hr_img[:, x*4:(x*4)+512, y*4:y*4+512]
        lr_img = lr_img[:, x:x+128, y:y+128]
        
        return {'hr_img': hr_img, "lr_img": lr_img}

In [35]:
dataset = train_dataset(train_dir = "data/train", temporal_radius = 1)

In [40]:
batch_size = 3
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [41]:
i = 1
for batch_idx, batch in enumerate(dataloader):
    print("batch_idx: {}".format(batch_idx))
    print("hr_frame: {}".format(batch['hr_img'].shape))
    print("lr_frame: {}".format(batch['lr_img'].shape))
    print("\n")
    if (i==2):
        break
    i+=1

batch_idx: 0
hr_frame: torch.Size([3, 3, 512, 512])
lr_frame: torch.Size([3, 3, 128, 128])


batch_idx: 1
hr_frame: torch.Size([3, 3, 512, 512])
lr_frame: torch.Size([3, 3, 128, 128])




In [3]:
Image.read("data/train/train_sharp/032/00000062.png")

AttributeError: module 'PIL.Image' has no attribute 'read'

In [5]:
img = torchvision.io.read_image("data/train/train_sharp/032/00000062.png")

In [6]:
img.shape

torch.Size([3, 720, 1280])

In [7]:
img2 = torchvision.io.read_image("data/train/train_sharp_bicubic/X4/032/00000062.png")

In [8]:
img2.shape

torch.Size([3, 180, 320])

In [9]:
180-128

52

In [12]:
lr_h_bound = 180 - 128
lr_w_bound = 320 - 128

In [15]:
x = random.randint(0, lr_h_bound)
y = random.randint(0, lr_w_bound)

In [17]:
img2[:, x:x+128, y:y+128]

torch.Size([3, 128, 128])

In [24]:
img[:, x*4:(x*4)+512, y*4:y*4+512].shape

torch.Size([3, 512, 512])

## Playground

In [5]:
# load model and scheduler
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id)

Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 13.66it/s]


In [6]:
pipeline

StableDiffusionUpscalePipeline {
  "_class_name": "StableDiffusionUpscalePipeline",
  "_diffusers_version": "0.28.0.dev0",
  "_name_or_path": "stabilityai/stable-diffusion-x4-upscaler",
  "feature_extractor": [
    null,
    null
  ],
  "low_res_scheduler": [
    "diffusers",
    "DDPMScheduler"
  ],
  "max_noise_level": 350,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "DDIMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ],
  "watermarker": [
    null,
    null
  ]
}

In [7]:
tokenizer = AutoTokenizer.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler",
    subfolder="tokenizer"
)
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler")
text_encoder_cls = CLIPTextModel
text_encoder_cls
text_encoder = text_encoder_cls.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler", subfolder="text_encoder"
)

In [8]:
captions = [""]
inputs = tokenizer(
    captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
encoder_hidden_states = text_encoder(inputs.input_ids, return_dict=False)[0]
encoder_hidden_states.shape

torch.Size([1, 77, 1024])

In [9]:
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,))
timesteps = timesteps.long()
timesteps

tensor([410])

In [10]:
weight_dtype = torch.float32

In [11]:
inp = torch.zeros((1, 3, 512, 512)).to(weight_dtype)
inp.shape

torch.Size([1, 3, 512, 512])

In [12]:
vae_out = pipeline.vae.encode(inp).latent_dist.sample()
vae_out.shape

torch.Size([1, 4, 128, 128])

In [14]:
vae_out_ = torch.zeros((1, 7, 128, 128)).to(weight_dtype)
vae_out_.shape

torch.Size([1, 7, 128, 128])

In [15]:
unet_out = pipeline.unet(vae_out_, timesteps, encoder_hidden_states, class_labels = torch.zeros(1).to(torch.int))
unet_out.sample.shape

torch.Size([1, 4, 128, 128])

In [16]:
out = pipeline.vae.decode(unet_out.sample)
out.sample.shape

torch.Size([1, 3, 512, 512])

## Testing the Models

In [4]:
from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel, CLIPImageProcessor
import diffusers
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline, 
    StableDiffusionControlNetImg2ImgPipeline,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)

In [33]:
pipeline = StableDiffusionControlNetImg2ImgPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    controlnet=controlnet,
    scheduler = noise_scheduler,
    safety_checker = None,
    feature_extractor = CLIPImageProcessor
)

You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet_img2img.StableDiffusionControlNetImg2ImgPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [16]:
weight_dtype = torch.float32

In [6]:
vae = AutoencoderKL.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler", subfolder="vae"
)

In [8]:
text_encoder_cls = CLIPTextModel
text_encoder_cls
text_encoder = text_encoder_cls.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler", subfolder="text_encoder"
)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler",
    subfolder="tokenizer"
)

In [10]:
unet = UNet2DConditionModel.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler", subfolder="unet"
)

In [None]:
low_res_scheduler=

In [11]:
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler")

In [13]:
controlnet = ControlNetModel.from_unet(unet)

In [14]:
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
controlnet.train()

ControlNetModel(
  (conv_in): Conv2d(7, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=256, out_features=1024, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (class_embedding): Embedding(1000, 1024)
  (controlnet_cond_embedding): ControlNetConditioningEmbedding(
    (conv_in): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (blocks): ModuleList(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): Conv2d(32, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): Conv2d(96, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

In [15]:
controlnet.enable_gradient_checkpointing()

In [19]:
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (1,))
timesteps = timesteps.long()
timesteps

tensor([262])

In [20]:
down_block_res_samples, mid_block_res_sample = controlnet(
    torch.zeros((1, 7, 32, 32)).to(dtype=weight_dtype),
    timesteps,
    encoder_hidden_states=encoder_hidden_states,
    controlnet_cond=torch.zeros((1, 3, 256, 256)).to(dtype=weight_dtype),
    return_dict=False,
    class_labels = torch.zeros(1).to(dtype=torch.int)
)

NameError: name 'encoder_hidden_states' is not defined

In [15]:
controlnet = ControlNetModel.from_unet(unet)
#controlnet.conv_in = torch.nn.Sequential(torch.nn.Conv2d(4, 7, kernel_size=1, padding=0, stride=1, bias=False), controlnet.conv_in)

In [40]:
unet(torch.zeros(1, 7, 32, 32).to(dtype=weight_dtype), )

TypeError: UNet2DConditionModel.forward() missing 2 required positional arguments: 'timestep' and 'encoder_hidden_states'

In [18]:
import bitsandbytes as bnb

In [19]:
optimizer_class = bnb.optim.AdamW8bit

In [20]:
params_to_optimize = controlnet.parameters()

In [21]:
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-2
adam_epsilon = 1e-08
optimizer = optimizer_class(
        params_to_optimize,
        lr=1e-5,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

In [23]:
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
response = requests.get(url)
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
low_res_img = low_res_img.resize((128, 128))

In [24]:
inp = torchvision.transforms.functional.pil_to_tensor(low_res_img)
inp = torch.unsqueeze(inp, 0)
inp.shape

NameError: name 'torchvision' is not defined

In [25]:
latents = vae.encode(torch.zeros((1, 3, 128, 128)).to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents.shape

torch.Size([1, 4, 32, 32])

In [26]:
out = vae.decode(latents)

In [27]:
out.sample.shape

torch.Size([1, 3, 128, 128])

In [28]:
out = vae.encode(torch.zeros((1, 3, 128, 128)).to(dtype=weight_dtype)).latent_dist.sample()
out.shape

torch.Size([1, 4, 32, 32])

In [29]:
outt = vae.decode(torch.zeros((1, 4, 32, 32)).to(dtype=weight_dtype)).sample
outt.shape

torch.Size([1, 3, 128, 128])

In [30]:
latents = vae.encode(inp.to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents.shape

NameError: name 'inp' is not defined

In [None]:
bsz = latents.shape[0]
bsz

In [None]:
noise = torch.randn_like(latents)

In [None]:
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
timesteps = timesteps.long()
timesteps

In [None]:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents.shape

In [32]:
tokenizer = AutoTokenizer.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler",
    subfolder="tokenizer"
)

In [None]:
captions = [""]
inputs = tokenizer(
    captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
encoder_hidden_states = text_encoder(inputs.input_ids, return_dict=False)[0]

In [None]:
controlnet_image = torch.zeros((1, 3, 32, 32)).to(dtype=weight_dtype)
controlnet_image.shape

In [None]:
controlnet

In [None]:
torch.zeros((1, 7, 32, 32)).to(dtype=weight_dtype)

In [None]:
sample = controlnet.conv_in(torch.zeros((1, 7, 32, 32)).to(dtype=weight_dtype)).shape

In [None]:
controlnet_cond = controlnet.controlnet_cond_embedding(torch.zeros((1, 3, 256, 256)).to(dtype=weight_dtype))
controlnet_cond.shape

In [None]:
noisy_latents.shape

In [None]:
down_block_res_samples, mid_block_res_sample = controlnet(
    torch.zeros((1, 7, 32, 32)).to(dtype=weight_dtype),
    timesteps,
    encoder_hidden_states=encoder_hidden_states,
    controlnet_cond=torch.zeros((1, 3, 256, 256)).to(dtype=weight_dtype),
    return_dict=False,
    class_labels = torch.zeros(1).to(dtype=torch.int)
)

In [None]:
model_pred = unet(
    torch.zeros((1, 7, 32, 32)).to(dtype=weight_dtype),
    timesteps,
    encoder_hidden_states=encoder_hidden_states,
    down_block_additional_residuals=[
        sample.to(dtype=weight_dtype) for sample in down_block_res_samples
    ],
    mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
    return_dict=False,
    class_labels = torch.zeros(1).to(dtype=torch.int)
)[0]

In [None]:
model_pred.shape

In [None]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)

In [None]:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

In [None]:
loss.backward()

In [None]:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)