# Setup

In [None]:
! pip install diffusers accelerate safetensors

# Stable diffusion pipeline

In [None]:
import torch
from transformers import CLIPModel, CLIPProcessor
from diffusers import StableDiffusionPipeline, DDPMScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
pipe.safety_checker = None

# Textual inversion
Classi necessarie per il training e training della Textual Inversion classica

In [None]:
import os
import random
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from types import MethodType
from torch.utils.data import Dataset as TorchDataset

prompts = [
    "a painting in the style of {}",
    "a rendering in the style of {}",
    "a cropped painting in the style of {}",
    "the painting in the style of {}",
    "a clean painting in the style of {}",
    "a dirty painting in the style of {}",
    "a dark painting in the style of {}",
    "a picture in the style of {}",
    "a cool painting in the style of {}",
    "a close-up painting in the style of {}",
    "a bright painting in the style of {}",
    "a cropped painting in the style of {}",
    "a good painting in the style of {}",
    "a close-up painting in the style of {}",
    "a rendition in the style of {}",
    "a nice painting in the style of {}",
    "a small painting in the style of {}",
    "a weird painting in the style of {}",
    "a large painting in the style of {}"
]

class ArtistDataset(TorchDataset):
  def __init__(self, data_root, tokenizer, placeholder_token, clip_output=False,
               size=512, repeats=100, flip_p=0.5, device="cpu"):
    self.data_root = data_root
    self.tokenizer = tokenizer
    self.size = size
    self.placeholder_token = placeholder_token
    self.device = device
    self.flip_p = flip_p
    self.clip_output = clip_output

    self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]

    self.num_images = len(self.image_paths)
    self._length = self.num_images * repeats

    self.templates = prompts
    self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
    self.resize = transforms.Resize(self.size)
    self.crop = transforms.CenterCrop(self.size)

    # CLIP
    self.clip_resize = transforms.Resize(224)
    self.norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])
    self.norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])
  def get_artworks_for_eval(self):
    return [Image.open(x) for x in self.image_paths]
  def __len__(self):
    return self._length
  def __getitem__(self, i):
    example = {}
    image = Image.open(self.image_paths[i % self.num_images])

    if not image.mode == "RGB":
      image = image.convert("RGB")

    if self.tokenizer:
      text = random.choice(self.templates).format(self.placeholder_token)
      tokenized = self.tokenizer(text,
                                 padding="max_length",
                                 truncation=True,
                                 max_length=self.tokenizer.model_max_length,
                                 return_tensors="pt")
      example["input_ids"] = tokenized.input_ids[0].to(self.device)

    image = image.resize((self.size, self.size))
    # image = self.resize(image)
    # image = self.crop(image)
    image = self.flip_transform(image)
    image = transforms.functional.to_tensor(image)
    #image = np.array(image).astype(np.uint8)
    #image = (image / 127.5 - 1.0).astype(np.float16)
    if self.clip_output:
      clip_input = {}
      clip_image = self.clip_resize(image)
      clip_image = (clip_image - self.norm_mean[:,None,None]) / self.norm_std[:,None,None]
      clip_input["pixel_values"] = clip_image.to(self.device, dtype=torch.float16)

    image = (2 * image - 1).to(self.device, dtype=torch.float16)

    #example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1).to(self.device)
    example["pixel_values"] = image

    if self.clip_output:
      return example, clip_input
    return example

class Dataset:
  def __init__(self, data_root, tokenizer=None, device="cpu", clip_output=False):
    self.data_root = data_root
    self.artist_list = sorted(os.listdir(self.data_root))
    self.artist_iterables = {
        f"{x}" : ArtistDataset(os.path.join(self.data_root, x), tokenizer, f"<{x}>", clip_output=clip_output, device=device)
        for x in self.artist_list
        }
  def get_artists(self):
    return self.artist_list
  def get_artists_num(self):
    return len(self.artist_list)
  def __getitem__(self, artist):
    assert artist in self.artist_list, f"{artist} is not present in the dataset"
    return self.artist_iterables[artist]


def forward(self, input_ids = None, position_ids = None, inputs_embeds = None) -> torch.Tensor:
  vocab_size = self.token_embedding.num_embeddings
  seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]

  if position_ids is None:
    position_ids = self.position_ids[:, :seq_length]

  if inputs_embeds is None:
    concept_pos = input_ids >= vocab_size
    input_ids[concept_pos] -= vocab_size

    inputs_embeds = self.token_embedding(input_ids)
    if concept_pos.any():
      inputs_embeds[concept_pos,:] = self.concept_embedding(input_ids[concept_pos])
      input_ids[concept_pos] += vocab_size

  position_embeddings = self.position_embedding(position_ids)
  embeddings = inputs_embeds + position_embeddings

  return embeddings

## Training loop

In [None]:
import torch
import torch.nn as nn
from safetensors.torch import save_file
from math import ceil
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.nn.functional import mse_loss
from accelerate import Accelerator
from tqdm import tqdm

MAX_EPOCHS = 1000
LR = 0.005
BATCH_SIZE = 4

# Custom forward function
obj = pipe.text_encoder.text_model.embeddings
obj.forward = MethodType(forward, obj)

# Dataset
dataset = Dataset('train/', pipe.tokenizer, device)

# unet and vae are converted to fp16 while the text_encoder keeps fp32
accelerator = Accelerator(mixed_precision='fp16')
vae = pipe.vae.to(accelerator.device, dtype=torch.float16)
unet = pipe.unet.to(accelerator.device, dtype=torch.float16)
text_encoder = pipe.text_encoder

# Add new tokens for each artist and their embeddings
pipe.tokenizer.add_tokens(list(map(lambda x: f'<{x}>', dataset.get_artists())))
text_encoder.text_model.embeddings.concept_embedding = nn.Embedding(dataset.get_artists_num(), 768, device=text_encoder.device)

# Freeze parameters
for m in [vae, unet, text_encoder]:
  for param in m.parameters():
    param.requires_grad = False

# Unlock only concept embeddings
text_encoder.text_model.embeddings.concept_embedding.weight.requires_grad = True
text_encoder.train()
opt = AdamW(text_encoder.text_model.embeddings.concept_embedding.parameters(), LR)

text_encoder, opt = accelerator.prepare(text_encoder, opt)

random.seed(4316)
generator = torch.Generator()
generator.manual_seed(4316)
# Train
for artist in dataset.get_artists():
  data_loader = DataLoader(dataset[artist], BATCH_SIZE, True, generator=generator)
  data_loader = accelerator.prepare(data_loader)
  num_batches = len(data_loader)
  num_epochs = ceil(MAX_EPOCHS / num_batches)

  pbar = tqdm(range(num_epochs), f"{artist}")
  for epoch in pbar:
    for i, batch in enumerate(data_loader):
      # Get image latents
      with torch.no_grad():
        latents = vae.encode(batch['pixel_values']).latent_dist.sample(generator=generator) * pipe.vae.config.scaling_factor
      bsz = latents.shape[0]

      # Sample noise
      noise = torch.randn_like(latents)

      # Sample timesteps
      timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
      timesteps = timesteps.long()

      # Add noise
      noisy_latents = scheduler.add_noise(latents, noise, timesteps)

      # Text encoding
      # text_embeddings = text_encoder(batch['input_ids'], batch['attention_mask'])[0].to(latents.dtype)
      text_embeddings = text_encoder(batch['input_ids'])[0].to(latents.dtype)

      # Model prediction
      pred = unet(noisy_latents, timesteps, text_embeddings).sample

      loss = mse_loss(pred.float(), noise.float())
      accelerator.backward(loss)
      opt.step()
      opt.zero_grad()
      pbar.set_postfix({'loss': loss.item()})
      pbar.set_postfix({'batch': f"{i}/{num_batches}"})

  # Save style embeddings (float32)
  print("Saving embeddings...")
  save_file(
      {'concept_embedding': text_encoder.text_model.embeddings.concept_embedding.weight},
      'drive/MyDrive/styles.safetensors'
  )

text_encoder.eval()
# After training convert text_encoder's weights to float16
text_encoder.to(dtype=torch.float16)

# Textual Inversion - Custom
Training delle varianti di Textual Inversion:
- Textual Inversion con Attention layers
- Textual Inversion con Mixture of Experts

In [None]:
import torch
import torch.nn as nn
from types import MethodType

class AttentionEmbedder(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super().__init__()
    self.cls = nn.Parameter(torch.randn((1, embed_dim)))
    self.attn_1 = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1, bias=False)
    self.attn_2 = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1, bias=False)
    self.mlp = nn.Linear(embed_dim, embed_dim)
    self.act = nn.GELU()
  def forward(self, x, need_weights=False):
    # x -> (n, d)
    x = torch.cat([self.cls, x], dim=0)
    x_1, coefs_1 = self.attn_1(x, x, x, need_weights=need_weights)
    x_2, coefs_2 = self.attn_2(x_1, x, x, need_weights=need_weights)

    return self.mlp(self.act(x_2[0])), coefs_1, coefs_2

class Expert(nn.Module):
  def __init__(self, in_features):
    super().__init__()
    self.l1 = nn.Linear(in_features, in_features * 2, bias=False)
    self.l2 = nn.Linear(in_features * 2, in_features, bias=False)
    self.dropout = nn.Dropout(p=0.1)
    self.act = nn.GELU()
  def forward(self, x):
    return self.l2(self.dropout(self.act(self.l1(x))))

class MoE(nn.Module):
  def __init__(self, in_features, n_experts=4, top_k=2):
    super().__init__()
    self.experts = nn.ModuleList([Expert(in_features) for _ in range(n_experts)])
    self.n_experts = n_experts
    self.top_k = top_k
    # Router
    self.noise_std = 1 / n_experts
    self.router = nn.Linear(in_features, n_experts, bias=False)
    self.softmax = nn.Softmax(dim=1)
    self.dist = torch.distributions.Normal(0, self.noise_std)
    # Compress
    self.dropout = nn.Dropout(p=0.15)
    self.compress = nn.Linear(in_features * 4, in_features)
  def forward(self, x, coef=False):
    bsz, dim = x.shape

    G_x = self.router(x)

    # Importance loss
    imp_loss = self.softmax(G_x).sum(0)
    imp_loss = (imp_loss.std() / imp_loss.mean()).pow(2)

    G_x_noised = G_x + (torch.randn_like(G_x) * self.noise_std)

    # Load loss
    threshold_k = G_x_noised.kthvalue(self.n_experts - self.top_k + 1, dim=1, keepdim=True)
    load_loss = (1 - self.dist.cdf((threshold_k.values - G_x))).sum(0)
    load_loss = (load_loss.std() / load_loss.mean()).pow(2)

    G_x_noised = self.softmax(G_x_noised)
    topk = torch.topk(G_x_noised, self.top_k, dim=1)

    tmp = torch.zeros_like(G_x_noised)
    tmp[torch.arange(bsz).view(-1,1), topk.indices] = topk.values
    G_x_noised = tmp

    # y -> (batch, n_experts, dim)
    y = torch.zeros((bsz,self.n_experts,dim), dtype=x.dtype, device=x.device)
    for i, expert in enumerate(self.experts):
      non_zero = (G_x_noised[:,i] != 0)
      if non_zero.any():
        y[non_zero,i,:] = expert(x[non_zero])

    result = torch.einsum('ijk,ij->ik', y, G_x_noised)
    result = self.compress(self.dropout(result.view(-1)))

    if coef:
      return result, G_x_noised
    return result, 0.5 * (imp_loss + load_loss)

def forward(self, input_ids = None, position_ids = None, inputs_embeds = None) -> torch.Tensor:
  vocab_size = self.token_embedding.num_embeddings
  seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]

  if position_ids is None:
    position_ids = self.position_ids[:, :seq_length]

  if inputs_embeds is None:
    concept_pos = input_ids >= vocab_size
    input_ids[concept_pos] -= vocab_size

    inputs_embeds = self.token_embedding(input_ids)
    if concept_pos.any():
      inputs_embeds[concept_pos,:] = self.vision_input
      input_ids[concept_pos] += vocab_size

  position_embeddings = self.position_embedding(position_ids)
  embeddings = inputs_embeds + position_embeddings

  return embeddings

## Training loop

In [None]:
import torch
import torch.nn as nn
from safetensors.torch import save_file
from math import ceil
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.nn.functional import mse_loss
from accelerate import Accelerator
from tqdm import tqdm

MAX_EPOCHS = 1000
LR = 0.001
BATCH_SIZE = 4
MODEL = 'moe'     # moe or attention

# Custom forward function
obj = pipe.text_encoder.text_model.embeddings
obj.forward = MethodType(forward, obj)

# Dataset
dataset = Dataset('train/', pipe.tokenizer, device, clip_output=True)

# unet and vae are converted to fp16 while the text_encoder keeps fp32
accelerator = Accelerator(mixed_precision='fp16')
vae = pipe.vae.to(accelerator.device, dtype=torch.float16)
unet = pipe.unet.to(accelerator.device, dtype=torch.float16)
vision_model = clip.vision_model.to(accelerator.device)#, dtype=torch.float16)
visual_projection = clip.visual_projection.to(accelerator.device)#, dtype=torch.float16)
text_encoder = pipe.text_encoder

# Add new tokens for each artist and their embeddings
pipe.tokenizer.add_tokens(list(map(lambda x: f'<{x}>', dataset.get_artists())))
concept_embeddings = torch.zeros((dataset.get_artists_num(), 768), device=accelerator.device)

# Freeze parameters
for m in [vision_model, visual_projection, vae, unet, text_encoder]:
  for param in m.parameters():
    param.requires_grad = False

text_encoder.train()
text_encoder = accelerator.prepare(text_encoder)
# Train
for j, artist in enumerate(dataset.get_artists()):
  # Reproducibility
  torch.manual_seed(4316)
  random.seed(4316)
  if MODEL == 'moe':
    embedder = MoE(768)
  else:
    embedder = AttentionEmbedder(768, 8)
  embedder.to(accelerator.device)
  embedder.train()
  opt = AdamW(embedder.parameters(), LR)

  data_loader = DataLoader(dataset[artist], BATCH_SIZE, True)
  num_batches = len(data_loader)
  num_epochs = ceil(MAX_EPOCHS / num_batches)

  clip_emb = clip_processor(images=dataset[artist].get_artworks_for_eval(), return_tensors='pt')
  clip_emb = visual_projection(vision_model(clip_emb['pixel_values'].to(device)).pooler_output)

  #data_loader, embedder, opt = accelerator.prepare(data_loader, embedder, opt)
  data_loader, opt = accelerator.prepare(data_loader, opt)
  pbar = tqdm(range(num_epochs), f"{artist}")
  for epoch in pbar:
    for i, (batch, clip_input) in enumerate(data_loader):
      # Get image latents
      with torch.no_grad():
        latents = vae.encode(batch['pixel_values']).latent_dist.sample() * pipe.vae.config.scaling_factor
      bsz = latents.shape[0]

      # Sample noise
      noise = torch.randn_like(latents)

      # Sample timesteps
      timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
      timesteps = timesteps.long()

      # Add noise
      noisy_latents = scheduler.add_noise(latents, noise, timesteps)

      # Text encoding
      #clip_emb = vision_model(clip_input['pixel_values']).pooler_output
      #clip_emb = visual_projection(clip_emb)
      embedder_out = embedder(clip_emb)
      token_emb = embedder_out[0]
      text_encoder.text_model.embeddings.vision_input = token_emb.repeat(bsz, 1)
      text_embeddings = text_encoder(batch['input_ids'])[0].to(latents.dtype)

      # Model prediction
      pred = unet(noisy_latents, timesteps, text_embeddings).sample

      loss = mse_loss(pred.float(), noise.float())
      if MODEL == 'moe':
        loss += 0.01 * embedder_out[1]
      accelerator.backward(loss)
      opt.step()
      opt.zero_grad()
      pbar.set_postfix({'loss': loss.item()})
      pbar.set_postfix({'batch': f"{i}/{num_batches}"})

  # Generate style embeddings
  concept_embeddings[j,:] = embedder(clip_emb)[0]

  # Save style embeddings (float32)
  print("Saving embeddings...")
  save_file(
      {'concept_embedding': concept_embeddings},
      'drive/MyDrive/styles_moe.safetensors'
  )

text_encoder.eval()
embedder.eval()
# After training convert text_encoder's weights to float16
text_encoder.to(dtype=torch.float16)

# CLIP evaluation
Valutazione quantitativa dei risultati

Caricamento degli embeddings salvati in *styles.safetensors*

In [None]:
from safetensors import safe_open

# This should be the same one used during training
dataset = Dataset('train/', None, device)

new_tokens = pipe.tokenizer.add_tokens(list(map(lambda x: f'<{x}>', dataset.get_artists())))
pipe.text_encoder.resize_token_embeddings(pipe.text_encoder.config.vocab_size + new_tokens)

with safe_open("styles.safetensors", framework="pt", device=device) as f:
  style_emb = f.get_tensor('concept_embedding')

pipe.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
pipe.text_encoder.text_model.embeddings.token_embedding.weight[-new_tokens:,:] = style_emb
pipe.to(device, dtype=torch.float16)

Evaluation loop

In [None]:
import torch
import pandas as pd
from tqdm import tqdm

eval_prompts = [
    'A portrait of a large family in the style of {}',
    'A close up of a young man in the style of {}',
    'A naturalistic landscape in the style of {}',
    'A landscape of a city in the style of {}'
]

dataset = Dataset('train/', None, device)
prompts = clip_processor(text=list(map(lambda x: x.replace('in the style of {}', '').strip(), eval_prompts)),
                         return_tensors='pt', padding=True)
# Initialize result tables
acc = {'Artist': []}
acc.update({f'prompt_{i + 1}': [] for i in range(len(eval_prompts))})
edi = {'Artist': []}
edi.update({f'prompt_{i + 1}': [] for i in range(len(eval_prompts))})

with torch.no_grad():
  prompts_emb = clip.get_text_features(**prompts)
  prompts_emb = prompts_emb / prompts_emb.norm(p=2, dim=-1, keepdim=True)

for artist_name in tqdm(dataset.get_artists()):
  acc['Artist'].append(artist_name)
  edi['Artist'].append(artist_name)

  artist = dataset[artist_name]
  artworks = artist.get_artworks_for_eval()
  artworks = clip_processor(images=artworks, return_tensors='pt')
  with torch.no_grad():
    artworks_emb = clip.get_image_features(**artworks)
    artworks_emb = artworks_emb / artworks_emb.norm(p=2, dim=-1, keepdim=True)

  generator = torch.Generator(device)
  generator.manual_seed(4316)
  for i, prompt in enumerate(eval_prompts):
    sd_out = pipe(prompt.format(artist.placeholder_token),
    #sd_out = pipe(prompt.format(' '.join([x.capitalize() for x in artist_name.split('-')])),
                  num_images_per_prompt=8,
                  generator=generator).images
    sd_out = clip_processor(images=sd_out, return_tensors='pt')
    with torch.no_grad():
      sd_out_emb = clip.get_image_features(**sd_out)
      sd_out_emb = sd_out_emb / sd_out_emb.norm(p=2, dim=-1, keepdim=True)

    # Similarity between real and generated artworks (accuracy)
    acc_score = torch.matmul(artworks_emb, sd_out_emb.t())
    acc_score = acc_score.mean(dim=1).max().item() # o usare la media?
    acc[f'prompt_{i + 1}'].append(acc_score)

    # Similarity between textual prompt and generated artworks (editability)
    edi_score = torch.matmul(prompts_emb[i], sd_out_emb.t())
    edi_score = edi_score.mean().item()
    edi[f'prompt_{i + 1}'].append(edi_score)

acc_table = pd.DataFrame(acc)
edi_table = pd.DataFrame(edi)
acc_table.to_csv('accuracy.tsv', sep='\t')
edi_table.to_csv('editability.tsv', sep='\t')

# Immagini generate per artista
Genera un immagine per ogni artista/prompt per la valutazione visiva.
Gli embedding devono essere già caricati nel modello.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from random import choice

def add_headers(
    fig,
    *,
    row_headers=None,
    col_headers=None,
    row_pad=5,
    col_pad=5,
    rotate_row_headers=True,
    **text_kwargs
):
  axes = fig.get_axes()

  for ax in axes:
    sbs = ax.get_subplotspec()

    # Putting headers on cols
    if (col_headers is not None) and sbs.is_first_row():
      ax.annotate(
        col_headers[sbs.colspan.start],
        xy=(0.5, 1),
        xytext=(0, col_pad),
        xycoords="axes fraction",
        textcoords="offset points",
        ha="center",
        va="baseline",
        **text_kwargs,
      )

    # Putting headers on rows
    if (row_headers is not None) and sbs.is_first_col():
      ax.annotate(
        row_headers[sbs.rowspan.start],
        xy=(0, 0.5),
        xytext=(-ax.yaxis.labelpad - row_pad, 0),
        #xycoords=ax.yaxis.label,
        xycoords="axes fraction",
        textcoords="offset points",
        ha="right",
        va="center",
        rotation=rotate_row_headers * 90,
        **text_kwargs,
      )

prompts = [
    'A portrait of a large family in the style of {}',
    'A close up of a young man in the style of {}',
    'A naturalistic landscape in the style of {}',
    'A landscape of a city in the style of {}'
]

dataset = Dataset('train/', None, device)
images = []

for artist_name in dataset.get_artists():
  artist = dataset[artist_name]
  images.append(choice(dataset[artist_name].get_artworks_for_eval()))
  for prompt in prompts:
    sd_out = pipe(prompt.format(artist.placeholder_token)).images
    #sd_out = pipe(prompt.format(' '.join([x.capitalize() for x in artist_name.split('-')])))
    images.append(sd_out[0])

images_np = [np.array(img.resize((128,128))) for img in images]

n_cols = len(prompts) + 1
n_rows = dataset.get_artists_num()
spacing = 5
fig, axs = plt.subplots(n_rows, n_cols, figsize=(14, 25))

for idx, ax in enumerate(axs.reshape(-1)):
  ax.imshow(images_np[idx])
  ax.axis('off')

plt.subplots_adjust(wspace=spacing/100, hspace=spacing/100)

prompts = ['Original'] + [p.replace('in the style of {}', '').strip() for p in prompts]
row_labels = [' '.join([x.capitalize() for x in y.split('-')]) for y in dataset.get_artists()]
add_headers(fig, col_headers=prompts, row_headers=row_labels, fontsize=12)

fig.show()