# Otter Image Demo (In-context Learning)

Here is an example of multi-modal ICL (in-context learning) with 🦦 Otter. We provide two demo images with corresponding instructions and answers, then we ask the model to generate an answer given our instruct. You may change your instruction and see how the model responds.

You can also try our [online demo](https://otter.cliangyu.com/) to see more in-context learning demonstrations.

必要なモジュールは各自インストール<br>
mlflow==2.6.0はバグがあるため使わないこと(https://github.com/mlflow/mlflow/issues/9331) (2023.08.23)

In [None]:
# !pip install --upgrade mlflow==2.5.0 pydantic==1.10.12 deepspeed==0.10.3

In [1]:
import requests
import torch
import transformers
from PIL import Image
import matplotlib.pyplot as plt
import sys
import os
import textwrap

sys.path.append("../..")
from otter.modeling_otter import OtterForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


[2023-10-24 05:30:33,801] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


## 読み込み

In [2]:
# model = OtterForConditionalGeneration.from_pretrained("luodian/OTTER-Image-MPT7B", device_map="auto") # Hugging Face
model = OtterForConditionalGeneration.from_pretrained("/home/ueno/Otter/weights/OTTER-Image-MPT7B/", device_map="auto")
tokenizer = model.text_tokenizer
image_processor = transformers.CLIPImageProcessor()

Using pad_token, but it is not set yet.


You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.
The current model version is configured for Otter-Image with max_num_frames set to None.
Total Trainable param: 1.385404 B


The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 4/4 [00:14<00:00,  3.50s/it]


### トークンの確認

In [None]:
# 特殊トークン
model.text_tokenizer.all_special_tokens

In [None]:
lst = [0, 50277, 50278, 50279, 50280]

for i in lst:
    print(f"{i}: {tokenizer.decode(i)}")
    

In [None]:
model.text_tokenizer.get_vocab

In [None]:
attributes_and_methods = dir(model.text_tokenizer)
print(attributes_and_methods)

In [None]:
tokenizer.decode(4374)

In [None]:
tokenizer.decode(2302)

In [None]:
print(tokenizer.eos_token_id)
print(tokenizer.bos_token_id)

In [None]:
import torch

arr = torch.tensor([-100, -100, -100, -100, 296, 404, 50277, 0, -100, -100])
index = (arr != -100).nonzero(as_tuple=True)[0][0].item()
print("インデックス:", index)

# インデックス5の値を取得
value = arr[4].item()

# 値を出力
print("インデックス5の値:", value)

In [None]:
# instruction_following.pyのtrain_one_epochのlabels
sakai = [50343,  6989,    27,   769,   436,  2505,  6266,   247,  3295,    32,
         4496,  4754,   310,   667,    13,    13,  4496,  2085,   247,  4278,
          15,    15,   443,   627,    13,  4496,  3662,   594,    15, 50277,
         5736,    27,  4374,  4374,  9370, 50277,     0, 50396, 50280, 50280]
sakai = [    0, 50278,  6989,    27, 18566,   436,  2460,   452,   667, 12834,
           32,   604,   627,   403,   667, 12834,    13,  4496,  2085,   253,
         7071,  1416,    15,   604,   417,    13,  4496,  1333,  5293,    15,
          443,  5736,    27, 50279,  2302,  8256, 50277,     0, 50280, 50280]
# modeling_mpt.pyのclass MPTForCausalLM(MPTPreTrainedModel):のforward
# lossに入るGTの可視化
# sakai = [-100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          2302,    15,   831, 23069,  1057,   417,   452,   667, 12834,   824,
#           347, 31806, 23069,    13, 28290, 23069,    13,   439,  6321,  3612,
#         23069,    13, 13968, 23069,    13,   285, 15070, 23069,    13,   594,
#           352,   310,  1327,    14,   615,   738,   422,    15, 50277,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4374,
#            15,   831, 23069,   556,   690, 31806, 23069,    13,   594,   352,
#           310, 22327,    15, 50277,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  2302,    15,   831, 23069,  1057,   417,
#           452,   667, 12834,   824,   347, 31806, 23069,    13, 28290, 23069,
#            13,   439,  6321,  3612, 23069,    13, 13968, 23069,    13,   285,
#         15070, 23069,    13,   594,   352,   310,  1327,    14,   615,   738,
#           422,    15, 50277,     0,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
#          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100]
# # lossに入るモデルの出力の可視化
# sakai = [50343,  6989,    27,   769,   694,   247,  2460,   273,   247,    15,
#          4496,   352,  3761,   452,   247,  3102,   390,   347, 11385,  2739,
#           390, 11385, 23069,    13,   390,   357,  3612, 23069,    13,   390,
#         23069,    13,   390, 15070, 23069,    32,  4496,  5736,    27,  1621,
#          1621,    13, 50277,   310,  1057,   417,   452,   667,   273,    15,
#           347, 31806, 23069,    13, 28290, 23069,    13,   439,  6321,  3612,
#         23069,    13, 13968, 23069,    13,   390, 15070, 23069,    15,   285,
#           352,   310,   275,    14,   615,   738,   422,    15, 50277,     0,
#          6989,    27,   752,   310,   271,  2460,   273,   247,    15,  1057,
#           436, 23069,   452,   667, 12834,   824,   347, 31806, 23069,    13,
#         28290, 23069,    13,   439,  6321,  3612, 23069,    13, 13968, 23069,
#            13,   285, 15070, 23069,    32,   443,  5736,    27,  1621,  1621,
#            13,   380, 23069,   556,   247, 12834, 12834, 12834, 28290,   352,
#           310,   417,    15, 50277,     0,  6989,    27,   752,   310,   271,
#          2460,   273,   247,    15,   752,   436, 23069,   452,   667, 12834,
#           824,   347, 31806, 23069,    13, 28290, 23069,    13,   439,  6321,
#          3612, 23069,    13, 13968, 23069,    13,   285, 15070, 23069,    32,
#           443,  5736,    27,  6279,  6279,    13,   831, 23069,  1057,   417,
#           452,   667, 12834,   824,   347, 31806, 23069,    13, 28290, 23069,
#            13,   439,  6321,  3612, 23069,    13, 13968, 23069,    13,   285,
#         15070, 23069,    13,   594,   352,   310,  1327,    14,   615,   738,
#           422,    15, 50277,     0, 50396, 50280, 50280, 50280, 50280, 50280,
#         50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280,
#         50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280,
#         50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280,
#         50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280, 50280]
# sakai = [50280 if x == -100 else x for x in sakai]
print(sakai)
for s in sakai:
    print(tokenizer.decode(s),end='')
print()

## 推論

In [None]:
demo_image_one = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
demo_image_two = Image.open(requests.get("http://images.cocodataset.org/test-stuff2017/000000028137.jpg", stream=True).raw)
query_image = Image.open(requests.get("http://images.cocodataset.org/test-stuff2017/000000028352.jpg", stream=True).raw)

vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
model.text_tokenizer.padding_side = "left"
lang_x = model.text_tokenizer(
    [
        "<image>User: a photo of GPT:<answer> two cats sleeping.<|endofchunk|><image>User: a photo of GPT:<answer> a bathroom sink.<|endofchunk|><image>User: a photo of GPT:<answer>"
    ],
    return_tensors="pt",
)

print(vision_x.shape) # torch.Size([1, 3, 1, 3, 224, 224]) shape (B, num_imgs, Frames=1, C, H, W)

fig, axes = plt.subplots(1, 3, figsize=(10, 4))
axes[0].imshow(demo_image_one)
axes[0].axis('off')
axes[1].imshow(demo_image_two)
axes[1].axis('off')
axes[2].imshow(query_image)
axes[2].axis('off')
plt.tight_layout()
plt.show()

# Get the data type from model's parameters
model_dtype = next(model.parameters()).dtype

# Convert tensors to the model's data type
vision_x = vision_x.to(dtype=model_dtype)
lang_x_input_ids = lang_x["input_ids"]
lang_x_attention_mask = lang_x["attention_mask"]

bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
generated_text = model.generate(
    vision_x=vision_x.to(model.device),
    lang_x=lang_x_input_ids.to(model.device),
    attention_mask=lang_x_attention_mask.to(model.device),
    max_new_tokens=512,
    num_beams=3,
    no_repeat_ngram_size=3,
    bad_words_ids=bad_words_id,
)

parsed_output = (
    model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
)

print("Generated text: ", parsed_output)

## MIMIC-ITのjsonファイルの確認

In [None]:
# ○○_instructions.json

import orjson

mimicit_path="../../data/LA/LACR_I2I_instructions.json"
with open(mimicit_path, "rb") as f:
    dataset = orjson.loads(f.read())
    # dataset = orjson.loads(f.read())["data"]
dataset

In [None]:
# ○○.json

import ijson

images = {}
images_path="../../data/LA/LA.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value
images

In [None]:
# エンコードされた文字列から画像可視化

import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

# base64エンコードされた文字列データ
str_data1 = images["LA_IMG_000000215677"]
str_data2 = images["LA_IMG_000000429446"]

# バイトデータにデコード
decoded_data1 = base64.b64decode(str_data1)
decoded_data2 = base64.b64decode(str_data2)

# デコードしたバイトデータをImageオブジェクトに変換
image1 = Image.open(BytesIO(decoded_data1))
image2 = Image.open(BytesIO(decoded_data2))

# 2x1のsubplotを作成して、2枚の画像を表示
fig, axarr = plt.subplots(1, 2)

axarr[0].imshow(image1)
axarr[0].axis('off')  # 軸を非表示に

axarr[1].imshow(image2)
axarr[1].axis('off')  # 軸を非表示に

plt.show()


In [None]:
# ○○_train.json

import orjson

train_config_path="../../data/LA/LACR_I2I_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())
cache_train_config

In [None]:
cache_train_list = list(cache_train_config.keys())
print(len(cache_train_list))
print(cache_train_list[:10])

In [None]:
cache_train_config['LACR_I2I_INS_000000296754']

In [None]:
cache_train_config['LACR_I2I_INS_000000222475']

## 自作データセット確認

In [None]:
import orjson

mimicit_path="/data/dataset/MIMIC-IT/VI/train_VI_long_instructions.json"
with open(mimicit_path, "rb") as f:
    dataset = orjson.loads(f.read())
dataset

In [None]:
import ijson

images = {}
images_path="/data/dataset/MIMIC-IT/VI/train_VI.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value
# images

In [None]:
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

# base64エンコードされた文字列データ
str_data1 = images["metal+metal+image_55"]
str_data2 = images["metal+metal_rust+image_8"]

# デコードしたバイトデータをImageオブジェクトに変換
image1 = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
image2 = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")

# 2x1のsubplotを作成して、2枚の画像を表示
fig, axarr = plt.subplots(1, 2)

axarr[0].imshow(image1)
axarr[0].axis('off')  # 軸を非表示に

axarr[1].imshow(image2)
axarr[1].axis('off')  # 軸を非表示に

plt.show()

In [None]:
import orjson

train_config_path="/data/dataset/MIMIC-IT/VI/train_VI_pairs25_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())
cache_train_config

In [None]:
cache_train_list = list(cache_train_config.keys())
print(len(cache_train_list))
print(cache_train_list[:10])

In [None]:
cache_train_config['book+aged_book+image_2=0']

## 自作データセット重み性能調査(欠陥名当て)

In [None]:
model_name = "context_true"
trained_ckpt_path = f'../../log/{model_name}/final_weights.pt'

train_ckpt = torch.load(trained_ckpt_path, map_location="cpu")
if train_ckpt.get("model_state_dict", None) is not None:
    train_ckpt = train_ckpt["model_state_dict"]
_ = model.load_state_dict(train_ckpt, strict=False)

trainデータで検証

In [None]:
# データ読み込み
import ijson
import orjson

images = {}
images_path="/data/dataset/MIMIC-IT/AC/AC_train.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="/data/dataset/MIMIC-IT/AC/AC_train_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

mimicit_path="/data/dataset/MIMIC-IT/AC/AC_train_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
# 正解率
from IPython.display import clear_output
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
count = 0
NUM = 10
for i in range(len(keys[:NUM])):
    print(i)
    query = keys[i].split('=')[0]
    str_data3 = images[query]

    # デコードしたバイトデータをImageオブジェクトに変換
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].lower()==parsed_output.lower():
        count += 1
    clear_output(wait=True)
    
print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")

In [None]:
# 一部可視化
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 10
count = 0
for i in range(len(keys[:NUM])):
    query = keys[i].split('=')[0]

    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].lower()==parsed_output.lower():
        count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10,5))
    axes.imshow(query_image)
    axes.axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")

valデータで検証

In [None]:
# データ読み込み
import ijson
import orjson

images = {}
images_path="/data/dataset/MIMIC-IT/AC/AC_val.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="/data/dataset/MIMIC-IT/AC/AC_val_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

mimicit_path="/data/dataset/MIMIC-IT/AC/AC_val_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
# 正解率
from IPython.display import clear_output
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
count = 0
NUM = 1000
for i in range(len(keys[:NUM])):
    print(i)
    query = keys[i].split('=')[0]
    str_data3 = images[query]

    # デコードしたバイトデータをImageオブジェクトに変換
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].lower()==parsed_output.lower():
        count += 1
    clear_output(wait=True)

print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")

In [None]:
# 一部可視化
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 10
count = 0
for i in range(len(keys[:NUM])):
    query = keys[i].split('=')[0]

    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].lower()==parsed_output.lower():
        count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10,5))
    axes.imshow(query_image)
    axes.axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")

testデータで検証

In [None]:
# データ読み込み
import ijson
import orjson

images = {}
images_path="../../data/VI_test_jsons/VI_test.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="../../data/VI_test_jsons/VI_test_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

mimicit_path="../../data/VI_test_jsons/VI_test_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
# 正解率
from IPython.display import clear_output
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
count = 0
total_ok = 0
total_ng = 0
count_ok = 0
count_ng = 0
miss_ok = []
miss_ng = []
NUM = 1000
for i in range(len(keys[:NUM])):
    print(i)
    query = keys[i].split('=')[0]
    str_data3 = images[query]

    # デコードしたバイトデータをImageオブジェクトに変換
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    # if instructions["data"][query]["answer"].lower()==parsed_output.lower():
    #     count += 1
    if query.split('+')[1]=="None":
        total_ok += 1
        if instructions["data"][query]["answer"].lower()==parsed_output.lower():
            count_ok += 1
        else:
            miss_ok.append(query)
    else:
        total_ng += 1
        if instructions["data"][query]["answer"].lower()==parsed_output.lower():
            count_ng += 1
        else:
            miss_ng.append(query)
    clear_output(wait=True)
    
# print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")
print(f"ok correct: {count_ok}, total: {total_ok}, acc: {(count_ok / total_ok) * 100:.2f}%")
print(f"ng correct: {count_ng}, total: {total_ng}, acc: {(count_ng / total_ng) * 100:.2f}%")

In [None]:
# 一部可視化
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 400
count = 0
keys = [s for s in keys if s.split('+')[1] != "None"]
# for i in range(len(keys[100:NUM])):
for i, key in enumerate(keys[300:NUM]):
    query = key.split('=')[0]

    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User: {instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].lower()==parsed_output.lower():
        count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10,5))
    axes.imshow(query_image)
    axes.axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"correct: {count}, total: {i+1}, acc: {(count / (i+1)) * 100:.2f}%")

### MVTecAD(欠陥名当て)

In [None]:
import os
def get_image_paths(folder_path):
    image_extensions = ['.jpg', '.jpeg', '.png']  # 画像の拡張子リスト
    all_files = sorted(os.listdir(folder_path)) # フォルダ内の全てのファイルを取得
    image_paths = [os.path.join(folder_path, file) for file in all_files if os.path.splitext(file)[1].lower() in image_extensions] # 画像のパスを抽出してリストに格納
    return image_paths

def write_text_file(file_path, text):
    with open(file_path, mode="a") as f:
        f.write(text+"\n")
        
def generate_list_string(items):
    # アンダースコアをスペースに変換
    items = [item.replace('_', ' ') for item in items]
    if len(items) == 1:
        return items[0]
    elif len(items) == 2:
        return f"{items[0]} and {items[1]}"
    else:
        return ", ".join(items[:-1]) + f", and {items[-1]}"

In [None]:
# クエリのみ使用
from IPython.display import clear_output

def test(folder, sub_folder, GTs, model_name):
    acc = []
    for j, (sub,gt) in enumerate(zip(sub_folder,GTs)):
        folder_name = f'./result/{folder}/{sub}/{model_name}'
        os.makedirs(folder_name, exist_ok=True)
        with open(f'{folder_name}/AC.txt', mode='w') as f:
            f.close()
        
        model.text_tokenizer.padding_side = "left"
        
        sentence = f"{sub} --> {gt}"
        # print(sentence)
        write_text_file(f'{folder_name}/AC.txt',sentence)
        inputs = textwrap.dedent(f"""
           <image>User: What are the defects present in this image? If there are none, please say None. GPT:<answer>
        """)    
        
        inputs = "".join(inputs.split("\n"))
        lang_x = model.text_tokenizer(
            [
                inputs
            ],
            return_tensors="pt",
        )
        
        write_text_file(f'{folder_name}/AC.txt',f'-----{sub} start-----')
        write_text_file(f'{folder_name}/AC.txt',"")
            
        query_folder_path = f"/data/dataset/mvtec/{folder}/test/{sub}"
        query_image_paths = get_image_paths(query_folder_path)
        count = 0
        for i, query_image_path in enumerate(query_image_paths[:]):
            # print(query_image_path)
            query_image = Image.open(query_image_path).resize((224, 224)).convert("RGB")
            vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
        
            # Get the data type from model's parameters
            model_dtype = next(model.parameters()).dtype

            # Convert tensors to the model's data type
            vision_x = vision_x.to(dtype=model_dtype)
            lang_x_input_ids = lang_x["input_ids"]
            lang_x_attention_mask = lang_x["attention_mask"]

            bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
            generated_text = model.generate(
                vision_x=vision_x.to(model.device),
                lang_x=lang_x_input_ids.to(model.device),
                attention_mask=lang_x_attention_mask.to(model.device),
                max_new_tokens=512,
                num_beams=3,
                no_repeat_ngram_size=3,
                bad_words_ids=bad_words_id,
            )

            parsed_output = (
                model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
            )
            
            if parsed_output.lower()==gt.lower():
                count += 1
            
            write_text_file(f'{folder_name}/AC.txt',query_image_path)
            write_text_file(f'{folder_name}/AC.txt',parsed_output)
            write_text_file(f'{folder_name}/AC.txt',"")
            clear_output(wait=True)
            
            # print(inputs)
            # print("GPT:", parsed_output)
            
            # fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10,5))
            # axes.imshow(query_image)
            # axes.axis('off')
            # plt.show()
            
        accuracy = f"correct: {count}, total: {len(query_image_paths)}, acc: {(count / (len(query_image_paths))) * 100:.2f}%"
        acc.append((sub,accuracy))
        
        write_text_file(f'{folder_name}/AC.txt',f'-----{sub} end-----')
        write_text_file(f'{folder_name}/AC.txt',accuracy)

    for a in acc:
        print(a)
        

In [None]:
folder = "bottle"
sub_folder = ["good","broken_large","broken_small","contamination"]
GTs = ['None','broken','broken','contamination']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "cable"
sub_folder = ["good","bent_wire","cable_swap","cut_inner_insulation","cut_outer_insulation","missing_cable","missing_wire","poke_insulation"]
GTs = ["None",'bent','swapp','crack','crack','missing','missing','hole']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "capsule"
sub_folder = ["good","crack","faulty_imprint","poke","scratch","squeeze"]
GTs = ["None",'crack','misprint','hole','scratch','misshapen']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "carpet"
sub_folder = ["good","color","cut","hole","metal_contamination","thread"]
GTs = ["None",'stain','cut','hole','contamination','contamination']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "grid"
sub_folder = ["good","bent","broken","glue","metal_contamination","thread"]
GTs = ["None","bent","broken","contamination","contamination","contamination"]

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "hazelnut"
sub_folder = ["good","crack","cut","hole","print"]
GTs = ['None','crack','scratch','hole','misprint']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "leather"
sub_folder = ["good","color","cut","fold","glue","poke"]
GTs = ['None','stain','scratch','wrinkle','poked leather']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "metal_nut"
sub_folder = ["good","bent","color","flip","scratch"]
GTs = ['None','bent','stain','flip','scratch']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "pill"
sub_folder = ["good","color","contamination","crack","faulty_imprint","scratch","pill_type"]
GTs = ['None','stain','contamination','crack','misprint','scratch','stain']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "screw"
sub_folder = ["good","manipulated_front","scratch_head","scratch_neck","thread_side","thread_top"]
GTs = ['None','strip','chip','chip','chip','chip']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "tile"
sub_folder = ["good","crack","glue_strip","gray_stroke","oil","rough"]
GTs = ['None','crack','contamination','stain','stain','contamination']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "toothbrush"
sub_folder = ["good","defective"]
GTs = ["None","broken"]

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "transistor"
sub_folder = ["good","bent_lead","cut_lead","damaged_case","misplaced"]
GTs = ["None","bent","cut","broken","misalignment"]

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "wood"
sub_folder = ["good","color","scratch","liquid","hole"]
GTs = ["None","stain","scratch","stain","hole"]

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "zipper"
sub_folder = ["good","broken_teeth","fabric_border","fabric_interior","rough","split_teeth","squeezed_teeth"]
GTs = ["None","broken","tear","frayed","frayed","misshapen","misshapen"]

test(folder, sub_folder, GTs, model_name)

### 非公開データセットで検証

In [None]:
# category = "rice"
# anormaly_reason = ["brokened rice","rice with milky white"]
# anormaly_type = ["broken","milky_white"]

# test(category, anormaly_reason, anormaly_type, model_name)

## 自作データセット重み性能調査(コンテキスト)

In [7]:
# model_name = "AC-VI_loss_context_query/batch128_epoch1_lr-5_pairs25_weight5"
model_name = "context_true"
# model_name = "debug"
# trained_ckpt_path = f'/home/yyamada/Otter_/log/VI_batch128_long_pairs25/final_weights.pt'
trained_ckpt_path = f'../../log/{model_name}/final_weights.pt'

train_ckpt = torch.load(trained_ckpt_path, map_location="cpu")
if train_ckpt.get("model_state_dict", None) is not None:
    train_ckpt = train_ckpt["model_state_dict"]
_ = model.load_state_dict(train_ckpt, strict=False)

trainデータで検証

In [None]:
import ijson
import orjson

images = {}
images_path="/data/dataset/MIMIC-IT/VI_jsons/VI_train.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="/data/dataset/MIMIC-IT/VI_jsons/VI_train_pairs25_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

mimicit_path="/data/dataset/MIMIC-IT/VI_jsons/VI_train_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
# 正解率
from IPython.display import clear_output
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
yesno_count = 0
reason_count = 0
both_count = 0
NUM = 500
for i in range(len(keys[:NUM])):
    print(i)
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]
    
    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ
    
    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"
    
    inputs = textwrap.dedent(f"""
        <image>User:{instructions["data"][context1]["instruction"]} GPT:<answer>{instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][context2]["instruction"]} GPT:<answer>{instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].split(" ")[0].lower()==parsed_output.split(" ")[0].lower():
        yesno_count += 1
        if len(parsed_output.split(" ")) > 1:
            if instructions["data"][query]["answer"].split(" ")[1].lower()==parsed_output.split(" ")[1].lower():
                both_count += 1
    if len(parsed_output.split(" ")) > 1:
        if instructions["data"][query]["answer"].split(" ")[1].lower()==parsed_output.split(" ")[1].lower():
                reason_count += 1
    clear_output(wait=True)
    
print(f"yesno correct: {yesno_count}, total: {NUM}, acc: {(yesno_count / NUM) * 100:.2f}%")
print(f"reason correct: {reason_count}, total: {NUM}, acc: {(reason_count / NUM) * 100:.2f}%")
print(f"both correct: {both_count}, total: {NUM}, acc: {(both_count / NUM) * 100:.2f}%")

In [None]:
# 一部可視化
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 50
yesno_count = 0
reason_count = 0
both_count = 0
for i in range(len(keys[:NUM])):
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]

    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User:{instructions["data"][context1]["instruction"]} GPT:<answer>{instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][context2]["instruction"]} GPT:<answer>{instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].split(" ")[0].lower()==parsed_output.split(" ")[0].lower():
        yesno_count += 1
        if len(parsed_output.split(" ")) > 1:
            if instructions["data"][query]["answer"].split(" ")[1].lower()==parsed_output.split(" ")[1].lower():
                both_count += 1
    if len(parsed_output.split(" ")) > 1:
        if instructions["data"][query]["answer"].split(" ")[1].lower()==parsed_output.split(" ")[1].lower():
                reason_count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
    axes[0].imshow(demo_image_one)
    axes[0].axis('off')
    axes[1].imshow(demo_image_two)
    axes[1].axis('off')
    axes[2].imshow(query_image)
    axes[2].axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"yesno correct: {yesno_count}, total: {NUM}, acc: {(yesno_count / NUM) * 100:.2f}%")
print(f"reason correct: {reason_count}, total: {NUM}, acc: {(reason_count / NUM) * 100:.2f}%")
print(f"both correct: {both_count}, total: {NUM}, acc: {(both_count / NUM) * 100:.2f}%")

valデータで検証

In [None]:
# データ読み込み
import ijson
import orjson

images = {}
images_path="/data/dataset/MIMIC-IT/VI_jsons/VI_val.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="/data/dataset/MIMIC-IT/VI_jsons/VI_val_pairs1_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

mimicit_path="/data/dataset/MIMIC-IT/VI_jsons/VI_val_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
# 正解率
from IPython.display import clear_output
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
yesno_count = 0
reason_count = 0
both_count = 0
NUM = 500
for i in range(len(keys[:NUM])):
    print(i)
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]
    
    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ
    
    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"
    
    inputs = textwrap.dedent(f"""
        <image>User:{instructions["data"][context1]["instruction"]} GPT:<answer>{instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][context2]["instruction"]} GPT:<answer>{instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].split(" ")[0].lower()==parsed_output.split(" ")[0].lower():
        yesno_count += 1
        if len(parsed_output.split(" ")) > 1:
            if instructions["data"][query]["answer"].split(" ")[1].lower()==parsed_output.split(" ")[1].lower():
                both_count += 1
    if len(parsed_output.split(" ")) > 1:
        if instructions["data"][query]["answer"].split(" ")[1].lower()==parsed_output.split(" ")[1].lower():
                reason_count += 1
    clear_output(wait=True)
    
print(f"yesno correct: {yesno_count}, total: {NUM}, acc: {(yesno_count / NUM) * 100:.2f}%")
print(f"reason correct: {reason_count}, total: {NUM}, acc: {(reason_count / NUM) * 100:.2f}%")
print(f"both correct: {both_count}, total: {NUM}, acc: {(both_count / NUM) * 100:.2f}%")

In [None]:
# 一部可視化
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 50
yesno_count = 0
reason_count = 0
both_count = 0
for i in range(len(keys[:NUM])):
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]

    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User:{instructions["data"][context1]["instruction"]} GPT:<answer>{instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][context2]["instruction"]} GPT:<answer>{instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].split(" ")[0].lower()==parsed_output.split(" ")[0].lower():
        yesno_count += 1
        if len(parsed_output.split(" ")) > 1:
            if instructions["data"][query]["answer"].split(" ")[1].lower()==parsed_output.split(" ")[1].lower():
                both_count += 1
    if len(parsed_output.split(" ")) > 1:
        if instructions["data"][query]["answer"].split(" ")[1].lower()==parsed_output.split(" ")[1].lower():
                reason_count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
    axes[0].imshow(demo_image_one)
    axes[0].axis('off')
    axes[1].imshow(demo_image_two)
    axes[1].axis('off')
    axes[2].imshow(query_image)
    axes[2].axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"yesno correct: {yesno_count}, total: {NUM}, acc: {(yesno_count / NUM) * 100:.2f}%")
print(f"reason correct: {reason_count}, total: {NUM}, acc: {(reason_count / NUM) * 100:.2f}%")
print(f"both correct: {both_count}, total: {NUM}, acc: {(both_count / NUM) * 100:.2f}%")

### MVTecAD(コンテキスト)

In [8]:
import os
def get_image_paths(folder_path):
    image_extensions = ['.jpg', '.jpeg', '.png']  # 画像の拡張子リスト
    all_files = sorted(os.listdir(folder_path)) # フォルダ内の全てのファイルを取得
    image_paths = [os.path.join(folder_path, file) for file in all_files if os.path.splitext(file)[1].lower() in image_extensions] # 画像のパスを抽出してリストに格納
    return image_paths

def write_text_file(file_path, text):
    with open(file_path, mode="a") as f:
        f.write(text+"\n")
        
def generate_list_string(items):
    # アンダースコアをスペースに変換
    items = [item.replace('_', ' ') for item in items]
    if len(items) == 1:
        return items[0]
    elif len(items) == 2:
        return f"{items[0]} and {items[1]}"
    else:
        return ", ".join(items[:-1]) + f", and {items[-1]}"

In [9]:
# コンテキストのみ使用
from IPython.display import clear_output

def test(folder, sub_folder, GTs, model_name, order=True):
    acc = []
    if folder=="grid":
        folder__ = "metal grid"
    else:
        folder__ = folder
    folder__ = folder__.replace('_', ' ')
    for sub,gt in zip(sub_folder,GTs):
        folder_name = f'./result/{folder}/{sub}/{model_name}'
        os.makedirs(folder_name, exist_ok=True)
        with open(f'{folder_name}/detective.txt', mode='w') as f:
            f.close()
        with open(f'{folder_name}/non-detective.txt', mode='w') as f:
            f.close()
        
        subfolder_string = generate_list_string(GTs)
        model.text_tokenizer.padding_side = "left"
        sentence = f"{sub} --> {gt}"
        write_text_file(f'{folder_name}/detective.txt',sentence)
        
        """ クエリ：不良品 """
        if order: # demo_image_one: 良品, demo_image_two: 不良品
            sentence = f"context1: OK, context2: NG, query: NG"
            write_text_file(f'{folder_name}/detective.txt',sentence)
            demo_image_one = Image.open(f"/data/dataset/mvtec/{folder}/test/good/000.png").resize((224, 224)).convert("RGB")
            demo_image_two = Image.open(f"/data/dataset/mvtec/{folder}/test/{sub}/000.png").resize((224, 224)).convert("RGB")
            
            inputs = textwrap.dedent(f"""
                <image>User: This is an image of {folder__}. Does this {folder__} have any defects such as {subfolder_string}? If there are any defects, please provide the defect name. If not, please say None. GPT:<answer>No None<|endofchunk|>
                <image>User: This is an image of {folder__}. Does this {folder__} have any defects such as {subfolder_string}? If there are any defects, please provide the defect name. If not, please say None. GPT:<answer>Yes {gt}<|endofchunk|>
                <image>User: This is an image of {folder__}. Does this {folder__} have any defects such as {subfolder_string}? If there are any defects, please provide the defect name. If not, please say None. GPT:<answer>
            """)
        
        else: # demo_image_one: 不良品, demo_image_two: 良品
            sentence = f"context1: NG, context2: OK, query: NG"
            write_text_file(f'{folder_name}/detective.txt',sentence)
            demo_image_one = Image.open(f"/data/dataset/mvtec/{folder}/test/{sub}/000.png").resize((224, 224)).convert("RGB")
            demo_image_two = Image.open(f"/data/dataset/mvtec/{folder}/test/good/000.png").resize((224, 224)).convert("RGB")
            
            inputs = textwrap.dedent(f"""
                <image>User: This is an image of {folder__}. Does this {folder__} have any defects such as {subfolder_string}? If there are any defects, please provide the defect name. If not, please say None. GPT:<answer>Yes {gt}<|endofchunk|>
                <image>User: This is an image of {folder__}. Does this {folder__} have any defects such as {subfolder_string}? If there are any defects, please provide the defect name. If not, please say None. GPT:<answer>No None<|endofchunk|>
                <image>User: This is an image of {folder__}. Does this {folder__} have any defects such as {subfolder_string}? If there are any defects, please provide the defect name. If not, please say None. GPT:<answer>
            """)
        
        inputs = "".join(inputs.split("\n"))
        lang_x = model.text_tokenizer(
            [
                inputs
            ],
            return_tensors="pt",
        )
        
        write_text_file(f'{folder_name}/detective.txt',f'-----{sub} start-----')
        write_text_file(f'{folder_name}/detective.txt',"")
            
        query_folder_path = f"/data/dataset/mvtec/{folder}/test/{sub}"
        query_image_paths = get_image_paths(query_folder_path)
        yesno_count = 0
        reason_count = 0
        both_count = 0
        for i, query_image_path in enumerate(query_image_paths[1:]):
            # print(query_image_path)
            query_image = Image.open(query_image_path).resize((224, 224)).convert("RGB")
            vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
        
            # Get the data type from model's parameters
            model_dtype = next(model.parameters()).dtype

            # Convert tensors to the model's data type
            vision_x = vision_x.to(dtype=model_dtype)
            lang_x_input_ids = lang_x["input_ids"]
            lang_x_attention_mask = lang_x["attention_mask"]

            bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
            generated_text = model.generate(
                vision_x=vision_x.to(model.device),
                lang_x=lang_x_input_ids.to(model.device),
                attention_mask=lang_x_attention_mask.to(model.device),
                max_new_tokens=512,
                num_beams=3,
                no_repeat_ngram_size=3,
                bad_words_ids=bad_words_id,
            )

            parsed_output = (
                model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
            )
            
            print(inputs)
            print(parsed_output)
            
            if parsed_output.split(" ")[0].lower()=="yes":
                yesno_count += 1
                if len(parsed_output.split(" ")) > 1:
                    if parsed_output.split(" ")[1].lower()==f"{gt}":
                        both_count += 1
            if len(parsed_output.split(" ")) > 1:
                if parsed_output.split(" ")[1].lower()==f"{gt}":
                        reason_count += 1
                    
            write_text_file(f'{folder_name}/detective.txt',query_image_path)
            write_text_file(f'{folder_name}/detective.txt',parsed_output)
            write_text_file(f'{folder_name}/detective.txt',"")
            
            # print(inputs)
            # print("GPT:", parsed_output)
            
            # fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
            # axes[0].imshow(demo_image_one)
            # axes[0].axis('off')
            # axes[1].imshow(demo_image_two)
            # axes[1].axis('off')
            # axes[2].imshow(query_image)
            # axes[2].axis('off')
            # plt.show()
            print(i,sub)
            clear_output(wait=True)
            
        yesno_acc = f"correct: {yesno_count}, total: {len(query_image_paths)-1}, yesno acc: {(yesno_count / (len(query_image_paths)-1)) * 100:.2f}%"
        # print(yesno_acc)
        reason_acc = f"correct: {reason_count}, total: {len(query_image_paths)-1}, reason acc: {(reason_count / (len(query_image_paths)-1)) * 100:.2f}%"
        # print(reason_acc)
        both_acc = f"correct: {both_count}, total: {len(query_image_paths)-1}, both acc: {(both_count / (len(query_image_paths)-1)) * 100:.2f}%"
        # print(both_acc)
        acc.append((sub,yesno_acc))
        acc.append((sub,reason_acc))
        acc.append((sub,both_acc))
        
        write_text_file(f'{folder_name}/detective.txt',f'-----{sub} end-----')
        write_text_file(f'{folder_name}/detective.txt',yesno_acc)
        write_text_file(f'{folder_name}/detective.txt',reason_acc)
        write_text_file(f'{folder_name}/detective.txt',both_acc)
        
        
        """ クエリ：良品 """
        sentence = f"{sub} --> {gt}"
        write_text_file(f'{folder_name}/non-detective.txt',sentence)
        if order: # demo_image_one: 良品, demo_image_two: 不良品
            sentence = f"context1: OK, context2: NG, query: OK"
            write_text_file(f'{folder_name}/non-detective.txt',sentence)
        
        else: # demo_image_one: 不良品, demo_image_two: 良品
            sentence = f"context1: NG, context2: OK, query: OK"
            write_text_file(f'{folder_name}/non-detective.txt',sentence)
        
        write_text_file(f'{folder_name}/non-detective.txt',f'-----{sub} start-----')
        write_text_file(f'{folder_name}/non-detective.txt',"")
            
        query_folder_path = f"/data/dataset/mvtec/{folder}/test/good"
        query_image_paths = get_image_paths(query_folder_path)
        yesno_count = 0
        reason_count = 0
        both_count = 0
        for i, query_image_path in enumerate(query_image_paths[1:]):
            # print(query_image_path)
            query_image = Image.open(query_image_path).resize((224, 224)).convert("RGB")
            vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
        
            # Get the data type from model's parameters
            model_dtype = next(model.parameters()).dtype

            # Convert tensors to the model's data type
            vision_x = vision_x.to(dtype=model_dtype)
            lang_x_input_ids = lang_x["input_ids"]
            lang_x_attention_mask = lang_x["attention_mask"]

            bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
            generated_text = model.generate(
                vision_x=vision_x.to(model.device),
                lang_x=lang_x_input_ids.to(model.device),
                attention_mask=lang_x_attention_mask.to(model.device),
                max_new_tokens=512,
                num_beams=3,
                no_repeat_ngram_size=3,
                bad_words_ids=bad_words_id,
            )

            parsed_output = (
                model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
            )
            
            print(inputs)
            print(parsed_output)
            
            if parsed_output.split(" ")[0].lower()=="no":
                yesno_count += 1
                if len(parsed_output.split(" ")) > 1:
                    if parsed_output.split(" ")[1].lower()=="none":
                        both_count += 1
            if len(parsed_output.split(" ")) > 1:
                if parsed_output.split(" ")[1].lower()=="none":
                        reason_count += 1
                
            write_text_file(f'{folder_name}/non-detective.txt',query_image_path)
            write_text_file(f'{folder_name}/non-detective.txt',parsed_output)
            write_text_file(f'{folder_name}/non-detective.txt',"")
            
            # print(inputs)
            # print("GPT:", parsed_output)
            
            # fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
            # axes[0].imshow(demo_image_one)
            # axes[0].axis('off')
            # axes[1].imshow(demo_image_two)
            # axes[1].axis('off')
            # axes[2].imshow(query_image)
            # axes[2].axis('off')
            # plt.show()
            print(i,sub)
            clear_output(wait=True)
            
        yesno_acc = f"correct: {yesno_count}, total: {len(query_image_paths)-1}, yesno acc: {(yesno_count / (len(query_image_paths)-1)) * 100:.2f}%"
        # print(yesno_acc)
        reason_acc = f"correct: {reason_count}, total: {len(query_image_paths)-1}, reason acc: {(reason_count / (len(query_image_paths)-1)) * 100:.2f}%"
        # print(reason_acc)
        both_acc = f"correct: {both_count}, total: {len(query_image_paths)-1}, both acc: {(both_count / (len(query_image_paths)-1)) * 100:.2f}%"
        # print(both_acc)
        acc.append(("good",yesno_acc))
        acc.append(("good",reason_acc))
        acc.append(("good",both_acc))
        
        write_text_file(f'{folder_name}/non-detective.txt',f'-----{sub} end-----')
        write_text_file(f'{folder_name}/non-detective.txt',yesno_acc)
        write_text_file(f'{folder_name}/non-detective.txt',reason_acc)
        write_text_file(f'{folder_name}/non-detective.txt',both_acc)
        
    for a in acc:
        print(a)

In [10]:
# # クエリのみ使用
# def test(category, anormaly_reason, anormaly_type, model_name, order):
#     if category=="grid":
#         category__ = "metal grid"
#     else:
#         category__ = category
#     category__ = category__.replace('_', ' ')
#     for j, (ano_type,ano_reason) in enumerate(zip(anormaly_type,anormaly_reason)):
#         folder_name = f'./result/{category}/{ano_type}/{model_name}'
#         os.makedirs(folder_name, exist_ok=True)
#         with open(f'{folder_name}/detective_1.txt', mode='w') as f:
#             f.close()
#         with open(f'{folder_name}/non-detective_1.txt', mode='w') as f:
#             f.close()
        
#         subfolder_string = generate_list_string(anormaly_reason)
#         model.text_tokenizer.padding_side = "left"
        
#         """ クエリ：不良品 """
#         sentence = f"query: NG"
#         # print(sentence)
#         write_text_file(f'{folder_name}/detective_1.txt',sentence)
#         # long
#         inputs = textwrap.dedent(f"""
#            <image>User: This is an image of {category__}. Does this wood have any defects such as {subfolder_string}? GPT:<answer>
#         """)
#         # short
#         # inputs = textwrap.dedent(f"""
#         #     <image>User: This is an image of {category__}. Does this wood have any defects? GPT:<answer>
#         # """)        
        
#         inputs = "".join(inputs.split("\n"))
#         lang_x = model.text_tokenizer(
#             [
#                 inputs
#             ],
#             return_tensors="pt",
#         )
        
#         write_text_file(f'{folder_name}/detective_1.txt',f'-----{ano_type} start-----')
#         write_text_file(f'{folder_name}/detective_1.txt',"")
            
#         query_folder_path = f"/data/dataset/mvtec/{category}/test/{ano_type}"
#         query_image_paths = get_image_paths(query_folder_path)
#         count = 0
#         for i, query_image_path in enumerate(query_image_paths[1:]):
#             # print(query_image_path)
#             query_image = Image.open(query_image_path).resize((224, 224)).convert("RGB")
#             vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
        
#             # Get the data type from model's parameters
#             model_dtype = next(model.parameters()).dtype

#             # Convert tensors to the model's data type
#             vision_x = vision_x.to(dtype=model_dtype)
#             lang_x_input_ids = lang_x["input_ids"]
#             lang_x_attention_mask = lang_x["attention_mask"]

#             bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
#             generated_text = model.generate(
#                 vision_x=vision_x.to(model.device),
#                 lang_x=lang_x_input_ids.to(model.device),
#                 attention_mask=lang_x_attention_mask.to(model.device),
#                 max_new_tokens=512,
#                 num_beams=3,
#                 no_repeat_ngram_size=3,
#                 bad_words_ids=bad_words_id,
#             )

#             parsed_output = (
#                 model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
#             )
            
#             if parsed_output.split(".")[0].lower()=="yes":
#                 count += 1
            
#             write_text_file(f'{folder_name}/detective_1.txt',query_image_path)
#             write_text_file(f'{folder_name}/detective_1.txt',parsed_output)
#             write_text_file(f'{folder_name}/detective_1.txt',"")
            
#             # print(inputs)
#             # print("GPT:", parsed_output)
            
#             # fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
#             # axes[0].imshow(demo_image_one)
#             # axes[0].axis('off')
#             # axes[1].imshow(demo_image_two)
#             # axes[1].axis('off')
#             # axes[2].imshow(query_image)
#             # axes[2].axis('off')
#             # plt.show()
            
#         acc = f"correct: {count}, total: {len(query_image_paths)-1}, acc: {(count / (len(query_image_paths)-1)) * 100:.2f}%"
#         print(acc)
        
#         write_text_file(f'{folder_name}/detective_1.txt',f'-----{ano_type} end-----')
#         write_text_file(f'{folder_name}/detective_1.txt',acc)
        
#         """ クエリ：良品 """
#         if j==0:
#             sentence = f"query: OK"
#             # print(sentence)
#             write_text_file(f'{folder_name}/non-detective_1.txt',sentence)
#             write_text_file(f'{folder_name}/non-detective_1.txt',f'-----{ano_type} start-----')
#             write_text_file(f'{folder_name}/non-detective_1.txt',"")
                
#             query_folder_path = f"/data/dataset/mvtec/{category}/test/good"
#             query_image_paths = get_image_paths(query_folder_path)
#             count = 0
#             for i, query_image_path in enumerate(query_image_paths[1:]):
#                 # print(query_image_path)
#                 query_image = Image.open(query_image_path).resize((224, 224)).convert("RGB")
#                 vision_x = image_processor.preprocess([query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
            
#                 # Get the data type from model's parameters
#                 model_dtype = next(model.parameters()).dtype

#                 # Convert tensors to the model's data type
#                 vision_x = vision_x.to(dtype=model_dtype)
#                 lang_x_input_ids = lang_x["input_ids"]
#                 lang_x_attention_mask = lang_x["attention_mask"]

#                 bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
#                 generated_text = model.generate(
#                     vision_x=vision_x.to(model.device),
#                     lang_x=lang_x_input_ids.to(model.device),
#                     attention_mask=lang_x_attention_mask.to(model.device),
#                     max_new_tokens=512,
#                     num_beams=3,
#                     no_repeat_ngram_size=3,
#                     bad_words_ids=bad_words_id,
#                 )

#                 parsed_output = (
#                     model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
#                 )
                
#                 if parsed_output.split(".")[0].lower()=="no":
#                     count += 1
                
#                 write_text_file(f'{folder_name}/non-detective_1.txt',query_image_path)
#                 write_text_file(f'{folder_name}/non-detective_1.txt',parsed_output)
#                 write_text_file(f'{folder_name}/non-detective_1.txt',"")
                
#                 # print(inputs)
#                 # print("GPT:", parsed_output)
                
#                 # fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
#                 # axes[0].imshow(demo_image_one)
#                 # axes[0].axis('off')
#                 # axes[1].imshow(demo_image_two)
#                 # axes[1].axis('off')
#                 # axes[2].imshow(query_image)
#                 # axes[2].axis('off')
#                 # plt.show()
                
#             acc = f"correct: {count}, total: {len(query_image_paths)-1}, acc: {(count / (len(query_image_paths)-1)) * 100:.2f}%"
#             print(acc)
            
#             write_text_file(f'{folder_name}/non-detective_1.txt',f'-----{ano_type} end-----')
#             write_text_file(f'{folder_name}/non-detective_1.txt',acc)

In [11]:
torch.cuda.is_available()

True

In [12]:
folder = "bottle"
sub_folder = ["broken_large","contamination"]
GTs = ['broken','contamination']
test(folder, sub_folder, GTs, model_name)

# sub_folder = ["broken_small","contamination"]
# test(folder, sub_folder, GTs, model_name)

('broken_large', 'correct: 18, total: 19, yesno acc: 94.74%')
('broken_large', 'correct: 0, total: 19, reason acc: 0.00%')
('broken_large', 'correct: 0, total: 19, both acc: 0.00%')
('good', 'correct: 0, total: 19, yesno acc: 0.00%')
('good', 'correct: 0, total: 19, reason acc: 0.00%')
('good', 'correct: 0, total: 19, both acc: 0.00%')
('contamination', 'correct: 17, total: 20, yesno acc: 85.00%')
('contamination', 'correct: 0, total: 20, reason acc: 0.00%')
('contamination', 'correct: 0, total: 20, both acc: 0.00%')
('good', 'correct: 0, total: 19, yesno acc: 0.00%')
('good', 'correct: 0, total: 19, reason acc: 0.00%')
('good', 'correct: 0, total: 19, both acc: 0.00%')


In [None]:
folder = "cable"
sub_folder = ["bent_wire","cable_swap","cut_inner_insulation","missing_cable","poke_insulation"]
GTs = ['bent','swapp','crack','missing','hole']
test(folder, sub_folder, GTs, model_name)

sub_folder = ["bent_wire","cable_swap","cut_outer_insulation","missing_wire","poke_insulation"]
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "capsule"
sub_folder = ["crack","faulty_imprint","poke","scratch","squeeze"]
GTs = ['crack','misprint','hole','scratch','misshapen']
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "carpet"
sub_folder = ["color","cut","hole","metal_contamination"]
GTs = ['stain','cut','hole','contamination']
test(folder, sub_folder, GTs, model_name)

sub_folder = ["color","cut","hole","thread"]
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "grid"
sub_folder = ["bent","broken","glue"]
GTs = ["bent","broken","contamination"]
test(folder, sub_folder, GTs, model_name)

sub_folder = ["bent","broken","metal_contamination"]
test(folder, sub_folder, GTs, model_name)

sub_folder = ["bent","broken","thread"]
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "hazelnut"
sub_folder = ["crack","cut","hole","print"]
GTs = ['crack','scratch','hole','misprint']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "leather"
sub_folder = ["color","cut","fold","poke"]
GTs = ['stain','scratch','wrinkle','hole']
test(folder, sub_folder, GTs, model_name)

sub_folder = ["glue","cut","fold","poke"]
test(folder, sub_folder, GTs, model_name)

sub_folder = ["glue","poke","fold","cut"]
GTs = ['stain','hole','wrinkle','scratch']
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "metal_nut"
sub_folder = ["bent","color","flip","scratch"]
GTs = ['bent','stain','flip','scratch']

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "pill"
sub_folder = ["color","contamination","crack","faulty_imprint","scratch"]
GTs = ['stain','contamination','crack','misprint','scratch']
test(folder, sub_folder, GTs, model_name)

sub_folder = ["pill_type","contamination","crack","faulty_imprint","scratch"]
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "screw"
sub_folder = ["scratch_head","manipulated_front"]
GTs = ['chip','strip']
test(folder, sub_folder, GTs, model_name)

sub_folder = ["scratch_neck","manipulated_front"]
test(folder, sub_folder, GTs, model_name)

sub_folder = ["thread_side","manipulated_front"]
test(folder, sub_folder, GTs, model_name)

sub_folder = ["thread_top","manipulated_front"]
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "tile"
sub_folder = ["crack","glue_strip","gray_stroke"]
GTs = ['crack','contamination','stain']
test(folder, sub_folder, GTs, model_name)

sub_folder = ["crack","rough","oil"]
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "toothbrush"
sub_folder = ["defective"]
GTs = ["broken"]

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "transistor"
sub_folder = ["bent_lead","cut_lead","damaged_case","misplaced"]
GTs = ["bent","cut","broken","misalignment"]

test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "wood"
sub_folder = ["color","scratch","hole"]
GTs = ["stain","scratch","hole"]
test(folder, sub_folder, GTs, model_name)

sub_folder = ["liquid","scratch","hole"]
test(folder, sub_folder, GTs, model_name)

In [None]:
folder = "zipper"
sub_folder = ["broken_teeth","fabric_border","fabric_interior","split_teeth"]
GTs = ["broken","tear","frayed","misshapen"]
test(folder, sub_folder, GTs, model_name)

sub_folder = ["broken_teeth","fabric_border","rough","squeezed_teeth"]
test(folder, sub_folder, GTs, model_name)

### Imagenet検証

In [None]:
model_name = "MI_loss_context_query_debug/batch128_epoch1_lr-5_pairs4"
trained_ckpt_path = f'../../log/{model_name}/final_weights.pt'

train_ckpt = torch.load(trained_ckpt_path, map_location="cpu")
if train_ckpt.get("model_state_dict", None) is not None:
    train_ckpt = train_ckpt["model_state_dict"]
_ = model.load_state_dict(train_ckpt, strict=False)


In [None]:
import ijson
import orjson

print("train")
images = {}
images_path="/data/dataset/MIMIC-IT/MiniImagenet_jsons/MI_train.json"
with open(images_path, "rb") as f:
    for key, value in ijson.kvitems(f, "", use_float=True):
        images[key] = value

train_config_path="/data/dataset/MIMIC-IT/MiniImagenet_jsons/MI_train_pairs4_train.json"
with open(train_config_path, "rb") as f:
    cache_train_config = orjson.loads(f.read())

mimicit_path="/data/dataset/MIMIC-IT/MiniImagenet_jsons/MI_train_instructions.json"
with open(mimicit_path, "rb") as f:
    instructions = orjson.loads(f.read())

In [None]:
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from IPython.display import clear_output

# 正解率
keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
count = 0
NUM = 100
for i in range(len(keys[:NUM])):
    print(i)
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]
    
    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ
    
    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"
    
    inputs = textwrap.dedent(f"""
        <image>User:{instructions["data"][context1]["instruction"]} GPT:<answer>{instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][context2]["instruction"]} GPT:<answer>{instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].lower()==parsed_output.lower():
        count += 1
    clear_output(wait=True)

print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")

In [None]:
# 一部可視化
import random
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO


keys = list(cache_train_config.keys())
random.seed(42)
random.shuffle(keys)
NUM = 50
count = 0
for i in range(len(keys[:NUM])):
    context1 = cache_train_config[keys[i]][0]
    context2 = cache_train_config[keys[i]][1]
    query = keys[i].split('=')[0]

    str_data1 = images[context1] # コンテキスト1
    str_data2 = images[context2] # コンテキスト2
    str_data3 = images[query] # クエリ

    # デコードしたバイトデータをImageオブジェクトに変換
    demo_image_one = Image.open(BytesIO(base64.urlsafe_b64decode(str_data1))).convert("RGB")
    demo_image_two = Image.open(BytesIO(base64.urlsafe_b64decode(str_data2))).convert("RGB")
    query_image = Image.open(BytesIO(base64.urlsafe_b64decode(str_data3))).convert("RGB")

    vision_x = image_processor.preprocess([demo_image_one, demo_image_two, query_image], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    model.text_tokenizer.padding_side = "left"

    inputs = textwrap.dedent(f"""
        <image>User:{instructions["data"][context1]["instruction"]} GPT:<answer>{instructions["data"][context1]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][context2]["instruction"]} GPT:<answer>{instructions["data"][context2]["answer"]}<|endofchunk|>
        <image>User:{instructions["data"][query]["instruction"]} GPT:<answer>
    """)
    inputs = "".join(inputs.split("\n"))
    lang_x = model.text_tokenizer(
        [
            inputs
        ],
        return_tensors="pt",
    )

    # Get the data type from model's parameters
    model_dtype = next(model.parameters()).dtype

    # Convert tensors to the model's data type
    vision_x = vision_x.to(dtype=model_dtype)
    lang_x_input_ids = lang_x["input_ids"]
    lang_x_attention_mask = lang_x["attention_mask"]

    bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x_input_ids.to(model.device),
        attention_mask=lang_x_attention_mask.to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )

    parsed_output = (
        model.text_tokenizer.decode(generated_text[0]).split("<answer>")[-1].lstrip().rstrip().split("<|endofchunk|>")[0].lstrip().rstrip().lstrip('"').rstrip('"')
    )
    
    if instructions["data"][query]["answer"].lower()==parsed_output.lower():
        count += 1
        
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,5))
    axes[0].imshow(demo_image_one)
    axes[0].axis('off')
    axes[1].imshow(demo_image_two)
    axes[1].axis('off')
    axes[2].imshow(query_image)
    axes[2].axis('off')
    print("--------------------------------------------------------")
    print(query)
    print(inputs)
    print("GPT:", parsed_output)
    plt.show()

print(f"correct: {count}, total: {NUM}, acc: {(count / NUM) * 100:.2f}%")