# Otter Image Demo (In-context Learning)

Here is an example of multi-modal ICL (in-context learning) with 🦦 Otter. We provide two demo images with corresponding instructions and answers, then we ask the model to generate an answer given our instruct. You may change your instruction and see how the model responds.

You can also try our [online demo](https://otter.cliangyu.com/) to see more in-context learning demonstrations.

必要なモジュールは各自インストール<br>
mlflow==2.6.0はバグがあるため使わないこと(https://github.com/mlflow/mlflow/issues/9331) (2023.08.23)

In [1]:
# !pip install --upgrade mlflow==2.5.0 pydantic==1.10.12 deepspeed==0.10.3

In [1]:
import requests
import torch
import transformers
from PIL import Image
import matplotlib.pyplot as plt
import sys
import os
import textwrap

sys.path.append("../..")
from otter.modeling_otter import OtterForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


[2023-09-28 04:51:02,409] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


## 読み込み

In [3]:
# model = OtterForConditionalGeneration.from_pretrained("luodian/OTTER-Image-MPT7B", device_map="auto") # Hugging Face
model = OtterForConditionalGeneration.from_pretrained("/data/dataset/otter/OTTER-Image-MPT7B/", device_map="auto")
tokenizer = model.text_tokenizer
image_processor = transformers.CLIPImageProcessor()

You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.


Using pad_token, but it is not set yet.


The current model version is configured for Otter-Image with max_num_frames set to None.
Total Trainable param: 1.385404 B


The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 4/4 [00:18<00:00,  4.57s/it]


### トークンの確認

In [4]:
model.text_tokenizer.all_special_tokens

['<|endoftext|>', '<PAD>', '<|endofchunk|>', '<image>', '<answer>']

In [5]:
model.text_tokenizer.get_vocab

<bound method PreTrainedTokenizerFast.get_vocab of GPTNeoXTokenizerFast(name_or_path='mosaicml/mpt-7b-instruct', vocab_size=50254, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<PAD>', 'additional_special_tokens': ['<|endofchunk|>', '<image>', '<answer>']}, clean_up_tokenization_spaces=True)>

In [6]:
tokenizer.convert_ids_to_tokens

<bound method PreTrainedTokenizerFast.convert_ids_to_tokens of GPTNeoXTokenizerFast(name_or_path='mosaicml/mpt-7b-instruct', vocab_size=50254, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<PAD>', 'additional_special_tokens': ['<|endofchunk|>', '<image>', '<answer>']}, clean_up_tokenization_spaces=True)>

In [7]:
attributes_and_methods = dir(model.text_tokenizer)
print(attributes_and_methods)



### 学習済み重みを使用する場合

In [4]:
trained_ckpt_path = '../../log/VI_batch128_long_pairs25/final_weights.pt'
# trained_ckpt_path = '../../weights/OTTER-Image-MPT7B/final_weights.pt' # デフォルト

train_ckpt = torch.load(trained_ckpt_path, map_location="cpu")
if train_ckpt.get("model_state_dict", None) is not None:
    train_ckpt = train_ckpt["model_state_dict"]
_ = model.load_state_dict(train_ckpt, strict=False)

In [None]:
train_ckpt.keys()

## 画像とプロンプトの用意

In [None]:
demo_image_one = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
demo_image_two = Image.open(requests.get("http://images.cocodataset.org/test-stuff2017/000000028137.jpg", stream=True).raw)
query_image = Image.open(requests.get("http://images.cocodataset.org/test-stuff2017/000000028352.jpg", stream=True).raw)

vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
model.text_tokenizer.padding_side = "left"
lang_x = model.text_tokenizer(
    [
        "<image>User: a photo of GPT:<answer> two cats sleeping.<|endofchunk|><image>User: a photo of GPT:<answer> a bathroom sink.<|endofchunk|><image>User: a photo of GPT:<answer>"
    ],
    return_tensors="pt",
)

print(vision_x.shape) # torch.Size([1, 3, 1, 3, 224, 224]) shape (B, num_imgs, Frames=1, C, H, W)

fig, axes = plt.subplots(1, 3, figsize=(10, 4))
axes[0].imshow(demo_image_one)
axes[0].axis('off')
axes[1].imshow(demo_image_two)
axes[1].axis('off')
axes[2].imshow(query_image)
axes[2].axis('off')
plt.tight_layout()
plt.show()

# Get the data type from model's parameters
model_dtype = next(model.parameters()).dtype

# Convert tensors to the model's data type
vision_x = vision_x.to(dtype=model_dtype)
lang_x_input_ids = lang_x["input_ids"]
lang_x_attention_mask = lang_x["attention_mask"]

bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
generated_text = model.generate(
    vision_x=vision_x.to(model.device),
    lang_x=lang_x_input_ids.to(model.device),
    attention_mask=lang_x_attention_mask.to(model.device),
    max_new_tokens=512,
    num_beams=3,
    no_repeat_ngram_size=3,
    bad_words_ids=bad_words_id,
)

parsed_output = (
    model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
)

print("Generated text: ", parsed_output)

## 哺乳類か哺乳類でないかの推論

In [15]:
import os
def get_image_paths(folder_path):
    image_extensions = ['.jpg', '.jpeg', '.png']  # 画像の拡張子リスト
    all_files = sorted(os.listdir(folder_path)) # フォルダ内の全てのファイルを取得
    image_paths = [os.path.join(folder_path, file) for file in all_files if os.path.splitext(file)[1].lower() in image_extensions] # 画像のパスを抽出してリストに格納
    return image_paths

In [None]:
demo_image_one = Image.open("../../../data/test_kosmos/animal_or_not/dog.jpg")
demo_image_two = Image.open("../../../data/test_kosmos/animal_or_not/home.jpg")
query_folder_path = "../../../data/test_kosmos/animal_or_not/"
query_image_paths = get_image_paths(query_folder_path)

for i, query_image_path in enumerate(query_image_paths[:]):
    query_image = Image.open(query_image_path)
    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"
    
    inputs = textwrap.dedent(f"""
        <image>User: Is the first image a mammal? Please answer with Yes or No. GPT:<answer> Yes.<|endofchunk|>
        <image>User: Is the second image a mammal? Please answer with Yes or No. GPT:<answer> No.<|endofchunk|>
        <image>User: Is the next image a mammal? Please answer with Yes or No. GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )
    
    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
    axes[0].imshow(demo_image_one)
    axes[0].axis('off')
    axes[1].imshow(demo_image_two)
    axes[1].axis('off')
    axes[2].imshow(query_image)
    axes[2].axis('off')
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

## 哺乳類であればA, 哺乳類で無ければBの推論

In [None]:
demo_image_one = Image.open("../../../data/test_kosmos/animal_or_not/dog.jpg")
demo_image_two = Image.open("../../../data/test_kosmos/animal_or_not/home.jpg")
query_folder_path = "../../../data/test_kosmos/animal_or_not/"
query_image_paths = get_image_paths(query_folder_path)

for i, query_image_path in enumerate(query_image_paths[:]):
    query_image = Image.open(query_image_path)
    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"
    
    inputs = textwrap.dedent(f"""
        <image>User: If the subject in this image is a mammal, answer A. If not, answer B. GPT:<answer> A. This is a dog. Dogs are mammals.<|endofchunk|>
        <image>User: If the subject in this image is a mammal, answer A. If not, answer B. GPT:<answer> B. This is a house. Houses are not mammals.<|endofchunk|>
        <image>User: If the subject in this image is a mammal, answer A. If not, answer B. GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )
    
    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
    axes[0].imshow(demo_image_one)
    axes[0].axis('off')
    axes[1].imshow(demo_image_two)
    axes[1].axis('off')
    axes[2].imshow(query_image)
    axes[2].axis('off')
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

## jsonファイルの確認

In [None]:
# ○○_instructions.json

import orjson

mimicit_path="../../data/LA/LACR_I2I_instructions.json"
with open(mimicit_path, "rb") as f:
    dataset = orjson.loads(f.read())
    # dataset = orjson.loads(f.read())["data"]
dataset

In [None]:
# ○○.json

import ijson

images = {}
images_path="../../data/LA/LA.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value
images

In [None]:
# エンコードされた文字列から画像可視化

import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

# base64エンコードされた文字列データ
str_data1 = images["LA_IMG_000000215677"]
str_data2 = images["LA_IMG_000000429446"]

# バイトデータにデコード
decoded_data1 = base64.b64decode(str_data1)
decoded_data2 = base64.b64decode(str_data2)

# デコードしたバイトデータをImageオブジェクトに変換
image1 = Image.open(BytesIO(decoded_data1))
image2 = Image.open(BytesIO(decoded_data2))

# 2x1のsubplotを作成して、2枚の画像を表示
fig, axarr = plt.subplots(1, 2)

axarr[0].imshow(image1)
axarr[0].axis('off')  # 軸を非表示に

axarr[1].imshow(image2)
axarr[1].axis('off')  # 軸を非表示に

plt.show()


In [None]:
# ○○_train.json

import orjson

train_config_path="../../data/LA/LACR_I2I_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())
cache_train_config

In [None]:
cache_train_list = list(cache_train_config.keys())
print(len(cache_train_list))
print(cache_train_list[:10])

In [None]:
cache_train_config['LACR_I2I_INS_000000296754']

In [None]:
cache_train_config['LACR_I2I_INS_000000222475']

## 自作データセット確認

In [None]:
import orjson

mimicit_path="/home/data/MIMIC-IT/VI/train_VI_long_instructions.json"
with open(mimicit_path, "rb") as f:
    dataset = orjson.loads(f.read())
dataset

In [6]:
import ijson

images = {}
images_path="/home/data/MIMIC-IT/VI/train_VI.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value
# images

In [None]:
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

# base64エンコードされた文字列データ
str_data1 = images["metal+metal+image_55"]
str_data2 = images["metal+metal_rust+image_8"]

# デコードしたバイトデータをImageオブジェクトに変換
image1 = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
image2 = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")

# 2x1のsubplotを作成して、2枚の画像を表示
fig, axarr = plt.subplots(1, 2)

axarr[0].imshow(image1)
axarr[0].axis('off')  # 軸を非表示に

axarr[1].imshow(image2)
axarr[1].axis('off')  # 軸を非表示に

plt.show()

In [None]:
import orjson

train_config_path="/home/data/MIMIC-IT/VI/train_VI_pairs25_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())
cache_train_config

In [None]:
cache_train_list = list(cache_train_config.keys())
print(len(cache_train_list))
print(cache_train_list[:10])

In [None]:
cache_train_config['book+aged_book+image_2=0']

## 自作データセット重み性能調査

In [7]:
trained_ckpt_path = '../../log/VI_batch128_long_pairs25/final_weights.pt'
# trained_ckpt_path = '../../weights/OTTER-Image-MPT7B/final_weights.pt' # 学習前

train_ckpt = torch.load(trained_ckpt_path, map_location="cpu")
if train_ckpt.get("model_state_dict", None) is not None:
    train_ckpt = train_ckpt["model_state_dict"]
_ = model.load_state_dict(train_ckpt, strict=False)

MVTecAD

In [8]:
import os
def get_image_paths(folder_path):
    image_extensions = ['.jpg', '.jpeg', '.png']  # 画像の拡張子リスト
    all_files = sorted(os.listdir(folder_path)) # フォルダ内の全てのファイルを取得
    image_paths = [os.path.join(folder_path, file) for file in all_files if os.path.splitext(file)[1].lower() in image_extensions] # 画像のパスを抽出してリストに格納
    return image_paths

def write_text_file(file_path, text):
    with open(file_path, mode="a") as f:
        f.write(text+"\n")
        
def generate_list_string(items):
    # アンダースコアをスペースに変換
    items = [item.replace('_', ' ') for item in items]
    if len(items) == 1:
        return items[0]
    elif len(items) == 2:
        return f"{items[0]} and {items[1]}"
    else:
        return ", ".join(items[:-1]) + f", and {items[-1]}"

In [17]:
# long 用
def test(category, anormaly_reason, anormaly_type, model_name, order):
    if category=="grid":
        category__ = "metal grid"
    else:
        category__ = category
    category__ = category__.replace('_', ' ')
    for ano_type,ano_reason in zip(anormaly_type,anormaly_reason):
        folder_name = f'./result/{category}/{ano_type}/{model_name}'
        os.makedirs(folder_name, exist_ok=True)
        with open(f'{folder_name}/detective.txt', mode='w') as f:
            f.close()
        with open(f'{folder_name}/non-detective.txt', mode='w') as f:
            f.close()
        
        subfolder_string = generate_list_string(anormaly_reason)
        model.text_tokenizer.padding_side = "left"
        
        """ クエリ：不良品 """
        if order: # demo_image_one: 良品, demo_image_two: 不良品
            sentence = f"context1: OK, context2: NG, query: NG"
            # print(sentence)
            write_text_file(f'{folder_name}/detective.txt',sentence)
            demo_image_one = Image.open(f"/home/data/mvtec/{category}/test/good/000.png").resize((224, 224)).convert("RGB")
            demo_image_two = Image.open(f"/home/data/mvtec/{category}/test/{ano_type}/000.png").resize((224, 224)).convert("RGB")
            # long
            inputs = textwrap.dedent(f"""
                <image>User: This is an image of {category__}. Does this wood have any defects such as {subfolder_string}? GPT:<answer> No. This {category__} does not have any defects such as {subfolder_string}, so it is non-defective.<|endofchunk|>
                <image>User: This is an image of {category__}. Does this wood have any defects such as {subfolder_string}? GPT:<answer> Yes. This {category__} has some {ano_reason}, so it is defective.<|endofchunk|>
                <image>User: This is an image of {category__}. Does this wood have any defects such as {subfolder_string}? GPT:<answer>
            """)
            #short
            # inputs = textwrap.dedent(f"""
            #     <image>User: This is an image of {category__}. Does this wood have any defects? GPT:<answer> No. This {category__} does not have any defects such as {subfolder_string}, so it is non-defective.<|endofchunk|>
            #     <image>User: This is an image of {category__}. Does this wood have any defects? GPT:<answer> Yes. This {category__} has some {ano_reason}, so it is defective.<|endofchunk|>
            #     <image>User: This is an image of {category__}. Does this wood have any defects? GPT:<answer>
            # """)
        
        else: # demo_image_one: 不良品, demo_image_two: 良品
            sentence = f"context1: NG, context2: OK, query: NG"
            # print(sentence)
            write_text_file(f'{folder_name}/detective.txt',sentence)
            demo_image_one = Image.open(f"/home/data/mvtec/{category}/test/{ano_type}/000.png").resize((224, 224)).convert("RGB")
            demo_image_two = Image.open(f"/home/data/mvtec/{category}/test/good/000.png").resize((224, 224)).convert("RGB")
            # long
            inputs = textwrap.dedent(f"""
                <image>User: This is an image of {category__}. Does this wood have any defects such as {subfolder_string}? GPT:<answer> No. This {category__} does not have any defects such as {subfolder_string}, so it is non-defective.<|endofchunk|>
                <image>User: This is an image of {category__}. Does this wood have any defects such as {subfolder_string}? GPT:<answer> Yes. This {category__} has some {ano_reason}, so it is defective.<|endofchunk|>
                <image>User: This is an image of {category__}. Does this wood have any defects such as {subfolder_string}? GPT:<answer>
            """)
            # short
            # inputs = textwrap.dedent(f"""
            #     <image>User: This is an image of {category__}. Does this wood have any defects? GPT:<answer> No. This {category__} does not have any defects such as {subfolder_string}, so it is non-defective.<|endofchunk|>
            #     <image>User: This is an image of {category__}. Does this wood have any defects? GPT:<answer> Yes. This {category__} has some {ano_reason}, so it is defective.<|endofchunk|>
            #     <image>User: This is an image of {category__}. Does this wood have any defects? GPT:<answer>
            # """)
        
        inputs = "".join(inputs.split("\n"))
        lang_x = model.text_tokenizer(
            [
                inputs
            ],
            return_tensors="pt",
        )
        
        write_text_file(f'{folder_name}/detective.txt',f'-----{ano_type} start-----')
        write_text_file(f'{folder_name}/detective.txt',"")
            
        query_folder_path = f"/home/data/mvtec/{category}/test/{ano_type}"
        query_image_paths = get_image_paths(query_folder_path)
        count = 0
        for i, query_image_path in enumerate(query_image_paths[1:]):
            # print(query_image_path)
            query_image = Image.open(query_image_path).resize((224, 224)).convert("RGB")
            vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
        
            # Get the data type from model's parameters
            model_dtype = next(model.parameters()).dtype

            # Convert tensors to the model's data type
            vision_x = vision_x.to(dtype=model_dtype)
            lang_x_input_ids = lang_x["input_ids"]
            lang_x_attention_mask = lang_x["attention_mask"]

            bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
            generated_text = model.generate(
                vision_x=vision_x.to(model.device),
                lang_x=lang_x_input_ids.to(model.device),
                attention_mask=lang_x_attention_mask.to(model.device),
                max_new_tokens=512,
                num_beams=3,
                no_repeat_ngram_size=3,
                bad_words_ids=bad_words_id,
            )

            parsed_output = (
                model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
            )
            
            if parsed_output.split(".")[0].lower()=="yes":
                count += 1
            
            write_text_file(f'{folder_name}/detective.txt',query_image_path)
            write_text_file(f'{folder_name}/detective.txt',parsed_output)
            write_text_file(f'{folder_name}/detective.txt',"")
            
            # print(inputs)
            # print("GPT:", parsed_output)
            
            # fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
            # axes[0].imshow(demo_image_one)
            # axes[0].axis('off')
            # axes[1].imshow(demo_image_two)
            # axes[1].axis('off')
            # axes[2].imshow(query_image)
            # axes[2].axis('off')
            # plt.show()
            
        acc = f"correct: {count}, total: {len(query_image_paths)-1}, acc: {(count / (len(query_image_paths)-1)) * 100:.2f}%"
        print(acc)
        
        write_text_file(f'{folder_name}/detective.txt',f'-----{ano_type} end-----')
        write_text_file(f'{folder_name}/detective.txt',acc)
        
        
        """ クエリ：良品 """
        if order: # demo_image_one: 良品, demo_image_two: 不良品
            sentence = f"context1: OK, context2: NG, query: OK"
            # print(sentence)
            write_text_file(f'{folder_name}/non-detective.txt',sentence)
        
        else: # demo_image_one: 不良品, demo_image_two: 良品
            sentence = f"context1: NG, context2: OK, query: OK"
            # print(sentence)
            write_text_file(f'{folder_name}/non-detective.txt',sentence)
        
        write_text_file(f'{folder_name}/non-detective.txt',f'-----{ano_type} start-----')
        write_text_file(f'{folder_name}/non-detective.txt',"")
            
        query_folder_path = f"/home/data/mvtec/{category}/test/good"
        query_image_paths = get_image_paths(query_folder_path)
        count = 0
        for i, query_image_path in enumerate(query_image_paths[1:]):
            # print(query_image_path)
            query_image = Image.open(query_image_path).resize((224, 224)).convert("RGB")
            vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
        
            # Get the data type from model's parameters
            model_dtype = next(model.parameters()).dtype

            # Convert tensors to the model's data type
            vision_x = vision_x.to(dtype=model_dtype)
            lang_x_input_ids = lang_x["input_ids"]
            lang_x_attention_mask = lang_x["attention_mask"]

            bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
            generated_text = model.generate(
                vision_x=vision_x.to(model.device),
                lang_x=lang_x_input_ids.to(model.device),
                attention_mask=lang_x_attention_mask.to(model.device),
                max_new_tokens=512,
                num_beams=3,
                no_repeat_ngram_size=3,
                bad_words_ids=bad_words_id,
            )

            parsed_output = (
                model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
            )
            
            if parsed_output.split(".")[0].lower()=="no":
                count += 1
            
            write_text_file(f'{folder_name}/non-detective.txt',query_image_path)
            write_text_file(f'{folder_name}/non-detective.txt',parsed_output)
            write_text_file(f'{folder_name}/non-detective.txt',"")
            
            # print(inputs)
            # print("GPT:", parsed_output)
            
            # fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
            # axes[0].imshow(demo_image_one)
            # axes[0].axis('off')
            # axes[1].imshow(demo_image_two)
            # axes[1].axis('off')
            # axes[2].imshow(query_image)
            # axes[2].axis('off')
            # plt.show()
            
        acc = f"correct: {count}, total: {len(query_image_paths)-1}, acc: {(count / (len(query_image_paths)-1)) * 100:.2f}%"
        print(acc)
        
        write_text_file(f'{folder_name}/non-detective.txt',f'-----{ano_type} end-----')
        write_text_file(f'{folder_name}/non-detective.txt',acc)

スクレイピングに含まれるカテゴリ
1. bottle
2. carpet
3. leather
4. tile
5. wood

スクレイピングに含まれないカテゴリ
1. cable
2. capsule
3. hazelnut
4. pill
5. screw
6. toothbrush
7. transistor
8. zipper

グレー
1. grid (プロンプトはmetal gridにする)
2. metal_nut (metalがある nutはテストに)

In [None]:
category = "wood"
anormaly_reason = ["scratched wood","stained wood","wood with holes"]
# anormaly_type = ["scratch","color","hole"]
anormaly_type = ["scratch","liquid","hole"]
model_name = "VI_batch128_long_pairs25"

test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "hazelnut"
anormaly_reason = ['cracked hazelnut','scratched hazelnut','hazelnut with holes','hazelnut with white marks']
anormaly_type = ["crack","cut","hole","print"]
model_name = "VI_batch128_long_pairs25"

test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "bottle"
anormaly_reason = ['broken bottle','contaminated bottle']
anormaly_type = ["broken_large","contamination"]
model_name = "VI_batch128_long_pairs25"

test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["broken_small","contamination"]
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "cable"
anormaly_reason = ['bent cable','swapped cable','broken cable','missing cable','poked cable']
anormaly_type = ["bent_wire","cable_swap","cut_inner_insulation","missing_cable","poke_insulation"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["bent_wire","cable_swap","cut_outer_insulation","missing_wire","poke_insulation"]
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "capsule"
anormaly_reason = ['cracked capsule','misprinted capsule','poked capsule','scratched capsule','damaged capsule']
anormaly_type = ["crack","faulty_imprint","poke","scratch","squeeze"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "carpet"
anormaly_reason = ['stained carpet','cut carpet','carpet with holes','contaminated carpet']
anormaly_type = ["color","cut","hole","metal_contamination"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["color","cut","hole","thread"]
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "grid"
anormaly_reason = ['bent metal grid','broken metal grid','contaminated metal grid']
anormaly_type = ["bent","broken","glue"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["bent","broken","metal_contamination"]
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["bent","broken","thread"]
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "leather"
anormaly_reason = ['stained leather','cut leather','folded leather','poked leather']
anormaly_type = ["color","cut","fold","poke"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["glue","cut","fold","poke"]
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "metal_nut"
anormaly_reason = ['bent metal nut','stained metal nut','flipped metal nut','scratched metal nut']
anormaly_type = ["bent","color","flip","scratch"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "pill"
anormaly_reason = ['stained pill','contaminated pill','cracked pill','misprinted pill','scratched pill']
anormaly_type = ["color","contamination","crack","faulty_imprint","scratch"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["pill_type","contamination","crack","faulty_imprint","scratch"]
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "screw"
anormaly_reason = ['stripped screw','scratched screw','broken screw']
anormaly_type = ["manipulated_front","scratch_head","thread_side"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["manipulated_front","scratch_neck","thread_top"]
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "tile"
anormaly_reason = ['cracked tile','contaminated tile','stained tile']
anormaly_type = ["crack","glue_strip","gray_stroke"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["crack","glue_strip","oil"]
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["crack","glue_strip","rough"]
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "toothbrush"
anormaly_reason = ['damaged toothbrush']
anormaly_type = ["defective"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "transistor"
anormaly_reason = ['bent transistor','cut transistor','damaged transistor','misplaced transistor']
anormaly_type = ["bent_lead","cut_lead","damaged_case","misplaced"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

In [None]:
category = "zipper"
anormaly_reason = ['broken zipper','torn zipper','damaged zipper']
anormaly_type = ["broken_teeth","fabric_border","rough"]
model_name = "VI_batch128_long_pairs25"
test(category, anormaly_reason, anormaly_type, model_name, True)

anormaly_type = ["split_teeth","fabric_interior","squeezed_teeth"]
test(category, anormaly_reason, anormaly_type, model_name, True)

### 非公開データセットで検証

In [None]:
category = "rice"
anormaly_reason = ["brokened rice","rice with milky white"]
anormaly_type = ["broken","milky_white"]
model_name = "VI_batch128_long_pairs25"

test(category, anormaly_reason, anormaly_type, model_name, True)

学習に使用していないデータで検証 (カテゴリ、欠陥名は既知)

In [None]:
import ijson
import orjson

images = {}
images_path="/home/data/MIMIC-IT/VI/val_VI.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="/home/data/MIMIC-IT/VI/val_VI_pairs25_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

# mimicit_path="/home/data/MIMIC-IT/VI/val_VI_short_instructions.json"
mimicit_path="/home/data/MIMIC-IT/VI/val_VI_long_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 50
count = 0
for i in range(len(keys[:NUM])):
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]

    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][context1]["instruction"]} GPT:<answer> {instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User: {instructions["data"][context2]["instruction"]} GPT:<answer> {instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].split(".")[0].lower()==parsed_output.split(".")[0].lower():
        count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
    axes[0].imshow(demo_image_one)
    axes[0].axis('off')
    axes[1].imshow(demo_image_two)
    axes[1].axis('off')
    axes[2].imshow(query_image)
    axes[2].axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")

学習に使用したデータで検証

In [None]:
import ijson
import orjson

images = {}
images_path="/home/data/MIMIC-IT/VI/train_VI.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="/home/data/MIMIC-IT/VI/train_VI_pairs25_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

# mimicit_path="/home/data/MIMIC-IT/VI/train_VI_short_instructions.json"
mimicit_path="/home/data/MIMIC-IT/VI/train_VI_long_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 50
count = 0
for i in range(len(keys[:NUM])):
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]

    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][context1]["instruction"]} GPT:<answer> {instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User: {instructions["data"][context2]["instruction"]} GPT:<answer> {instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].split(".")[0].lower()==parsed_output.split(".")[0].lower():
        count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
    axes[0].imshow(demo_image_one)
    axes[0].axis('off')
    axes[1].imshow(demo_image_two)
    axes[1].axis('off')
    axes[2].imshow(query_image)
    axes[2].axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")

学習に使用していないデータで検証 (カテゴリは未知、欠陥名はほぼ既知)

In [None]:
import ijson
import orjson

images = {}
images_path="/home/data/MIMIC-IT/VI/test_VI.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="/home/data/MIMIC-IT/VI/test_VI_pairs5_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

# mimicit_path="/home/data/MIMIC-IT/VI/test_VI_short_instructions.json"
mimicit_path="/home/data/MIMIC-IT/VI/test_VI_long_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 50
count = 0
for i in range(len(keys[:NUM])):
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]

    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][context1]["instruction"]} GPT:<answer> {instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User: {instructions["data"][context2]["instruction"]} GPT:<answer> {instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].split(".")[0].lower()==parsed_output.split(".")[0].lower():
        count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
    axes[0].imshow(demo_image_one)
    axes[0].axis('off')
    axes[1].imshow(demo_image_two)
    axes[1].axis('off')
    axes[2].imshow(query_image)
    axes[2].axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")