# 画像データとテキストデータの両方を同時に扱うことができるMMBTを用いる。

### まずはセッティング

In [8]:
## colabでこのノートを実行する時のみ使うセル
# from google.colab import drive
# drive.mount('/content/drive')

# !pip install transformers[ja]
# !pip install --quiet sentencepiece

In [11]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os
import random

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.metrics import log_loss, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from scipy.special import softmax

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset

from transformers import (
    AutoTokenizer, AutoModel, MMBTForClassification, MMBTConfig, AutoConfig,
    Trainer, TrainingArguments,
)
import transformers

from torchvision.io import read_image
from torchvision.models import ResNet152_Weights, resnet152

from matplotlib import pyplot as plt
import seaborn as sns

In [12]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

seed_everything(42)

In [14]:
## colab上のパス
#  INPUT = "/content/drive/MyDrive/Nishika/bokete"
# train_image_path = "/content/train/"
# test_image_path = "/content/test/"

# ローカル上のパス 
INPUT = "/Users/koshidatatsuo/python/nishika/bokete"
train_image_path = "/content/train/"
test_image_path = "/content/test/"

train_df = pd.read_csv(os.path.join(INPUT, "train.csv"))
test_df = pd.read_csv(os.path.join(INPUT, "test.csv"))
submission_df = pd.read_csv(os.path.join(INPUT, "sample_submission.csv"))

train_df["img_path"] = train_image_path + train_df["odai_photo_file_name"]
test_df["img_path"] = test_image_path + test_df["odai_photo_file_name"]

ちゃんとデータが取得できているか確認

In [16]:
print(f"train_data: {train_df.shape}")
display(train_df.head())

print(f"test_data: {test_df.shape}")
display(test_df.head())

train_data: (24962, 5)


Unnamed: 0,id,odai_photo_file_name,text,is_laugh,img_path
0,ge5kssftl,9fkys1gb2r.jpg,君しょっちゅうソレ自慢するけど、ツムジ２個ってそんなに嬉しいのかい？,0,/content/train/9fkys1gb2r.jpg
1,r7sm6tvkj,c6ag0m1lak.jpg,これでバレない？授業中寝てもバレない？,0,/content/train/c6ag0m1lak.jpg
2,yp5aze0bh,whtn6gb9ww.jpg,「あなたも感じる？」\n『ああ…、感じてる…』\n「後ろに幽霊いるよね…」\n『女のな…』,0,/content/train/whtn6gb9ww.jpg
3,ujaixzo56,6yk5cwmrsy.jpg,大塚愛聞いてたらお腹減った…さく、らんぼと牛タン食べたい…,0,/content/train/6yk5cwmrsy.jpg
4,7vkeveptl,0i9gsa2jsm.jpg,熊だと思ったら嫁だった,0,/content/train/0i9gsa2jsm.jpg


test_data: (6000, 4)


Unnamed: 0,id,odai_photo_file_name,text,img_path
0,rfdjcfsqq,nc1kez326b.jpg,僕のママ、キャラ弁のゆでたまごに８時間かかったんだ,/content/test/nc1kez326b.jpg
1,tsgqmfpef,49xt2fmjw0.jpg,かわいいが作れた！,/content/test/49xt2fmjw0.jpg
2,owjcthkz2,9dtscjmyfh.jpg,来世の志茂田景樹,/content/test/9dtscjmyfh.jpg
3,rvgaocjyy,osa3n56tiv.jpg,ちょ、あの、オカン、これ水風呂やねんけど、なんの冗談??,/content/test/osa3n56tiv.jpg
4,uxtwu5i69,yb1yqs4pvb.jpg,「今日は皆さんにザリガニと消防車の違いを知ってもらいたいと思います」『どっちも同じだろ。両方...,/content/test/yb1yqs4pvb.jpg


In [17]:
test_df["is_laugh"] = 0

## MMBT

MMBTとはMultiModal BiTransformersの略であり、BERTをベースとした画像とテキストのマルチモーダルディープラーニングです。画像にはResNet152を、テキスト側はBERTを用いてそれぞれベクトル変換し、両方をtokenとして連結したものに再度BERTに入力します。  
https://arxiv.org/pdf/1909.02950.pdf

https://github.com/facebookresearch/mmbt

すでにhuggingface内にモデルがあるので、今回はこちらを使用していきたいと思います。
https://huggingface.co/docs/transformers/main/en/model_summary#multimodal-models

In [18]:
# 画像データをEmbeddingしていきます
class ImageEncoder(nn.Module):
    POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
    def __init__(self, pretrained_weight):
        super().__init__()
        model = resnet152(weights=pretrained_weight)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.pool = nn.AdaptiveAvgPool2d(self.POOLING_BREAKDOWN[3])

    def forward(self,  x):
        out = self.pool(self.model(x))
        out = torch.flatten(out, start_dim=2)
        out = out.transpose(1, 2).contiguous()
        return out

In [19]:
def read_jpg(path):
    image_tensor = read_image(path)
    if image_tensor.shape[0] == 1:
        # 1channel=白黒画像があるので3channelにconvertしています。
        image_tensor = image_tensor.expand(3, *image_tensor.shape[1:])
    return image_tensor

class BoketeTextImageDataset(Dataset):
    def __init__(self, df, tokenizer, max_seq_len:int, image_transform):
        self.df = df
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.image_transforms = image_transform.transforms()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        sentence = torch.tensor(self.tokenizer.encode(row["text"], max_length=self.max_seq_len, padding="max_length", truncation=True))
        start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
        sentence = sentence[:self.max_seq_len]

        image = self.image_transforms(read_jpg(row["img_path"]))

        return {
            "image_start_token": start_token,
            "image_end_token": end_token,
            "sentence": sentence,
            "image": image,
            "label": torch.tensor(row["is_laugh"]),
        }

def collate_fn(batch):
    lens = [len(row["sentence"]) for row in batch]
    bsz, max_seq_len = len(batch), max(lens)

    mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
    text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)

    for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
        text_tensor[i_batch, :length] = input_row["sentence"]
        mask_tensor[i_batch, :length] = 1

    img_tensor = torch.stack([row["image"] for row in batch])
    tgt_tensor = torch.stack([row["label"] for row in batch])
    img_start_token = torch.stack([row["image_start_token"] for row in batch])
    img_end_token = torch.stack([row["image_end_token"] for row in batch])

    return {
        "input_ids":text_tensor,
        "attention_mask":mask_tensor,
        "input_modal":img_tensor,
        "modal_start_tokens":img_start_token,
        "modal_end_tokens":img_end_token,
        "labels":tgt_tensor,
    }

学習済みモデルには、東北大学の乾研究室が作成したものを使用します。

In [None]:
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-whole-word-masking")

### データ分割

In [None]:
trn_idx, val_idx = train_test_split(list(range(len(train_df))), test_size=0.2, random_state=42, stratify=train_df["is_laugh"])

In [None]:
trn_ds = BoketeTextImageDataset(train_df.iloc[trn_idx], tokenizer, 48, image_transform=ResNet152_Weights.IMAGENET1K_V2)
val_ds = BoketeTextImageDataset(train_df.iloc[val_idx], tokenizer, 48, image_transform=ResNet152_Weights.IMAGENET1K_V2)

In [None]:
test_ds = BoketeTextImageDataset(test_df, tokenizer, 48, image_transform=ResNet152_Weights.IMAGENET1K_V2)

In [None]:
config = MMBTConfig(transformer_config, num_labels=2)
model = MMBTForClassification(config, transformer, ImageEncoder(ResNet152_Weights.IMAGENET1K_V2))

In [None]:
config.use_return_dict = True

In [None]:
model.config = model.mmbt.config

In [None]:
trainer_args = TrainingArguments(
    output_dir="/content/mmbt_exp01",
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    num_train_epochs=3,
    evaluation_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    logging_steps=50,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=12,
    save_total_limit=1,
    fp16=True,
    remove_unused_columns=False,
    gradient_accumulation_steps=20,
    load_best_model_at_end=True,
    logging_dir='./logs',
    report_to="none"
)

In [None]:
trainer = Trainer(
    model=model,
    args=trainer_args,
    tokenizer=tokenizer,
    train_dataset=trn_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
)

学習の開始

In [None]:
trainer.train()

In [None]:
val_preds = trainer.predict(val_ds).predictions

In [None]:
# sanity check
log_loss(val_ds.df["is_laugh"].values, softmax(val_preds, axis=-1))

In [None]:
accuracy_score(val_ds.df["is_laugh"].values, np.argmax(val_preds, axis=-1))

In [None]:
_conf_options = {"normalize": None,}
_plot_options = {
        "cmap": "Blues",
        "annot": True
    }

conf = confusion_matrix(y_true=val_ds.df["is_laugh"].values,
                        y_pred=np.argmax(val_preds, axis=-1),
                        **_conf_options)

fig, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(conf, ax=ax, **_plot_options)
ax.set_ylabel("Label")
ax.set_xlabel("Predict")

### 予測を行う

In [None]:
preds = trainer.predict(test_ds).predictions

In [None]:
submission_df["is_laugh"] = softmax(preds, axis=-1)[:, 1]

In [None]:
submission_df["is_laugh"] = submission_df["is_laugh"].astype(float)

In [None]:
## colab上で実行する時はこちらのパス
#OUTPUT = "/content/drive/MyDrive/Nishika/bokete"

OUTPUT = "/Users/koshidatatsuo/python/nishika/bokete" # ディレクトリを指定してください
submission_df.to_csv(os.path.join(OUTPUT,'submission2.csv'), index=False)