論文<br>
https://arxiv.org/abs/2202.03052<br>
<br>
GitHub<br>
https://github.com/OFA-Sys/OFA<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/OFA_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## GitHubからコード取得

In [None]:
%cd /content

!git clone https://github.com/OFA-Sys/OFA.git

## 学習済みモデルのダウンロード

In [None]:
%cd /content

!mkdir -p /content/OFA/checkpoints/
!wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/ofa_large_clean.pt
!mv ofa_large_clean.pt OFA/checkpoints/ofa_large.pt

## ライブラリのインストール

In [None]:
# fairseq
# RESTART RUNTIMEが表示された場合「ランタイムを再起動」
%cd /content
!git clone https://github.com/pytorch/fairseq.git

%cd /content/fairseq
!pip install --use-feature=in-tree-build ./

In [None]:
%cd /content/OFA

# 1行目は削除
!sed '1d' requirements.txt | xargs -I {} pip install {}

## ライブラリのインポート

In [None]:
import torch
import numpy as np
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from tasks.mm_tasks.refcoco import RefcocoTask

from models.ofa import OFAModel
from PIL import Image

import cv2
import numpy
from google.colab.patches import cv2_imshow

tasks.register_task('refcoco', RefcocoTask)

# turn on cuda if GPU is available
use_cuda = torch.cuda.is_available()
# use fp16 only when GPU is available
use_fp16 = False

# specify some options for evaluation
parser = options.get_generation_parser()
input_args = ["", "--task=refcoco", "--beam=10", "--path=checkpoints/ofa_large.pt", "--bpe-dir=utils/BPE"]
args = options.parse_args_and_arch(parser, input_args)
cfg = convert_namespace_to_omegaconf(args)

# モデルビルド

In [None]:
# configファイルと学習済みモデルのロード
task = tasks.setup_task(cfg.task)
models, cfg = checkpoint_utils.load_model_ensemble(
    utils.split_paths(cfg.common_eval.path),
    task=task
)

# GPUに載せる
for model in models:
    model.eval()
    if use_fp16:
        model.half()
    if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
        model.cuda()
    model.prepare_for_inference_(cfg)

# generatorの初期化
generator = task.build_generator(models, cfg.generation)

# Preprocess
transformation定義

In [None]:
# Image transform
from torchvision import transforms
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

patch_resize_transform = transforms.Compose([
    lambda image: image.convert("RGB"),
    transforms.Resize((task.cfg.patch_image_size, task.cfg.patch_image_size), interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# Text preprocess
bos_item = torch.LongTensor([task.src_dict.bos()])
eos_item = torch.LongTensor([task.src_dict.eos()])
pad_idx = task.src_dict.pad()


def get_symbols_to_strip_from_output(generator):
    if hasattr(generator, "symbols_to_strip_from_output"):
        return generator.symbols_to_strip_from_output
    else:
        return {generator.bos, generator.eos}


def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
    x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
    token_result = []
    bin_result = []
    img_result = []
    for token in x.strip().split():
      if token.startswith('<bin_'):
        bin_result.append(token)
      elif token.startswith('<code_'):
        img_result.append(token)
      else:
        if bpe is not None:
          token = bpe.decode('{}'.format(token))
        if tokenizer is not None:
          token = tokenizer.decode(token)
        if token.startswith(' ') or len(token_result) == 0:
          token_result.append(token.strip())
        else:
          token_result[-1] += token

    return ' '.join(token_result), ' '.join(bin_result), ' '.join(img_result)


def coord2bin(coords, w_resize_ratio, h_resize_ratio):
    coord_list = [float(coord) for coord in coords.strip().split()]
    bin_list = []
    bin_list += ["<bin_{}>".format(int((coord_list[0] * w_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))]
    bin_list += ["<bin_{}>".format(int((coord_list[1] * h_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))]
    bin_list += ["<bin_{}>".format(int((coord_list[2] * w_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))]
    bin_list += ["<bin_{}>".format(int((coord_list[3] * h_resize_ratio / task.cfg.max_image_size * (task.cfg.num_bins - 1))))]
    return ' '.join(bin_list)


def bin2coord(bins, w_resize_ratio, h_resize_ratio):
    bin_list = [int(bin[5:-1]) for bin in bins.strip().split()]
    coord_list = []
    coord_list += [bin_list[0] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / w_resize_ratio]
    coord_list += [bin_list[1] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / h_resize_ratio]
    coord_list += [bin_list[2] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / w_resize_ratio]
    coord_list += [bin_list[3] / (task.cfg.num_bins - 1) * task.cfg.max_image_size / h_resize_ratio]
    return coord_list


def encode_text(text, length=None, append_bos=False, append_eos=False):
    line = [
      task.bpe.encode(' {}'.format(word.strip())) 
      if not word.startswith('<code_') and not word.startswith('<bin_') else word
      for word in text.strip().split()
    ]
    line = ' '.join(line)
    s = task.tgt_dict.encode_line(
        line=line,
        add_if_not_exist=False,
        append_eos=False
    ).long()
    if length is not None:
        s = s[:length]
    if append_bos:
        s = torch.cat([bos_item, s])
    if append_eos:
        s = torch.cat([s, eos_item])
    return s

def construct_sample(image: Image, instruction: str):
    patch_image = patch_resize_transform(image).unsqueeze(0)
    patch_mask = torch.tensor([True])

    instruction = encode_text(' {}'.format(instruction.lower().strip()), append_bos=True, append_eos=True).unsqueeze(0)
    instruction_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in instruction])
    sample = {
        "id":np.array(['42']),
        "net_input": {
            "src_tokens": instruction,
            "src_lengths": instruction_length,
            "patch_images": patch_image,
            "patch_masks": patch_mask,
        }
    }
    return sample
  
# Function to turn FP32 to FP16
def apply_half(t):
    if t.dtype is torch.float32:
        return t.to(dtype=torch.half)
    return t

# Image Captioning

## テスト画像取得

In [None]:
%cd /content/OFA
!mkdir test_imgs
%cd test_imgs

# https://www.pakutaso.com/20180239038post-15116.html
!wget https://www.pakutaso.com/shared/img/thumb/smIMGL4174_TP_V4.jpg

%cd /content/OFA
image = Image.open('/content/OFA/test_imgs/smIMGL4174_TP_V4.jpg')

In [None]:
instruction = "what does the image describe?"

# Construct input sample & preprocess for GPU if cuda available
sample = construct_sample(image, instruction)
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample

# Generate result
with torch.no_grad():
    hypos = task.inference_step(generator, models, sample)
    tokens1, bins1, imgs1 = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens2, bins2, imgs2 = decode_fn(hypos[0][1]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens3, bins3, imgs3 = decode_fn(hypos[0][2]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens4, bins4, imgs4 = decode_fn(hypos[0][3]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens5, bins5, imgs5 = decode_fn(hypos[0][4]["tokens"], task.tgt_dict, task.bpe, generator)

# display result
display(image)
print('Instruction: {}'.format(instruction))
print('OFA\'s Output1: {}, Probs: {}'.format(tokens1, hypos[0][0]["score"].exp().item()))
print('OFA\'s Output2: {}, Probs: {}'.format(tokens2, hypos[0][1]["score"].exp().item()))
print('OFA\'s Output3: {}, Probs: {}'.format(tokens3, hypos[0][2]["score"].exp().item()))
print('OFA\'s Output4: {}, Probs: {}'.format(tokens4, hypos[0][3]["score"].exp().item()))
print('OFA\'s Output5: {}, Probs: {}'.format(tokens5, hypos[0][4]["score"].exp().item()))

# Visual Question Answering: VQA

## テスト画像取得

In [None]:
%cd /content/OFA
!mkdir test_imgs
%cd test_imgs

# https://www.pakutaso.com/20180239038post-15116.html
!wget https://www.pakutaso.com/shared/img/thumb/smIMGL4174_TP_V4.jpg

%cd /content/OFA
image = Image.open('/content/OFA/test_imgs/smIMGL4174_TP_V4.jpg')

## Question設定

In [None]:
instruction = "What does a woman have in her left hand?" #@param {type:"string"}

In [None]:
# Construct input sample & preprocess for GPU if cuda available
sample = construct_sample(image, instruction)
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample

# Generate result
with torch.no_grad():
    hypos = task.inference_step(generator, models, sample)
    tokens1, bins1, imgs1 = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens2, bins2, imgs2 = decode_fn(hypos[0][1]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens3, bins3, imgs3 = decode_fn(hypos[0][2]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens4, bins4, imgs4 = decode_fn(hypos[0][3]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens5, bins5, imgs5 = decode_fn(hypos[0][4]["tokens"], task.tgt_dict, task.bpe, generator)

# display result
display(image)
print('Instruction: {}'.format(instruction))
print('OFA\'s Output1: {}, Probs: {}'.format(tokens1, hypos[0][0]["score"].exp().item()))
print('OFA\'s Output2: {}, Probs: {}'.format(tokens2, hypos[0][1]["score"].exp().item()))
print('OFA\'s Output3: {}, Probs: {}'.format(tokens3, hypos[0][2]["score"].exp().item()))
print('OFA\'s Output4: {}, Probs: {}'.format(tokens4, hypos[0][3]["score"].exp().item()))
print('OFA\'s Output5: {}, Probs: {}'.format(tokens5, hypos[0][4]["score"].exp().item()))

# Grounded QA

## テスト画像取得

In [None]:
%cd /content/OFA
!mkdir test_imgs
%cd test_imgs

# https://www.pakutaso.com/20180239038post-15116.html
!wget https://www.pakutaso.com/shared/img/thumb/smIMGL4174_TP_V4.jpg

%cd /content/OFA
image = Image.open('/content/OFA/test_imgs/smIMGL4174_TP_V4.jpg')

# Coordinate, Questions設定

In [None]:
coords = "522.0 396.0 595.0 469.0" #@param {type:"string"}
w, h = image.size
w_resize_ratio = task.cfg.patch_image_size / w
h_resize_ratio = task.cfg.patch_image_size / h
bins = coord2bin(coords, w_resize_ratio, h_resize_ratio)
question = "What's in the region?" #@param {type:"string"}
instruction = "\"" + question + " region: \" + bins"

In [None]:
# Construct input sample & preprocess for GPU if cuda available
sample = construct_sample(image, instruction)
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample

# Generate result
with torch.no_grad():
    hypos = task.inference_step(generator, models, sample)
    tokens1, bins1, imgs1 = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens2, bins2, imgs2 = decode_fn(hypos[0][1]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens3, bins3, imgs3 = decode_fn(hypos[0][2]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens4, bins4, imgs4 = decode_fn(hypos[0][3]["tokens"], task.tgt_dict, task.bpe, generator)
    tokens5, bins5, imgs5 = decode_fn(hypos[0][4]["tokens"], task.tgt_dict, task.bpe, generator)

# display result
img = cv2.cvtColor(numpy.asarray(image), cv2.COLOR_RGB2BGR)
coord_list = bin2coord(bins, w_resize_ratio, h_resize_ratio)
cv2.rectangle(
    img,
    (int(coord_list[0]), int(coord_list[1])),
    (int(coord_list[2]), int(coord_list[3])),
    (0, 255, 0),
    3
)
cv2_imshow(img)

print('Instruction: {}'.format(instruction))
print('OFA\'s Output1: {}, Probs: {}'.format(tokens1, hypos[0][0]["score"].exp().item()))
print('OFA\'s Output2: {}, Probs: {}'.format(tokens2, hypos[0][1]["score"].exp().item()))
print('OFA\'s Output3: {}, Probs: {}'.format(tokens3, hypos[0][2]["score"].exp().item()))
print('OFA\'s Output4: {}, Probs: {}'.format(tokens4, hypos[0][3]["score"].exp().item()))
print('OFA\'s Output5: {}, Probs: {}'.format(tokens5, hypos[0][4]["score"].exp().item()))

# Visual Grounding

## テスト画像取得

In [None]:
%cd /content/OFA
!mkdir test_imgs
%cd test_imgs

# https://www.pakutaso.com/20180239038post-15116.html
!wget https://www.pakutaso.com/shared/img/thumb/smIMGL4174_TP_V4.jpg

%cd /content/OFA
image = Image.open('/content/OFA/test_imgs/smIMGL4174_TP_V4.jpg')

## Question設定

In [None]:
question = "Fruit that a woman has in her right hand"  #@param {type:"string"}
instruction = 'which region does the text \" ' + question + '\" describe?'

In [None]:
# Construct input sample & preprocess for GPU if cuda available
sample = construct_sample(image, instruction)
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample

# Generate result
with torch.no_grad():
    hypos = task.inference_step(generator, models, sample)
    tokens, bins, imgs = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)

# display result
w_resize_ratio = task.cfg.patch_image_size / w
h_resize_ratio = task.cfg.patch_image_size / h
img = cv2.cvtColor(numpy.asarray(image), cv2.COLOR_RGB2BGR)
coord_list = bin2coord(bins, w_resize_ratio, h_resize_ratio)
cv2.rectangle(
    img,
    (int(coord_list[0]), int(coord_list[1])),
    (int(coord_list[2]), int(coord_list[3])),
    (0, 255, 0),
    3
)
cv2_imshow(img)