# Pre-compute the mid-U features

In [None]:
!pip install -qq transformers diffusers
!pip install -qq datasets torchvision
!pip install -qq wandb

In [None]:
from diffusers import StableDiffusionPipeline, DDPMScheduler

device = "cuda"
pretrained_model_name = "stabilityai/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name).to(device)
def hook_fn(module, input, output):
    module.output = output
pipe.unet.mid_block.register_forward_hook(hook_fn)

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
import torch
from diffusers import DDPMScheduler

device = "cuda"

pretrained_model_name = "stabilityai/stable-diffusion-2-1-base"
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_name, subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained(
        pretrained_model_name, subfolder="tokenizer"
    )

vae.requires_grad_(False)
text_encoder.requires_grad_(False)

weight_dtype = torch.float32
text_encoder.to(device, dtype=weight_dtype);
vae.to(device, dtype=weight_dtype);

In [None]:
from datasets import load_dataset
from torchvision import transforms
from PIL import Image
import torch
import random
import numpy as np

im_dir = "data/laion-art"

resolution = 512
train_batch_size = 4
dataloader_num_workers = 8
lr = 1e-4

train_transforms = transforms.Compose(
    [
        transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(resolution),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples['text']:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column 'text' should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

def preprocess_train(examples):
    images = [image.convert("RGB") for image in examples["image"]]
    examples["pixel_values"] = [train_transforms(image) for image in images]
    examples["input_ids"] = tokenize_captions(examples)
    examples["aesthetic"] = torch.tensor(examples["aesthetic"]).float()
    return examples

dataset = load_dataset("fantasyfish/laion-art")
train_dataset = dataset["train"].with_transform(preprocess_train)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([example["input_ids"] for example in examples])
    aethetics = torch.FloatTensor([example["aesthetic"] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids, "aesthetics": aethetics}

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=train_batch_size,
    num_workers=dataloader_num_workers,
)

In [None]:
print(batch["pixel_values"].shape)
latents = vae.encode(batch["pixel_values"].to(device, dtype=weight_dtype)).latent_dist.sample()
print(latents.shape)

In [None]:
num_inference_steps = 30
features_all = []

for step, batch in enumerate(train_dataloader):

    latents = vae.encode(batch["pixel_values"].to(device, dtype=weight_dtype)).latent_dist.sample()
    latents = latents * vae.config.scaling_factor # 0.18215

    # Sample noise that we'll add to the latents
    noise = torch.randn_like(latents)
    # # https://www.crosslabs.org//blog/diffusion-with-offset-noise
    # noise += args.noise_offset * torch.randn(
    #     (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
    # )

    bsz = latents.shape[0]
    # Sample a random timestep for each image
    timesteps = torch.randint(0, num_inference_steps, (bsz,), device=latents.device)
    timesteps = timesteps.long()

    # Add noise to the latents according to the noise magnitude at each timestep
    # (this is the forward diffusion process)
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

    # Get the text embedding for conditioning
    encoder_hidden_states = text_encoder(batch["input_ids"].long().to(device))[0]
    _ = pipe.unet(noisy_latents, timesteps.to(device), encoder_hidden_states).sample
    features = pipe.unet.mid_block.output
    features_all.append(features)

features_all = torch.vstack(features_all)
torch.save(features_all, "midU_features.pt")

# Train the aesthetic classifier

In [2]:
from datasets import load_dataset
import torch

batch_size = 64
dataloader_num_workers = 16
dataset = load_dataset("fantasyfish/laion-art")
train_dataset = dataset["train"].remove_columns(["image", "text"])
test_dataset = dataset["test"].remove_columns(["image", "text"])

def collate_fn(examples):
    aethetics = torch.FloatTensor([example["aesthetic"] for example in examples])
    return {"aesthetics": aethetics}
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=batch_size,
    num_workers=dataloader_num_workers,
)



  0%|          | 0/2 [00:00<?, ?it/s]



In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
mount_dir = '/content/drive/MyDrive'
midU_features = torch.load(os.path.join(mount_dir, "midU_features.pt"), map_location=torch.device("cpu"))
test_midU_features = torch.load(os.path.join(mount_dir, "test_midU_features.pt"), map_location=torch.device("cpu"))

In [4]:
import torch.nn.functional as F

min_val, max_val = min(train_dataset['aesthetic']), max(train_dataset['aesthetic'])
print(min_val, max_val)
print(min(test_dataset['aesthetic']), max(test_dataset['aesthetic']))
def encode_onehot(labels, min_val, max_val):
    index = torch.floor(((labels - min_val) / (max_val - min_val + 1e-4) * 10.0)).long()
    return F.one_hot(index, num_classes=10).float()

# batch = next(iter(train_dataloader))
# encode_onehot(batch['aesthetics'], min_val, max_val)[:10]

8.000041961669922 10.2408447265625
8.00003433227539 10.080093383789062


In [5]:
from torch import nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler

device = "cuda"
lr = 1e-4
model = nn.Sequential(
    nn.Conv2d(1280, 256, kernel_size=3, padding=1), nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(256, 128, kernel_size=3, padding=1                                                                                                                ), nn.ReLU(),
    nn.AdaptiveAvgPool2d(output_size=(2, 2)), nn.Flatten(),
    nn.Linear(128*4, 64), nn.ReLU(), nn.Linear(64, 10)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-2)

In [6]:
%env WANDB_NOTEBOOK_NAME prompt_engineer_exploration.ipynb
import wandb
wandb.login()

wandb.init(project="aesthetic-classifier",
           config={
               "batch_size": batch_size,
               "learning_rate": lr,
               "dataset": "liason-art",
           })

env: WANDB_NOTEBOOK_NAME=prompt_engineer.ipynb


[34m[1mwandb[0m: Currently logged in as: [33mfantasy-fish[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
%%wandb

global_step = 0

n_epochs = 100
n_epochs_decay_start = 50
n_epochs_save_checkpoint = 5
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1)
train_min_val, train_max_val = min(train_dataset['aesthetic']), max(train_dataset['aesthetic'])
test_min_val, test_max_val = min(test_dataset['aesthetic']), max(test_dataset['aesthetic'])

for epoch in range(n_epochs):
    for batch_id, batch in enumerate(train_dataloader, 0):
        global_step += batch_size
        optimizer.zero_grad() # a clean up step for PyTorc

        features = midU_features[batch_id*batch_size:(batch_id+1)*batch_size].to(device)
        logits = model(features).to(device)
        labels = encode_onehot(batch['aesthetics'], train_min_val, train_max_val).to(device)
        loss = criterion(logits, labels)

        loss.backward() # compute updates for each parameter
        optimizer.step() # make the updates for each parameter

        wandb.log({"step_loss": loss.detach().item() / batch_size, "lr": scheduler.get_last_lr()[0]})

        if epoch >= n_epochs_decay_start:
            lr_scheduler.step()

    # run validation
    model.eval()
    with torch.no_grad():
        test_loss = 0.0
        for batch_id, batch in enumerate(test_dataloader):
            features = test_midU_features[batch_id*batch_size:(batch_id+1)*batch_size].to(device)
            logits = model(features).to(device)
            labels = encode_onehot(batch['aesthetics'], test_min_val, test_max_val).to(device)
            loss = criterion(logits, labels)
            test_loss += loss.detach().item() / len(test_dataset)
        wandb.log({"test_loss": test_loss})
    model.train()

    if epoch > 0 and epoch % n_epochs_save_checkpoint == 0:
        torch.save(model.state_dict(), f"model_{epoch}.pt")