# ボケ判定AIを作ろう！-チュートリアル1
このnotebookは、Nishikaコンペティション [ボケ判定AIを作ろう！](https://www.nishika.com/competitions/) のチュートリアルです。

「ボケて」データを用いて、画像データと文章からそのボケてが面白いか面白くないかを予測することをテーマとしています。

このNotebookでは、画像とテキストそれぞれの特徴量生成を以下のような方法で行っていきます。

- CNNモデルを用いた画像データの特徴量化
- BERTモデルを用いたテキストデータの特徴量化

特徴量の作成では、テキストと画像それぞれ別々で作成していますので、画像データとテキストデータを組み合わせた特徴量を入れることで精度向上が見込めるかも知れませんので、いろいろと試していただければと思います。

| 要素 | 説明 |
| ---- | ---- |
|id | ID|
|odai_photo_file_name | ボケてのお題画像|
|text | ボケての文章|
|is_laugh | 面白さ（面白い：１、面白くない：０）|


ディレクトリ構成は以下のように設定します

```
├── train.zip
│ ├── xxx.jpg
│ └── yyy.jpg
├── test.zip
│ ├── xxx.jpg
│ └── yyy.jpg
├── train.csv
├── test.csv
├── sample_submission.csv
└── submission.csv(今回のbaselineで生成されるsubmissionファイル)


```

### setting
ページ上部の「ランタイム」>「ランタイムのタイプを変更」から「GPU」「ハイメモリ」を選択

In [1]:
!pip3 install tf-nightly

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
!pip install keras-cv-attention-models

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
!nvidia-smi

Mon Sep 26 14:03:10 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P0    43W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Library

In [5]:
!pip install --quiet transformers==4.18.0
!pip install --quiet tokenizers==0.12.1
!pip install --quiet sentencepiece
!pip install --quiet japanize-matplotlib
!pip install transformers fugashi ipadic >> /dev/null

In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

import torch


from sklearn.metrics import mean_squared_error
from sklearn.metrics import log_loss
from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import sys
import os
import re
import random

from time import time
from tqdm import tqdm

from contextlib import contextmanager
import lightgbm as lgb

import re
import requests
import unicodedata
import nltk
from nltk.corpus import wordnet
from bs4 import BeautifulSoup
nltk.download(['wordnet', 'stopwords', 'punkt'])

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

# Setting

In [7]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

seed_everything(42)

In [8]:
INPUT = "/content/drive/MyDrive/nishika/" # 所望のディレクトリに変更してください。
train_image_path = "/content/drive/MyDrive/nishika/train/"
test_image_path = "/content/drive/MyDrive/nishika/test/"

# Read Data
学習データと推論データについて、目的変数の分布などを確認していきます。

In [9]:
train_df = pd.read_csv(os.path.join(INPUT, "train.csv"))
test_df = pd.read_csv(os.path.join(INPUT, "test.csv"))
submission_df = pd.read_csv(os.path.join(INPUT, "sample_submission.csv"))

In [10]:
print(f"train_data: {train_df.shape}")
display(train_df.head())

print(f"test_data: {test_df.shape}")
display(test_df.head())

train_data: (24962, 4)


Unnamed: 0,id,odai_photo_file_name,text,is_laugh
0,ge5kssftl,9fkys1gb2r.jpg,君しょっちゅうソレ自慢するけど、ツムジ２個ってそんなに嬉しいのかい？,0
1,r7sm6tvkj,c6ag0m1lak.jpg,これでバレない？授業中寝てもバレない？,0
2,yp5aze0bh,whtn6gb9ww.jpg,「あなたも感じる？」\n『ああ…、感じてる…』\n「後ろに幽霊いるよね…」\n『女のな…』,0
3,ujaixzo56,6yk5cwmrsy.jpg,大塚愛聞いてたらお腹減った…さく、らんぼと牛タン食べたい…,0
4,7vkeveptl,0i9gsa2jsm.jpg,熊だと思ったら嫁だった,0


test_data: (6000, 3)


Unnamed: 0,id,odai_photo_file_name,text
0,rfdjcfsqq,nc1kez326b.jpg,僕のママ、キャラ弁のゆでたまごに８時間かかったんだ
1,tsgqmfpef,49xt2fmjw0.jpg,かわいいが作れた！
2,owjcthkz2,9dtscjmyfh.jpg,来世の志茂田景樹
3,rvgaocjyy,osa3n56tiv.jpg,ちょ、あの、オカン、これ水風呂やねんけど、なんの冗談??
4,uxtwu5i69,yb1yqs4pvb.jpg,「今日は皆さんにザリガニと消防車の違いを知ってもらいたいと思います」『どっちも同じだろ。両方...


# Create Image Features

ボケてというものは、画像と文章の組み合わせで面白さを表現しているので、以下にして画像のデータと文章のデータをモデルに学習させるかがポイントになってくるかと思います。

画像のデータを特徴量として用いるために、今回はDenseNet121の学習済みモデルを用います。

In [11]:
import cv2
from keras.models import Model
from keras.layers import GlobalAveragePooling2D, Input, Lambda, AveragePooling1D
import keras.backend as K
from tqdm import tqdm, tqdm_notebook
from keras_cv_attention_models import beit

In [12]:
class CFG:
    img_size = 224
    batch_size = 17

In [13]:
def resize_to_square(im):
    old_size = im.shape[:2] 
    ratio = float(CFG.img_size)/max(old_size)
    new_size = tuple([int(x*ratio) for x in old_size])
    # 画像サイズを224×224に変更します
    im = cv2.resize(im, (new_size[1], new_size[0]))
    delta_w = CFG.img_size - new_size[1]
    delta_h = CFG.img_size - new_size[0]
    top, bottom = delta_h//2, delta_h-(delta_h//2)
    left, right = delta_w//2, delta_w-(delta_w//2)
    color = [0, 0, 0]
    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT,value=color)
    return new_im


def load_image(ids, is_train=True):
  if is_train:
    image = cv2.imread(train_image_path+ids)
  else:
    image = cv2.imread(test_image_path+ids)
  new_image = resize_to_square(image)
  return new_image

In [14]:
m = beit.BeitBasePatch16(pretrained="imagenet21k-ft1k")

>>>> Load pretrained from: /root/.keras/models/beit_base_patch16_224_imagenet21k-ft1k.h5


In [15]:
image_df_train = train_df[["id", "odai_photo_file_name"]].copy()
image_df_train.head()

Unnamed: 0,id,odai_photo_file_name
0,ge5kssftl,9fkys1gb2r.jpg
1,r7sm6tvkj,c6ag0m1lak.jpg
2,yp5aze0bh,whtn6gb9ww.jpg
3,ujaixzo56,6yk5cwmrsy.jpg
4,7vkeveptl,0i9gsa2jsm.jpg


In [16]:
image_ids = image_df_train["odai_photo_file_name"].values
n_batches = len(image_ids) // CFG.batch_size + 1

In [None]:
features = {}
for b in tqdm(range(n_batches)):
    start = b*CFG.batch_size
    end = (b+1)*CFG.batch_size
    batch_ids = image_ids[start:end]
    batch_images = np.zeros((len(batch_ids),CFG.img_size,CFG.img_size,3))
    for i,image_id in enumerate(batch_ids):
        try:
            batch_images[i] = load_image(image_id)
        except:
          print("Error")
    batch_preds = m.predict(batch_images)
    for i,image_id in enumerate(batch_ids):
        features[image_id] = batch_preds[i]

  0%|          | 0/1469 [00:00<?, ?it/s]



  0%|          | 1/1469 [00:05<2:17:17,  5.61s/it]



  0%|          | 2/1469 [00:06<1:02:37,  2.56s/it]



  0%|          | 3/1469 [00:06<39:48,  1.63s/it]  



  0%|          | 4/1469 [00:06<28:13,  1.16s/it]



  0%|          | 5/1469 [00:07<22:08,  1.10it/s]



  0%|          | 6/1469 [00:07<17:55,  1.36it/s]



  0%|          | 7/1469 [00:08<15:51,  1.54it/s]



  1%|          | 8/1469 [00:08<14:59,  1.62it/s]



  1%|          | 9/1469 [00:09<13:22,  1.82it/s]



  1%|          | 10/1469 [00:09<12:31,  1.94it/s]



  1%|          | 11/1469 [00:10<11:43,  2.07it/s]



  1%|          | 12/1469 [00:10<11:22,  2.14it/s]



  1%|          | 13/1469 [00:10<11:02,  2.20it/s]



  1%|          | 14/1469 [00:11<11:43,  2.07it/s]



  1%|          | 15/1469 [00:11<10:55,  2.22it/s]



  1%|          | 16/1469 [00:12<10:09,  2.38it/s]



  1%|          | 17/1469 [00:12<09:34,  2.53it/s]



  1%|          | 18/1469 [00:12<09:08,  2.65it/s]



  1%|▏         | 19/1469 [00:13<08:51,  2.73it/s]



  1%|▏         | 20/1469 [00:13<09:37,  2.51it/s]



  1%|▏         | 21/1469 [00:14<09:13,  2.62it/s]



  1%|▏         | 22/1469 [00:14<08:55,  2.70it/s]



  2%|▏         | 23/1469 [00:14<08:43,  2.76it/s]



  2%|▏         | 24/1469 [00:15<08:39,  2.78it/s]



  2%|▏         | 25/1469 [00:15<08:35,  2.80it/s]



  2%|▏         | 26/1469 [00:15<08:31,  2.82it/s]



  2%|▏         | 27/1469 [00:16<09:23,  2.56it/s]



  2%|▏         | 28/1469 [00:16<09:03,  2.65it/s]



  2%|▏         | 29/1469 [00:16<08:46,  2.74it/s]



  2%|▏         | 30/1469 [00:17<08:33,  2.80it/s]



  2%|▏         | 31/1469 [00:17<08:28,  2.83it/s]



  2%|▏         | 32/1469 [00:18<08:21,  2.86it/s]



  2%|▏         | 33/1469 [00:18<08:17,  2.89it/s]



  2%|▏         | 34/1469 [00:18<08:12,  2.91it/s]



  2%|▏         | 35/1469 [00:19<08:14,  2.90it/s]



  2%|▏         | 36/1469 [00:19<08:15,  2.89it/s]



  3%|▎         | 37/1469 [00:19<08:15,  2.89it/s]



  3%|▎         | 38/1469 [00:20<09:09,  2.60it/s]



  3%|▎         | 39/1469 [00:20<08:51,  2.69it/s]



  3%|▎         | 40/1469 [00:20<08:37,  2.76it/s]



  3%|▎         | 41/1469 [00:21<08:31,  2.79it/s]



  3%|▎         | 42/1469 [00:21<08:20,  2.85it/s]



  3%|▎         | 43/1469 [00:21<08:18,  2.86it/s]



  3%|▎         | 44/1469 [00:22<08:18,  2.86it/s]



  3%|▎         | 45/1469 [00:22<08:09,  2.91it/s]



  3%|▎         | 46/1469 [00:22<08:10,  2.90it/s]



  3%|▎         | 47/1469 [00:23<08:03,  2.94it/s]



  3%|▎         | 48/1469 [00:23<08:03,  2.94it/s]



  3%|▎         | 49/1469 [00:23<08:01,  2.95it/s]



  3%|▎         | 50/1469 [00:24<07:59,  2.96it/s]



  3%|▎         | 51/1469 [00:24<07:58,  2.96it/s]



  4%|▎         | 52/1469 [00:24<07:59,  2.96it/s]



  4%|▎         | 53/1469 [00:25<08:03,  2.93it/s]



  4%|▎         | 54/1469 [00:25<09:06,  2.59it/s]



  4%|▎         | 55/1469 [00:26<08:47,  2.68it/s]



  4%|▍         | 56/1469 [00:30<33:44,  1.43s/it]



  4%|▍         | 57/1469 [00:34<54:22,  2.31s/it]



  4%|▍         | 58/1469 [00:38<1:07:17,  2.86s/it]



  4%|▍         | 59/1469 [00:42<1:12:36,  3.09s/it]



  4%|▍         | 60/1469 [00:46<1:18:25,  3.34s/it]



  4%|▍         | 61/1469 [00:50<1:26:18,  3.68s/it]



  4%|▍         | 62/1469 [00:54<1:28:59,  3.79s/it]



  4%|▍         | 63/1469 [00:58<1:29:27,  3.82s/it]



  4%|▍         | 64/1469 [01:02<1:31:31,  3.91s/it]



  4%|▍         | 65/1469 [01:06<1:34:42,  4.05s/it]



  4%|▍         | 66/1469 [01:11<1:35:41,  4.09s/it]



  5%|▍         | 67/1469 [01:15<1:35:02,  4.07s/it]



  5%|▍         | 68/1469 [01:19<1:35:21,  4.08s/it]



  5%|▍         | 69/1469 [01:23<1:33:34,  4.01s/it]



  5%|▍         | 70/1469 [01:27<1:33:51,  4.03s/it]



  5%|▍         | 71/1469 [01:30<1:31:19,  3.92s/it]



  5%|▍         | 72/1469 [01:35<1:35:48,  4.11s/it]



  5%|▍         | 73/1469 [01:39<1:37:49,  4.20s/it]



  5%|▌         | 74/1469 [01:44<1:38:30,  4.24s/it]



  5%|▌         | 75/1469 [01:48<1:38:15,  4.23s/it]



  5%|▌         | 76/1469 [01:51<1:33:43,  4.04s/it]



  5%|▌         | 77/1469 [01:55<1:31:35,  3.95s/it]



  5%|▌         | 78/1469 [02:00<1:34:09,  4.06s/it]



  5%|▌         | 79/1469 [02:04<1:33:27,  4.03s/it]



  5%|▌         | 80/1469 [02:08<1:34:34,  4.09s/it]



  6%|▌         | 81/1469 [02:12<1:33:20,  4.03s/it]



  6%|▌         | 82/1469 [02:15<1:28:35,  3.83s/it]



  6%|▌         | 83/1469 [02:19<1:32:18,  4.00s/it]



  6%|▌         | 84/1469 [02:23<1:31:53,  3.98s/it]



  6%|▌         | 85/1469 [02:28<1:34:19,  4.09s/it]



  6%|▌         | 86/1469 [02:32<1:34:37,  4.11s/it]



  6%|▌         | 87/1469 [02:36<1:31:46,  3.98s/it]



  6%|▌         | 88/1469 [02:40<1:33:32,  4.06s/it]



  6%|▌         | 89/1469 [02:43<1:30:03,  3.92s/it]



  6%|▌         | 90/1469 [02:47<1:29:05,  3.88s/it]



  6%|▌         | 91/1469 [02:51<1:31:53,  4.00s/it]



  6%|▋         | 92/1469 [02:56<1:38:06,  4.27s/it]



  6%|▋         | 93/1469 [03:00<1:34:35,  4.12s/it]



  6%|▋         | 94/1469 [03:05<1:40:16,  4.38s/it]



  6%|▋         | 95/1469 [03:10<1:43:14,  4.51s/it]



  7%|▋         | 96/1469 [03:14<1:43:19,  4.52s/it]



  7%|▋         | 97/1469 [03:19<1:41:31,  4.44s/it]



  7%|▋         | 98/1469 [03:23<1:40:19,  4.39s/it]



  7%|▋         | 99/1469 [03:27<1:41:00,  4.42s/it]



  7%|▋         | 100/1469 [03:32<1:41:55,  4.47s/it]



  7%|▋         | 101/1469 [03:37<1:42:28,  4.49s/it]



  7%|▋         | 102/1469 [03:41<1:42:34,  4.50s/it]



  7%|▋         | 103/1469 [03:45<1:39:48,  4.38s/it]



  7%|▋         | 104/1469 [03:50<1:40:20,  4.41s/it]



  7%|▋         | 105/1469 [03:54<1:36:42,  4.25s/it]



  7%|▋         | 106/1469 [03:57<1:33:41,  4.12s/it]



  7%|▋         | 107/1469 [04:02<1:34:21,  4.16s/it]



  7%|▋         | 108/1469 [04:06<1:35:39,  4.22s/it]



  7%|▋         | 109/1469 [04:10<1:33:40,  4.13s/it]



  7%|▋         | 110/1469 [04:14<1:32:04,  4.06s/it]



  8%|▊         | 111/1469 [04:18<1:33:08,  4.11s/it]



  8%|▊         | 112/1469 [04:22<1:32:51,  4.11s/it]



  8%|▊         | 113/1469 [04:26<1:34:09,  4.17s/it]



  8%|▊         | 114/1469 [04:31<1:38:50,  4.38s/it]



  8%|▊         | 115/1469 [04:35<1:36:45,  4.29s/it]



  8%|▊         | 116/1469 [04:39<1:34:25,  4.19s/it]



  8%|▊         | 117/1469 [04:44<1:34:34,  4.20s/it]



  8%|▊         | 118/1469 [04:48<1:33:02,  4.13s/it]



  8%|▊         | 119/1469 [04:51<1:28:14,  3.92s/it]



  8%|▊         | 120/1469 [04:55<1:29:01,  3.96s/it]



  8%|▊         | 121/1469 [04:59<1:30:10,  4.01s/it]



  8%|▊         | 122/1469 [05:03<1:31:08,  4.06s/it]



  8%|▊         | 123/1469 [05:07<1:30:46,  4.05s/it]



  8%|▊         | 124/1469 [05:12<1:33:50,  4.19s/it]



  9%|▊         | 125/1469 [05:15<1:29:36,  4.00s/it]



  9%|▊         | 126/1469 [05:19<1:28:57,  3.97s/it]



  9%|▊         | 127/1469 [05:24<1:32:01,  4.11s/it]



  9%|▊         | 128/1469 [05:28<1:33:22,  4.18s/it]



  9%|▉         | 129/1469 [05:32<1:33:06,  4.17s/it]



  9%|▉         | 130/1469 [05:36<1:30:40,  4.06s/it]



  9%|▉         | 131/1469 [05:40<1:31:17,  4.09s/it]



  9%|▉         | 132/1469 [05:44<1:30:51,  4.08s/it]



  9%|▉         | 133/1469 [05:48<1:29:23,  4.01s/it]



  9%|▉         | 134/1469 [05:52<1:29:56,  4.04s/it]



  9%|▉         | 135/1469 [05:56<1:28:34,  3.98s/it]



  9%|▉         | 136/1469 [06:00<1:29:54,  4.05s/it]



  9%|▉         | 137/1469 [06:04<1:30:26,  4.07s/it]



  9%|▉         | 138/1469 [06:08<1:28:21,  3.98s/it]



  9%|▉         | 139/1469 [06:12<1:27:49,  3.96s/it]



 10%|▉         | 140/1469 [06:16<1:27:40,  3.96s/it]



 10%|▉         | 141/1469 [06:21<1:31:08,  4.12s/it]



 10%|▉         | 142/1469 [06:25<1:31:20,  4.13s/it]



 10%|▉         | 143/1469 [06:28<1:27:37,  3.97s/it]



 10%|▉         | 144/1469 [06:32<1:27:43,  3.97s/it]



 10%|▉         | 145/1469 [06:36<1:25:58,  3.90s/it]



 10%|▉         | 146/1469 [06:41<1:31:42,  4.16s/it]



 10%|█         | 147/1469 [06:45<1:32:53,  4.22s/it]



 10%|█         | 148/1469 [06:49<1:33:47,  4.26s/it]



 10%|█         | 149/1469 [06:54<1:35:01,  4.32s/it]



 10%|█         | 150/1469 [06:59<1:37:51,  4.45s/it]



 10%|█         | 151/1469 [07:02<1:33:26,  4.25s/it]



 10%|█         | 152/1469 [07:07<1:34:55,  4.32s/it]



 10%|█         | 153/1469 [07:11<1:33:09,  4.25s/it]



 10%|█         | 154/1469 [07:16<1:34:43,  4.32s/it]



 11%|█         | 155/1469 [07:19<1:30:04,  4.11s/it]



 11%|█         | 156/1469 [07:23<1:29:48,  4.10s/it]



 11%|█         | 157/1469 [07:28<1:31:43,  4.19s/it]



 11%|█         | 158/1469 [07:32<1:31:25,  4.18s/it]



 11%|█         | 159/1469 [07:36<1:31:20,  4.18s/it]



 11%|█         | 160/1469 [07:40<1:28:43,  4.07s/it]



 11%|█         | 161/1469 [07:44<1:26:49,  3.98s/it]



 11%|█         | 162/1469 [07:47<1:26:10,  3.96s/it]



 11%|█         | 163/1469 [07:52<1:30:08,  4.14s/it]



 11%|█         | 164/1469 [07:56<1:29:09,  4.10s/it]



 11%|█         | 165/1469 [08:01<1:31:46,  4.22s/it]



 11%|█▏        | 166/1469 [08:04<1:29:12,  4.11s/it]



 11%|█▏        | 167/1469 [08:09<1:32:50,  4.28s/it]



 11%|█▏        | 168/1469 [08:14<1:34:44,  4.37s/it]



 12%|█▏        | 169/1469 [08:18<1:33:56,  4.34s/it]



 12%|█▏        | 170/1469 [08:23<1:35:31,  4.41s/it]



 12%|█▏        | 171/1469 [08:27<1:32:46,  4.29s/it]



 12%|█▏        | 172/1469 [08:31<1:33:06,  4.31s/it]



 12%|█▏        | 173/1469 [08:35<1:34:18,  4.37s/it]



 12%|█▏        | 174/1469 [08:39<1:31:05,  4.22s/it]



 12%|█▏        | 175/1469 [08:43<1:29:03,  4.13s/it]



 12%|█▏        | 176/1469 [08:47<1:25:10,  3.95s/it]



 12%|█▏        | 177/1469 [08:51<1:26:51,  4.03s/it]



 12%|█▏        | 178/1469 [08:55<1:27:25,  4.06s/it]



 12%|█▏        | 179/1469 [08:59<1:27:05,  4.05s/it]



 12%|█▏        | 180/1469 [09:03<1:26:51,  4.04s/it]



 12%|█▏        | 181/1469 [09:07<1:25:18,  3.97s/it]



 12%|█▏        | 182/1469 [09:11<1:27:37,  4.09s/it]



 12%|█▏        | 183/1469 [09:16<1:29:34,  4.18s/it]



 13%|█▎        | 184/1469 [09:20<1:28:26,  4.13s/it]



 13%|█▎        | 185/1469 [09:24<1:27:45,  4.10s/it]



 13%|█▎        | 186/1469 [09:28<1:29:52,  4.20s/it]



 13%|█▎        | 187/1469 [09:32<1:29:12,  4.17s/it]



 13%|█▎        | 188/1469 [09:36<1:26:32,  4.05s/it]



 13%|█▎        | 189/1469 [09:41<1:29:21,  4.19s/it]



 13%|█▎        | 190/1469 [09:45<1:30:20,  4.24s/it]



 13%|█▎        | 191/1469 [09:50<1:33:55,  4.41s/it]



 13%|█▎        | 192/1469 [09:55<1:39:35,  4.68s/it]



 13%|█▎        | 193/1469 [10:00<1:39:34,  4.68s/it]



 13%|█▎        | 194/1469 [10:04<1:36:40,  4.55s/it]



 13%|█▎        | 195/1469 [10:08<1:34:21,  4.44s/it]



 13%|█▎        | 196/1469 [10:12<1:28:49,  4.19s/it]



 13%|█▎        | 197/1469 [10:17<1:35:52,  4.52s/it]



 13%|█▎        | 198/1469 [10:21<1:31:20,  4.31s/it]



 14%|█▎        | 199/1469 [10:25<1:32:19,  4.36s/it]



 14%|█▎        | 200/1469 [10:29<1:29:10,  4.22s/it]



 14%|█▎        | 201/1469 [10:33<1:26:19,  4.08s/it]



 14%|█▍        | 202/1469 [10:37<1:28:19,  4.18s/it]



 14%|█▍        | 203/1469 [10:42<1:32:17,  4.37s/it]



 14%|█▍        | 204/1469 [10:46<1:29:45,  4.26s/it]



 14%|█▍        | 205/1469 [10:50<1:28:19,  4.19s/it]



 14%|█▍        | 206/1469 [10:55<1:29:46,  4.27s/it]



 14%|█▍        | 207/1469 [10:59<1:29:26,  4.25s/it]



 14%|█▍        | 208/1469 [11:03<1:27:12,  4.15s/it]



 14%|█▍        | 209/1469 [11:07<1:25:41,  4.08s/it]



 14%|█▍        | 210/1469 [11:11<1:27:29,  4.17s/it]



 14%|█▍        | 211/1469 [11:15<1:28:45,  4.23s/it]



 14%|█▍        | 212/1469 [11:20<1:30:24,  4.32s/it]



 14%|█▍        | 213/1469 [11:24<1:29:53,  4.29s/it]



 15%|█▍        | 214/1469 [11:29<1:30:51,  4.34s/it]



 15%|█▍        | 215/1469 [11:33<1:31:19,  4.37s/it]



 15%|█▍        | 216/1469 [11:37<1:30:26,  4.33s/it]



 15%|█▍        | 217/1469 [11:41<1:28:40,  4.25s/it]



 15%|█▍        | 218/1469 [11:46<1:30:30,  4.34s/it]



 15%|█▍        | 219/1469 [11:51<1:32:19,  4.43s/it]



 15%|█▍        | 220/1469 [11:55<1:30:37,  4.35s/it]



 15%|█▌        | 221/1469 [11:59<1:27:15,  4.19s/it]



 15%|█▌        | 222/1469 [12:03<1:27:49,  4.23s/it]



 15%|█▌        | 223/1469 [12:07<1:27:29,  4.21s/it]



 15%|█▌        | 224/1469 [12:12<1:28:46,  4.28s/it]



 15%|█▌        | 225/1469 [12:16<1:29:05,  4.30s/it]



 15%|█▌        | 226/1469 [12:20<1:25:23,  4.12s/it]



 15%|█▌        | 227/1469 [12:24<1:26:07,  4.16s/it]



 16%|█▌        | 228/1469 [12:27<1:21:52,  3.96s/it]



 16%|█▌        | 229/1469 [12:31<1:21:43,  3.95s/it]



 16%|█▌        | 230/1469 [12:35<1:22:14,  3.98s/it]



 16%|█▌        | 231/1469 [12:40<1:25:47,  4.16s/it]



 16%|█▌        | 232/1469 [12:44<1:24:44,  4.11s/it]



 16%|█▌        | 233/1469 [12:48<1:23:05,  4.03s/it]



 16%|█▌        | 234/1469 [12:52<1:24:24,  4.10s/it]



 16%|█▌        | 235/1469 [12:56<1:25:35,  4.16s/it]



 16%|█▌        | 236/1469 [13:00<1:25:18,  4.15s/it]



 16%|█▌        | 237/1469 [13:05<1:26:50,  4.23s/it]



 16%|█▌        | 238/1469 [13:09<1:25:58,  4.19s/it]



 16%|█▋        | 239/1469 [13:13<1:26:17,  4.21s/it]



 16%|█▋        | 240/1469 [13:17<1:23:43,  4.09s/it]



 16%|█▋        | 241/1469 [13:21<1:25:25,  4.17s/it]



 16%|█▋        | 242/1469 [13:26<1:27:36,  4.28s/it]



 17%|█▋        | 243/1469 [13:30<1:29:20,  4.37s/it]



 17%|█▋        | 244/1469 [13:35<1:27:06,  4.27s/it]



 17%|█▋        | 245/1469 [13:38<1:24:42,  4.15s/it]



 17%|█▋        | 246/1469 [13:42<1:23:47,  4.11s/it]



 17%|█▋        | 247/1469 [13:47<1:25:26,  4.20s/it]



 17%|█▋        | 248/1469 [13:51<1:25:21,  4.19s/it]



 17%|█▋        | 249/1469 [13:55<1:23:08,  4.09s/it]



 17%|█▋        | 250/1469 [13:59<1:25:50,  4.23s/it]



 17%|█▋        | 251/1469 [14:03<1:24:00,  4.14s/it]



 17%|█▋        | 252/1469 [14:07<1:22:18,  4.06s/it]



 17%|█▋        | 253/1469 [14:11<1:23:28,  4.12s/it]



 17%|█▋        | 254/1469 [14:15<1:22:44,  4.09s/it]



 17%|█▋        | 255/1469 [14:20<1:23:49,  4.14s/it]



 17%|█▋        | 256/1469 [14:24<1:22:58,  4.10s/it]



 17%|█▋        | 257/1469 [14:28<1:25:06,  4.21s/it]



 18%|█▊        | 258/1469 [14:32<1:23:26,  4.13s/it]



 18%|█▊        | 259/1469 [14:37<1:24:44,  4.20s/it]



 18%|█▊        | 260/1469 [14:41<1:25:20,  4.24s/it]



 18%|█▊        | 261/1469 [14:47<1:37:45,  4.86s/it]



 18%|█▊        | 262/1469 [14:51<1:33:23,  4.64s/it]



 18%|█▊        | 263/1469 [14:56<1:32:07,  4.58s/it]



 18%|█▊        | 264/1469 [15:00<1:28:35,  4.41s/it]



 18%|█▊        | 265/1469 [15:04<1:27:26,  4.36s/it]



 18%|█▊        | 266/1469 [15:08<1:23:45,  4.18s/it]



 18%|█▊        | 267/1469 [15:12<1:22:18,  4.11s/it]



 18%|█▊        | 268/1469 [15:16<1:23:46,  4.19s/it]



 18%|█▊        | 269/1469 [15:20<1:22:10,  4.11s/it]



 18%|█▊        | 270/1469 [15:24<1:21:29,  4.08s/it]



 18%|█▊        | 271/1469 [15:28<1:21:24,  4.08s/it]



 19%|█▊        | 272/1469 [15:32<1:22:57,  4.16s/it]



 19%|█▊        | 273/1469 [15:36<1:21:12,  4.07s/it]



 19%|█▊        | 274/1469 [15:40<1:20:51,  4.06s/it]



 19%|█▊        | 275/1469 [15:45<1:21:48,  4.11s/it]



 19%|█▉        | 276/1469 [15:49<1:21:32,  4.10s/it]



 19%|█▉        | 277/1469 [15:53<1:22:47,  4.17s/it]



 19%|█▉        | 278/1469 [15:57<1:19:30,  4.01s/it]



 19%|█▉        | 279/1469 [16:01<1:20:25,  4.06s/it]



 19%|█▉        | 280/1469 [16:05<1:22:31,  4.16s/it]



 19%|█▉        | 281/1469 [16:09<1:20:02,  4.04s/it]

In [None]:
image_feature = pd.DataFrame.from_dict(features, orient='index').add_prefix("beit_").reset_index()
image_feature.rename(columns={"index":"odai_photo_file_name"}, inplace=True)

In [None]:
# trainのデータに結合します。
train_df = pd.merge(train_df, image_feature, on="odai_photo_file_name", how="left")

In [None]:
train_df.shape

In [None]:
# testデータでも同様なことを行って行きます
image_df_test = test_df[["id", "odai_photo_file_name"]].copy()

image_ids = image_df_test["odai_photo_file_name"].values
n_batches = len(image_ids) // CFG.batch_size + 1


features = {}
for b in tqdm(range(n_batches)):
    start = b*CFG.batch_size
    end = (b+1)*CFG.batch_size
    batch_ids = image_ids[start:end]
    batch_images = np.zeros((len(batch_ids),CFG.img_size,CFG.img_size,3))
    for i,image_id in enumerate(batch_ids):
        try:
            batch_images[i] = load_image(image_id, is_train=False)
        except:
          print("Error")
    batch_preds = m.predict(batch_images)
    for i,image_id in enumerate(batch_ids):
        features[image_id] = batch_preds[i]

image_feature = pd.DataFrame.from_dict(features, orient='index').add_prefix("beit_").reset_index()
image_feature.rename(columns={"index":"odai_photo_file_name"}, inplace=True)

test_df = pd.merge(test_df, image_feature, on="odai_photo_file_name", how="left")

In [None]:
test_df.shape

# Data Split

In [None]:
train_df.to_csv('embedding_train_image_beit')

In [None]:
test_df.to_csv('embedding_test_image_beit')

In [None]:
# 学習データと評価データに分割します
train_df, valid_df = train_test_split(train_df, test_size=0.2, random_state=42, stratify=train_df["is_laugh"])

train_y = train_df["is_laugh"]
train_x = train_df.drop(["id", "odai_photo_file_name", "text","is_laugh"], axis=1)

valid_y = valid_df["is_laugh"]
valid_x = valid_df.drop(["id", "odai_photo_file_name", "text","is_laugh"], axis=1)

test_x = test_df.drop(["id", "odai_photo_file_name", "text"], axis=1)

In [None]:
print(train_x.shape)
print(valid_x.shape)

# Model

In [None]:
lgbm_params = {  
    "n_estimators": 20000,
    "objective": 'binary',
    "learning_rate": 0.05,
    "num_leaves": 32,
    "random_state": 71,
    "n_jobs": -1,
    "importance_type": "gain",
    'colsample_bytree': .8,
    "reg_lambda": 5,
    "max_depth":5,
    }

lgtrain = lgb.Dataset(train_x, train_y)
lgvalid = lgb.Dataset(valid_x, valid_y)

lgb_clf = lgb.train(
    lgbm_params,
    lgtrain,
    num_boost_round=10000,
    valid_sets=[lgtrain, lgvalid],
    valid_names=['train','valid'],
    early_stopping_rounds=50,
    verbose_eval=50
)

In [None]:
# 特徴量の重要度を可視化。
lgb.plot_importance(lgb_clf, figsize=(12,8), max_num_features=50, importance_type='gain')
plt.tight_layout()
plt.show()

In [None]:
# 評価指標はlog lossだが、accuracyも見てみる

val_pred = lgb_clf.predict(valid_x, num_iteration=lgb_clf.best_iteration)
val_pred_max = np.round(lgb_clf.predict(valid_x)).astype(int)  # クラスに分類
accuracy = sum(valid_y == val_pred_max) / len(valid_y)
print(accuracy)

In [None]:
_conf_options = {"normalize": None,}
_plot_options = {
        "cmap": "Blues",
        "annot": True
    }

conf = confusion_matrix(y_true=valid_y,
                        y_pred=val_pred_max,
                        **_conf_options)

fig, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(conf, ax=ax, **_plot_options)
ax.set_ylabel("Label")
ax.set_xlabel("Predict")

# Predict

In [None]:
test_pred = lgb_clf.predict(test_x, num_iteration=lgb_clf.best_iteration)

In [None]:
submission_df["is_laugh"] = test_pred
submission_df.head()

In [None]:
submission_df.to_csv(('sub.csv'), index=False)