論文  
https://arxiv.org/abs/2211.00575<br>
<br>
GitHub  
https://github.com/DavidHuji/CapDec<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/CapDec_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## GitHubからコード取得

## ライブラリのインストール

In [None]:
%cd /content

!pip install git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
!pip install diffusers==0.7.2 transformers scipy ftfy
!pip install --upgrade gdown

In [None]:
access_tokens="ここにアクセストークンを設定" # @param {type:"string"}

## ライブラリのインポート

In [None]:
%cd /content

import os



# For Stable Diffusion
import torch
from torch import autocast
device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt

# For CapDec
from typing import Tuple, List, Union, Optional
import numpy as np

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch import nn
import torch.nn.functional as nnf

import clip

from PIL import Image

## 学習済みモデルのダウンロード

In [None]:
# 015.pt
model_path = './015.pt'

!gdown --id 1zfadJ41bo8HrCKTdoVsKGHWO4UBXCy9W \
        -O {model_path}

# Text to Image

## load model

In [None]:
# load pretrain model
stable_diffusion_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=access_tokens)
stable_diffusion_model.to("cuda")

## Inference

In [None]:
prompt = "A man and a woman posing for a picture next to tokyo tower." #@param {type:"string"}

In [None]:
!mkdir outputs

create_img = "outputs/test_01.png"

# 乱数固定
generator = torch.Generator(device).manual_seed(0)

# モデルにpromptを入力し画像生成
image = stable_diffusion_model(
    prompt,
    generator=generator,
    width=720, height=512)[0][0]
# 保存
image.save(create_img)

In [None]:
plt.imshow(plt.imread(create_img))
plt.axis('off')
plt.show()

# Image Captioning

## load model

In [None]:
class MultiHeadAttention(nn.Module):

  def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
    super().__init__()
    self.num_heads = num_heads
    head_dim = dim_self // num_heads
    self.scale = head_dim ** -0.5
    self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
    self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
    self.project = nn.Linear(dim_self, dim_self)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, y=None, mask=None):
    y = y if y is not None else x
    b, n, c = x.shape
    _, m, d = y.shape
    # b n h dh
    queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
    # b m 2 h dh
    keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
    keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
    attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
    if mask is not None:
      if mask.dim() == 2:
        mask = mask.unsqueeze(1)
      attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
    attention = attention.softmax(dim=2)
    out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
    out = self.project(out)
    return out, attention

class MlpTransformer(nn.Module):
  def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
    super().__init__()
    out_d = out_d if out_d is not None else in_dim
    self.fc1 = nn.Linear(in_dim, h_dim)
    self.act = act
    self.fc2 = nn.Linear(h_dim, out_d)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.dropout(x)
    x = self.fc2(x)
    x = self.dropout(x)
    return x

In [None]:
class TransformerLayer(nn.Module):

  def forward_with_attention(self, x, y=None, mask=None):
    x_, attention = self.attn(self.norm1(x), y, mask)
    x = x + x_
    x = x + self.mlp(self.norm2(x))
    return x, attention

  def forward(self, x, y=None, mask=None):
    x = x + self.attn(self.norm1(x), y, mask)[0]
    x = x + self.mlp(self.norm2(x))
    return x

  def __init__(
      self, dim_self, dim_ref, num_heads, mlp_ratio=4., 
      bias=False, dropout=0., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm):
    super().__init__()
    self.norm1 = norm_layer(dim_self)
    self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
    self.norm2 = norm_layer(dim_self)
    self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)


class Transformer(nn.Module):

  def forward_with_attention(self, x, y=None, mask=None):
    attentions = []
    for layer in self.layers:
      x, att = layer.forward_with_attention(x, y, mask)
      attentions.append(att)
    return x, attentions

  def forward(self, x, y=None, mask=None):
    for i, layer in enumerate(self.layers):
      if i % 2 == 0 and self.enc_dec: # cross
        x = layer(x, y)
      elif self.enc_dec:  # self
        x = layer(x, x, mask)
      else:  # self or cross
        x = layer(x, y, mask)
    return x

  def __init__(
      self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
      mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
    super(Transformer, self).__init__()
    dim_ref = dim_ref if dim_ref is not None else dim_self
    self.enc_dec = enc_dec
    if enc_dec:
      num_layers = num_layers * 2
    layers = []
    for i in range(num_layers):
      if i % 2 == 0 and enc_dec:  # cross
        layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
      elif enc_dec:  # self
        layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
      else:  # self or cross
        layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
    self.layers = nn.ModuleList(layers)

In [None]:
class TransformerMapper(nn.Module):

    def forward(self, x):
      x = self.linear(x).view(x.shape[0], self.clip_length, -1)
      prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
      prefix = torch.cat((x, prefix), dim=1)
      out = self.transformer(prefix)[:, self.clip_length:]
      return out

    def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
      super(TransformerMapper, self).__init__()
      self.clip_length = clip_length
      self.transformer = Transformer(dim_embedding, 8, num_layers)
      self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
      self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)

class ClipCaptionModel(nn.Module):
  def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
    return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

  def forward(
      self, tokens: torch.Tensor, prefix: torch.Tensor, 
      mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None
      ):
    embedding_text = self.gpt.transformer.wte(tokens)
    prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
    embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
    if labels is not None:
      dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
      labels = torch.cat((dummy_token, tokens), dim=1)
    out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
    return out

  def __init__(
      self, prefix_size: int = 640, num_layers: int = 8
      ):
    super(ClipCaptionModel, self).__init__()
    self.prefix_length = 40
    self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
    self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
    self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, 40, 40, num_layers)

In [None]:
# build model
cap_model = ClipCaptionModel()

# load weight
cap_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) 
cap_model = cap_model.eval() 
cap_model = cap_model.to(device)

In [None]:
clip_model, preprocess = clip.load("RN50x4", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

## Inference

In [None]:
def generate_beam(
    model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
    entry_length=67, temperature=1., stop_token: str = '.'):
    
  model.eval()
  stop_token_index = tokenizer.encode(stop_token)[0]
  tokens = None
  scores = None
  device = next(model.parameters()).device
  seq_lengths = torch.ones(beam_size, device=device)
  is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)

  with torch.no_grad():
    if embed is not None:
      generated = embed
    else:
      if tokens is None:
        tokens = torch.tensor(tokenizer.encode(prompt))
        tokens = tokens.unsqueeze(0).to(device)
        generated = model.gpt.transformer.wte(tokens)
    for i in range(entry_length):
      outputs = model.gpt(inputs_embeds=generated)
      logits = outputs.logits
      logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
      logits = logits.softmax(-1).log()
      if scores is None:
        scores, next_tokens = logits.topk(beam_size, -1)
        generated = generated.expand(beam_size, *generated.shape[1:])
        next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
        if tokens is None:
          tokens = next_tokens
        else:
          tokens = tokens.expand(beam_size, *tokens.shape[1:])
          tokens = torch.cat((tokens, next_tokens), dim=1)
      else:
        logits[is_stopped] = -float(np.inf)
        logits[is_stopped, 0] = 0
        scores_sum = scores[:, None] + logits
        seq_lengths[~is_stopped] += 1
        scores_sum_average = scores_sum / seq_lengths[:, None]
        scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
        next_tokens_source = next_tokens // scores_sum.shape[1]
        seq_lengths = seq_lengths[next_tokens_source]
        next_tokens = next_tokens % scores_sum.shape[1]
        next_tokens = next_tokens.unsqueeze(1)
        tokens = tokens[next_tokens_source]
        tokens = torch.cat((tokens, next_tokens), dim=1)
        generated = generated[next_tokens_source]
        scores = scores_sum_average * seq_lengths
        is_stopped = is_stopped[next_tokens_source]
      next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
      generated = torch.cat((generated, next_token_embed), dim=1)
      is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
      if is_stopped.all():
        break
  scores = scores / seq_lengths
  output_list = tokens.cpu().numpy()
  output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
  order = scores.argsort(descending=True)
  output_texts = [output_texts[i] for i in order]
  return output_texts

In [None]:
pil_image = Image.open(create_img)
image = preprocess(pil_image).unsqueeze(0).to(device)

with torch.no_grad():
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
  prefix_embed = cap_model.clip_project(prefix).reshape(1, 40, -1)

generated_text_prefix = generate_beam(cap_model, tokenizer, embed=prefix_embed)[0]

print("\nCaption:", generated_text_prefix)

# Image Caption to Image

In [None]:
create_img = "outputs/test_02.png"

# 乱数固定

# モデルにpromptを入力し画像生成
image = stable_diffusion_model(
    generated_text_prefix,
    generator=generator,
    width=720, height=512)[0][0]
# 保存
image.save(create_img)

In [None]:
plt.imshow(plt.imread(create_img))
plt.axis('off')
plt.show()