<a href="https://colab.research.google.com/github/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/2023notebooks/2023_1123Stroop_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stroop effect

# 文献

* J. Ridley Stroop (1935) [STUDIES OF INTERFERENCE IN SERIAL VERBAL REACTIONS](https://psychclassics.yorku.ca/Stroop/), Journal of Experimental Psychology, 18, 643-662.

<center>
<div style="width:88%;font-color:teal;">

語の名前と異なる色で印刷された色名語音読 (RCNd) 条件 (100 刺激，単位:秒) と黒インクで印字された色名語音読条件 (RCNb)<br/>
**RCNd** 印刷色が文字と不一致の場合の色名文字音読条件，
**RCNb** 黒インクで印字された色名文字音読条件，
**No. Ss** 被験者数，**s** 標準偏差，**D** 差分，**D/PEd** 差分を確率誤差で除した値<br/>
Stroop (1935) Table 1.
</div>
<img src="https://raw.githubusercontent.com/ShinAsakawa/ShinAsakawa.github.io/master/assets/1935Stroop_tab1.jpg" style="width:77%">
<!-- <img src="1935Stroop_tab1.jpg" width="77%">-->
<div style="width:88%;font-color:teal;">

正方形 ■ の色名呼称条件と，他の色で印刷された語の色名音読条件<br/>
**NCWd**: 色名呼称時に文字が色と異なる条件，**NC** 色名呼称条件.
Stroop (1935) Table 3.    
</div>    

<img src="https://raw.githubusercontent.com/ShinAsakawa/ShinAsakawa.github.io/master/assets/1935Stroop_tab3.jpg" style="width:77%">
<!-- <img src="1935Stroop_tab3.jpg" width="77%"> -->
</center>

* **NC**: 色名呼称 Naming Colors.
* **NCWd**: 色名呼称時に文字が色と異なる場合 Naming the Colors of the Print of Words Where the Color of the Print and the Word are Different.
* **RCNb** 黒で書かれた色名を読む Reading Color Names Printed in Black Ink.
* **RCNd** 色名を読むが，印刷色と単語名が不一致 Reading Color Names Where the Color of the Print and the Word are Different.
* **D** 差分 Difference.
* **D/P Ed** 差分を確率誤差で除した値 Difference divided by the probable error of the difference.
* **M & F** 男性と女性 Males and Females.
* **P Ed**: 差分の確率誤差 Probable error of the difference.
* **s**: 標準偏差 Sigma or standard deviation.
* **s / m**:  Standard deviation divided by the mean.


<center>
<img src="https://raw.githubusercontent.com/ShinAsakawa/ShinAsakawa.github.io/master/assets/2003Roelofs_stroop_fig9.jpg" style="width:44%">
<!-- <img src="2003Roelofs_stroop_fig9.jpg" width="44%"><br/> -->
<div style="width:77%;background-color:lavendar;">

Stroop 課題における，単語計画と実行制御。
ヒトの左半球の側面図 (上) と 中央 (下)。
単語計画系は，色知覚 (cp)，概念同定 (ci)，レンマ検索 (lr)，単語携帯符号化 (wfe)，構音処理 (art) を
介して色名呼称へと至る。
単語形態知覚 (wfp) は，語彙と形態と並列的に至る。
単語音読は，最小限 wfp, wfe, art を含む。
実行系は前帯状回にあり，目標と入力制御に関与する。
<!-- Figure 9. Word planning and executive control in the Stroop task.
Lateral view (top panel) and medial view (bottom panel) of the left hemisphere of the human brain.

The word-planning system achieves color naming through color perception (cp), conceptual identific
ation (ci), lemma retrieval (lr), word-form encoding (wfe), and articulatory processing (art);
word-form perception (wfp) activates lemmas and word forms in parallel.
Word reading minimally involves wfp, wfe, and art.
The executive system centered on the anterior cingulate achieves goal and input control. -->
出典: Roelofs (2003) __Goal-Referenced Selection of Verbal Action: Modeling Attentional Control in the Stroop Task__, Psychological Review, 2003, Vol. 110, No. 1, 88–125.
</div>
</center>    

# 0 下準備

## 0.1 ライブラリの輸入

In [None]:
%config InlineBackend.figure_format = 'retina'
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

import IPython
isColab = 'google.colab' in str(IPython.get_ipython())

if isColab:
    !pip install --upgrade termcolor==1.1
from termcolor import colored

try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

try:
    import bit
except ImportError:
    !pip install ipynbname --upgrade
    !git clone https://github.com/ShinAsakawa/bit.git
    import bit

import os
HOME = os.environ['HOME']

from tqdm.notebook import tqdm

## 0.1 乱数系列発生器の種を設定

In [None]:
# 乱数のシードを設定
import numpy as np
import random
import sys

seed=42
torch.manual_seed(seed=seed)
np.random.seed(seed=seed)
random.seed(seed)
batch_size = 64

# 1 データセットの定義

In [None]:
from bit import get_text_img
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
from glob import  glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import v2


class stroop_Dataset(Dataset):
    def __init__(self,
                 dataset_name:str='train'):

        super().__init__()
        self.dataset_name = 'train' if dataset_name == 'train' else 'val'

        self.width, self.height = 224, 224
        self.bgcolor = [255,255,255]

        # 色刺激情報
        self.colors = ['black', 'red', 'green', 'blue', 'yellow']

        # 文字刺激情報，上の色刺激情報と一対一対応
        self.words  = ['黒', '赤', '緑', '青', '黄','■']

        # 認識の一般性を確保するために複数の文字サイズを用いる
        self.font_sizes=[42, 56, 70, 84, 98, 112]

        # 認識の一般性を確保するために複数の文字フォントを用いる
        # 無料で公開されている Noto フォントを用いる。
        # 下のループを実行することで 14 種類のフォントが登録される
        # NotoSerif は明朝体，NotoSans はゴチック体と考えれば良い
        # これら両書体について，各々 7 種類の太さが定義されている
        fonts = []
        for font_size in self.font_sizes:
            _fonts = bit.get_notojp_fonts(fontsize=font_size, verbose=False)
            for _fontname, _font in _fonts.items():
                font_name = str(f'{font_size:03d}')+_fontname

                if dataset_name == 'train':
                    fonts.append((font_name, _font))
                elif 'Regular' in _fontname:
                    fonts.append((font_name, _font))

        self.fonts = fonts

        # 上記の，文字 X 色 X サイズ X フォント で条件を作成
        # 文字呼称条件と色名呼称条件とが Stroop 効果である。
        # 理論上，サイズ同定条件，フォント識別問題でも同様の実験が成り立つが今回は採用せず
        cond = []
        for font in fonts:
            size = int(font[0][:3])
            for color in self.colors:
                for word in self.words:
                    # 条件は (色，文字，フォントサイズ，フォント書体) の 4 連 tuple
                    cond.append((color, word, size, font))
        self.cond = cond

        self.affine = v2.RandomAffine(degrees=(-5, 5), translate=(0.1, 0.1), fill=list(np.array(self.bgcolor)/255))
        #self.affine = v2.RandomAffine(degrees=(-5, 5), translate=(0.05, 0.05))

    def __len__(self):
        return len(self.colors) * len(self.words) * len(self.fonts)
        #return len(self.cond)

    def __getitem__(self, idx:int):

        color_ = self.cond[idx][0]
        word_  = self.cond[idx][1]
        size_  = self.cond[idx][2]
        font_name  = self.cond[idx][3][0]
        font_  = self.cond[idx][3][1]

        color_idx = self.colors.index(color_)
        word_idx  = self.words.index(word_)
        size_idx  = self.font_sizes.index(size_)

        # 条件に従った画像を 1 枚生成
        img, draw_canvas, bbox = get_text_img(
            text=word_, color=color_, draw_bbox=False, font=font_)

        # 画像を torch.Tensor に変換しないと，DataLoader でハンドリングできない。
        # このため一旦 torch.tensor に変換している
        _img = torch.tensor(
            (np.array(img)/255).clip(0,1).transpose(2,0,1),
            device=device,
            dtype=torch.float32,
        )

        if self.dataset_name == 'train':
            _img = self.affine(_img)

        return _img, {'color':self.colors.index(color_),
                     'word':self.words.index(word_),
                     'font_size':size_, 'font_name':font_name}

stroop_ds = stroop_Dataset()
stroop_val_ds = stroop_Dataset(dataset_name='val')

## 1.1 定義したデータセットの視覚化

In [None]:
fig, ax = plt.subplots(6, 6, figsize=(14, 10))
i, j = 0, 0
j_max = 6

ds = stroop_ds # or stroop_val_ds
#ds = stroop_val_ds
Ns = np.random.permutation(ds.__len__())
for idx in Ns[:36]:
#for idx in range(30):
    img, y = ds.__getitem__(idx)
    _img = img.detach().squeeze(0).cpu().numpy().transpose(1,2,0) # * 255
    print(idx,y)
    ax[i,j].imshow(_img)
    ax[i,j].set_xticks([])
    ax[i,j].set_yticks([])
    j += 1
    if j == j_max:
        i+=1; j=0

#plt.show()

## 1.2 訓練用，検証用データへの分割とデータローダ用意

In [None]:
train_ds = stroop_ds
val_ds = stroop_val_ds

# 並列計算のために collation 関数を定義
def _collate_fn(batch):
    inps, tgts = list(zip(*batch))
    inps = list(inps)
    tgts = list(tgts)
    return inps, tgts

# 訓練データセット用データローダ
train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=_collate_fn
)

# 検証データセット用データローダ
val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=_collate_fn
)

# 後に使用するために，データローダに名前を加えておく
train_dl.name = 'train'
val_dl.name = 'val'

## 1.3 公開されている訓練済一般画像認識 (ImageNet) モデルの読み込み，最終層の付け替え

In [None]:
import torchvision
cnn_col = torchvision.models.mobilenet_v3_small(weights="DEFAULT")
#cnn_col = torchvision.models.resnet50(weights="DEFAULT")
cnn_col.eval()

In [None]:
# 各モデルを定義し，訓練済み結合係数をダウンロード
import copy
import torchvision
from torchvision import models

#cnn_col = models.resnet50(weights="DEFAULT")
#cnn_wrd = models.resnet50(weights="DEFAULT")
#cnn_col = models.resnet18(weights="DEFAULT")
#cnn_wrd = models.resnet18(weights="DEFAULT")
cnn_col = models.efficientnet_v2_s(weights="DEFAULT").to(device)
cnn_wrd = models.efficientnet_v2_s(weights="DEFAULT").to(device)
# cnn_col = torchvision.models.mobilenet_v3_small(weight="DEFAULT")
# cnn_wrd = torchvision.models.mobilenet_v3_small(weight="DEFAULT")

# parameters_col = {name:param for name, param in cnn_col.named_parameters()}
# modules_col = {name:param for name, param in cnn_col.named_modules()}
# parameters_wrd = {name:param for name, param in cnn_wrd.named_parameters()}
# modules_wrd = {name:param for name, param in cnn_wrd.named_modules()}

# cnn model の最終層入れ替え
#cnn_col.fc = torch.nn.Linear(in_features=2048, out_features=len(stroop_ds.colors)) # resnet50
#cnn_wrd.fc = torch.nn.Linear(in_features=2048, out_features=len(stroop_ds.words))
# cnn_col.classifier[-1] = torch.nn.Linear(in_features=1024, out_features=len(stroop_ds.colors)) # mobilenet v3 small
# cnn_wrd.classifier[-1] = torch.nn.Linear(in_features=1024, out_features=len(stroop_ds.words))
cnn_col.classifier[-1] = torch.nn.Linear(in_features=1280, out_features=len(stroop_ds.colors)).to(device) # efficientnet v2
cnn_wrd.classifier[-1] = torch.nn.Linear(in_features=1280, out_features=len(stroop_ds.words)).to(device)

# 転移学習で学習させるパラメータを `params_to_update` に格納
params_to_update_wrd = []
params_to_update_col = []

# 学習させるパラメータ名
update_param_names_wrd = ["classifier.1.weight", "classifier.1.bias"]
update_param_names_col = ["classifier.1.weight", "classifier.1.bias"]
# update_param_names_wrd = ["classifier.3.weight", "classifier.3.bias"]
# update_param_names_col = ["classifier.3.weight", "classifier.3.bias"]

# 学習させるパラメータ以外は勾配計算をなくし、変化しないように設定
# cnn_col
for name, param in cnn_col.named_parameters():
    if name in update_param_names_col:
        param.requires_grad = True
        params_to_update_col.append(param)
        print(name)
    else:
        param.requires_grad = False

for name, param in cnn_col.state_dict().items():
    if name in update_param_names_col:
        param.requires_grad = True
        params_to_update_col.append(param)
        print(name)
    else:
        param.requires_grad = False


# cnn_wrd
for name, param in cnn_wrd.named_parameters():
    if name in update_param_names_wrd:
        param.requires_grad = True
        #params_to_update_wrd.append((name, param))
        params_to_update_wrd.append(param)
        print(name)
    else:
        param.requires_grad = False

for name, param in cnn_wrd.state_dict().items():
    if name in update_param_names_wrd:
        param.requires_grad = True
        #params_to_update_wrd.append((name,param))
        params_to_update_wrd.append(param)
        print(name)
    else:
        param.requires_grad = False

# params_to_updateの中身を確認
#print(params_to_update_col)
#print(params_to_update_wrd)
# for param in params_to_update_wrd:
#     print(param[0], type(param[1]))
#print(f'id(cnn_wrd):{id(cnn_wrd)}, id(cnn_col):{id(cnn_col)}')

# 確認作業
for (name1, param1), (name2, param2) in zip(cnn_wrd.named_parameters(), cnn_wrd.named_parameters()):
    if param1.requires_grad == True or param2.requires_grad == True:
        print(name1, name2)
        #print(name1, param1.requires_grad, name2, param2.requires_grad)

print(cnn_col.classifier)
params_to_update_col

## 1.4 訓練関数の定義

In [None]:
criterion = torch.nn.CrossEntropyLoss()
col_optimizer = torch.optim.Adam(params=params_to_update_col, lr=0.001)

# モデルを学習させる関数
def train_model(
    model:torchvision.models=cnn_col,
    target:str='color',  # ['color', 'word']
    train_dl:torch.utils.data.dataloader=train_dl,
    val_dl:torch.utils.data.dataloader=val_dl,
    criterion:torch.nn.modules=criterion,
    optimizer:torch.optim=col_optimizer,
    epochs:int=5,
    losses:dict=None,
    accs:list=None):

    if losses == None:
        losses = {'train':[], 'val':[]}
    if accs == None:
        accs = {'train':[],'val':[]}

    for epoch in range(epochs):

        print(f'エポック {epoch+1:02d}/{epochs:02d}', end=" ")
        for phase in [train_dl, val_dl]:
            if phase.name == 'train':
                model.train()  # モデルを訓練モードに
            else:
                model.eval()   # モデルを検証モードに

            epoch_loss = 0.     # epoch ごとの損失和
            epoch_corrects = 0  # epoch ごとの正解数

            # 未学習時の検証性能を確かめるため epoch=0 時の訓練は省略
            if (epoch == 0) and (phase.name == 'train'):
                continue

            # データローダからミニバッチを取り出すループ
            for inputs, labels in phase:
                inputs = torch.tensor(np.array([inp.detach().cpu().numpy() for inp in inputs])).float().to(device)
                labels = torch.LongTensor([label[target] for label in labels]).to(device)

                optimizer.zero_grad() # optimizerを初期化
                # 順伝搬（forward）計算
                with torch.set_grad_enabled(phase.name =='train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)  # 損失を計算
                    _, preds = torch.max(outputs, 1)   # ラベルを予測

                    # 訓練時はバックプロパゲーション
                    if phase.name == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * inputs.size(0)
                    # 正解数の合計を更新
                    epoch_corrects += torch.sum(preds == labels.data)

            # epoch ごとの loss と正解率を表示
            N = train_ds.__len__() if phase.name == 'train' else val_ds.__len__()
            epoch_loss = epoch_loss / N
            epoch_acc = epoch_corrects.double() / N

            losses[phase.name].append(epoch_loss)
            accs[phase.name].append(epoch_acc.detach().cpu().numpy())  #[0])

            print(f'{phase.name:5s}',
                  f'損失: {epoch_loss:.4f}',
                  f'精度: {epoch_acc:.3f}', end=" ")
        print()
    return losses, accs

## 1.5 検証データセットを用いた評価関数の定義

In [None]:
def _eval(
    model:torchvision.models=cnn_col,
    target:str='color',  # ['color', 'word']
    ds:torch.utils.data.dataset=val_ds,
    isDisplay:bool=False):

    model.eval()
    n_corrects = 0
    Outs = []
    for idx in tqdm(range(ds.__len__())):
        img, label = val_ds.__getitem__(idx)
        out = model(img.unsqueeze(0))
        _, pred = torch.max(out, 1)   # ラベルを予測
        _pred = pred.detach().cpu().numpy()[0]
        _tch = label[target]
        tch_col, tch_wrd = label['color'], label['word']
        isOK = _pred == _tch
        if isOK:
            n_corrects += 1
        if not isOK:
            if target == 'color':
                _pred_ = stroop_ds.colors[_pred]
            else:
                _pred_ = stroop_ds.words[_pred]
            Outs.append({'idx':idx, '正否':isOK, '出力':_pred_, '刺激色':stroop_ds.colors[tch_col], '刺激字':stroop_ds.words[tch_wrd]})
            if isDisplay:
                plt.figure(figsize=(3,3))
                plt.title(f'idx:{idx}, 予測:{_pred_}, 色:{stroop_ds.colors[tch_col]}, 文字:{stroop_ds.words[tch_wrd]}')
                plt.imshow(img.detach().cpu().numpy().transpose(1,2,0))
                plt.show()

    print(f'正解率: {(n_corrects / val_ds.__len__()) * 100:.3f} %')
    #print(stroop_ds.words)
    return Outs

# 2 訓練の実施

In [None]:
#%%time
print('# 色名呼称課題')
_optimizer = torch.optim.Adam(params=params_to_update_col, lr=0.001)
#_optimizer = torch.optim.Adam(params=[cnn_col.classifier[1].bias, cnn_col.classifier[1].weight], lr=0.001)
#col_optimizer = torch.optim.Adam(params=params_to_update_col, lr=0.001)
losses_col = {'train':[], 'val':[]}
accs_col = {'train':[],'val':[]}
losses_col, accs_col = train_model(
    model=cnn_col,
    #model=cnn_col,
    target='color',
    optimizer=_optimizer,
    losses=losses_col,
    accs=accs_col,
    epochs=10)

In [None]:
_eval(target='color', model=cnn_col, isDisplay=True)

In [None]:
%%time
print('# 文字音読課題')
_optimizer = torch.optim.Adam(params=params_to_update_wrd, lr=0.0001)
#_optimizer = torch.optim.Adam(params=params_to_update_wrd, lr=0.001)
_losses = {'train':[], 'val':[]}
_accs = {'train':[],'val':[]}
losses_col, accs_col = train_model(
    model=cnn_wrd,
    target='word',
    optimizer=_optimizer,
    losses=_losses,
    accs=_accs,
    epochs=30)

In [None]:
_eval(target='word', model=cnn_wrd, isDisplay=True)

In [None]:
_eval(target='color', model=cnn_col, isDisplay=True)

In [None]:
print(losses_col)  # 色名呼称条件の損失値の推移
#print(losses_wrd)  # 文字音読条件の損失値の推移
plt.plot(losses_col['train'],c='red', label='訓練データ')
plt.plot(losses_col['val'],c='blue', label='検証データ')
plt.title('色名呼称課題 損失値の変化')
plt.legend()
plt.show()

In [None]:
plt.title('文字音読課題 損失値の変化')
plt.plot(losses_wrd['train'], label='訓練データ', c='red')
plt.plot(losses_wrd['val'], label='検証データ', c='blue')
plt.legend()
plt.show()

In [None]:
outputs_col = _eval(target='color', model=cnn_col, isDisplay=True)
outputs_wrd = _eval(target='word', model=cnn_wrd, isDisplay=True)
print(outputs_wrd)
print(stroop_ds.words)

## 2.4 保存

In [None]:
# 全体を保存して，model2 に再読み込み
_fname = '2023_1114stroop_col_resnet50_full.pt'
torch.save(cnn_col.state_dict(), _fname)
model2 = models.resnet50(weights="DEFAULT")
model2.fc = torch.nn.Linear(in_features=2048, out_features=len(stroop_ds.colors))
model2.load_state_dict(torch.load(_fname))
model2.eval()
_ = _eval(target='color', model=model2)

_fname = '2023_1114stroop_wrd_resnet50_full.pt'
torch.save(cnn_wrd.state_dict(), _fname)
model2 = models.resnet50(weights="DEFAULT")
model2.fc = torch.nn.Linear(in_features=2048, out_features=len(stroop_ds.words))
model2.load_state_dict(torch.load(_fname))
model2.eval()
_ = _eval(target='word', model=model2)


In [None]:
for (param0), (param1) in zip(cnn_col.parameters(), cnn_wrd.parameters()):
#for (name0, param0), (name1, param1) in zip(cnn_col.named_parameters(), cnn_wrd.named_parameters()):
    if not 'fc' in name0:
        print(name0, (param0.data == param1.data).sum().detach().numpy() == param0.detach().numpy().size)
    else:
        print(name0, param0, param1)

In [None]:
# fc だけを保存して，model2 に再読み込み
# _fname = '2023_1114stroop_col.pt'
# torch.save(cnn_col.fc.state_dict(), _fname)
# model2 = models.resnet50(weights="DEFAULT")
# model2.fc = torch.nn.Linear(in_features=2048, out_features=len(stroop_ds.colors))
# model2.fc.load_state_dict(torch.load(_fname))
# model2.eval()
# _ = _eval(target='color', model=model2)

_fname = '2023_1114stroop_wrd.pt'
torch.save(cnn_wrd.fc.state_dict(), _fname)
model2 = models.resnet50(weights="DEFAULT")
model2.fc = torch.nn.Linear(in_features=2048, out_features=len(stroop_ds.words))
model2.fc.load_state_dict(torch.load(_fname))
model2.eval()
_ = _eval(target='word', model=model2)

In [None]:
_dic = torch.load(wrd_pt_fname)
print(_dic)
print(cnn_wrd.fc.state_dict())
print(model2.fc.state_dict())
model2.eval()
_ = _eval(target='word', model=model2)
cnn_wrd.eval()
_ = _eval(target='word', model=cnn_wrd)


In [None]:
model3 = models.resnet50(weights="DEFAULT")
model3.fc = torch.nn.Linear(in_features=2048, out_features=len(stroop_ds.colors))

#hoge_fname = '2023_1114hoge.pt'
#torch.save(cnn_col.state_dict(), hoge_fname)
model3.load_state_dict(torch.load(hoge_fname))
#model3.eval()

# torch.save(cnn_col.fc.state_dict(), hoge_fname)
# model3.fc.load_state_dict(torch.load(hoge_fname))
_ = _eval(target='color', model=cnn_col, isDisplay=False)
_ = _eval(target='color', model=model3, isDisplay=False)


In [None]:
def save_checkpoint(checkpoint_path, model): # , optimizer):
    state = {'state_dict': model.state_dict(),
             #'optimizer' : optimizer.state_dict()
            }
    torch.save(state, checkpoint_path)
    #print('model saved to %s' % checkpoint_path)

def load_checkpoint(checkpoint_path, model): # , optimizer):
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])
    #optimizer.load_state_dict(state['optimizer'])
    #print('model loaded from %s' % checkpoint_path)



model3 = models.resnet18(weights="DEFAULT")
model3.fc = torch.nn.Linear(in_features=512, out_features=len(stroop_ds.colors))

save_checkpoint('2023_1113stroop_col.pt', cnn_col)
#model3.load_state_dict(torch.load(hoge_fname))
load_checkpoint('2023_1113stroop_col.pt', model3)
_ = _eval(target='color', model=cnn_col, isDisplay=False)
_ = _eval(target='color', model=model3, isDisplay=False)


In [None]:
import copy

class stroop_model(torch.nn.Module): # vision.models.resnet): # .Resnet):
    def __init__(self,
                 stroop_ds:torch.utils.data.dataset=stroop_ds)-> None:
        super().__init__()

        # get the pretrained ResNet18 network
        self.cnn = models.resnet18(weights="DEFAULT")

        for name, param in model.cnn.named_parameters():
            param.requires_grad = False

        self.condition = 'color' # ['color', 'word']
        self.cond_layer = torch.nn.Embedding(num_embeddings=2, embedding_dim=2)
        self.cond_vec = self.set_condition(cond='color')

        self.col_layer = torch.nn.Linear(in_features=512, out_features=len(stroop_ds.colors))
        self.wrd_layer = torch.nn.Linear(in_features=512, out_features=len(stroop_ds.words))
        self.out_layer = torch.nn.Linear(
            in_features=len(stroop_ds.colors)+len(stroop_ds.words)+2, # 最後の 2 は条件ベクトル
            out_features=len(stroop_ds.colors))

    def set_condition(self, cond:str='color')-> None:
        if cond == 'color':
            self.cond_vec = self.cond_layer(torch.LongTensor([0]))
        elif cond == 'word':
            self.cond_vec = self.cond_layer(torch.LongTensor([1]))

    def forward(self,
                x:torch.Tensor) -> torch.Tensor:

        size = x.size(0)
        cond_vecs = self.cond_vec.repeat(size,1)

        x = self.cnn.conv1(x)
        x = self.cnn.bn1(x)
        x = self.cnn.relu(x)
        x = self.cnn.maxpool(x)

        x = self.cnn.layer1(x)
        x = self.cnn.layer2(x)
        x = self.cnn.layer3(x)
        x = self.cnn.layer4(x)

        x = self.cnn.avgpool(x)
        x = torch.flatten(x, 1)

        _col = self.col_layer(x)
        _wrd = self.wrd_layer(x)

        #x = self.out_layer(torch.cat((_col, _wrd, cond_vecs),dim=1))
        x = self.cnn.fc(x)

        return x