<a href="https://colab.research.google.com/github/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/2024notebooks/2024_1122Karapetian_AlexNet_transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Karapetian+ (2023), Empirically Identifying and Computationally Modeling the Brain–Behavior Relationship for Human Scene Categorization, Journal of Cognitive Neuroscience 35:11, pp. 1879–1897, doi:10.1162/jocn_a_02043

データは，https://osf.io/4fdky/ より入手して，駒澤 Gdrive で共有

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'device:{device}')

import os
import numpy as np
import sys
import zipfile
import glob

import IPython
isColab = 'google.colab' in str(IPython.get_ipython())

if isColab:
    from google.colab import drive
    drive.mount('/content/drive')

    basedir = '/content/drive/Shareddrives/#2024認知心理学研究(1)b/浅川先生/2023Karapetian+OSF/Stimuli'
    fnames = list(sorted(glob.glob(os.path.join(basedir,'*.jpg'))))
else:
    HOME = os.environ['HOME']
    basedir = os.path.join(HOME, 'study/2024Agnessa14_Perceptual-decision-making.git/Stimuli')
    fnames = list(sorted(glob.glob(os.path.join(basedir,'*.jpg'))))

import matplotlib.pyplot as plt
import PIL

try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

# 刺激画像の表示

In [None]:
# 刺激画像の表示
nrows, ncols = 6, 10
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12,9))
#fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(14,10))

i=0
for row in range(nrows):
    for col in range(ncols):
        #img = PIL.Image.open(os.path.join(basedir, 'Stimuli/'+str(i+1)+'.jpg')).convert('RGB')
        img = PIL.Image.open(os.path.join(basedir, str(i+1)+'.jpg')).convert('RGB')
        ax[row][col].imshow(img)
        ax[row][col].axis('off')
        ax[row][col].set_title(f'{i+1}')
        i += 1

# 1-10: アパート
#11-20: ベッド
#21-30: 高速道路
#31-40: 海岸
#41-50: 峡谷
#51-60: 森林
# 1-30 は，人工物情景であり，31-60 は，自然情景

In [None]:
# データセットの作成
import torchvision
from torchvision import transforms
from sklearn.model_selection import train_test_split
#import torch
#import numpy as np
import random

# 乱数シード固定（再現性の担保）
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed = 42
fix_seed(seed)

# データローダーのサブプロセスの乱数の seed 固定
def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

for i in range(8):
    worker_init_fn(i)

class Agressa2023_dataset(torch.utils.data.Dataset):
    def __init__(self,
                 task:str='cat',
# 'cat': カテゴリー化課題 (人工物か自然物か)，
# 'desc': 判別課題 (1-10:アパート，11-20:ベッド，21-30:高速道路, 31-40:海岸, 41-50:峡谷, 51-60:森林
                 ):
        super().__init__()
        if task == 'desc':
            self.task = 'desc'
        else:
            self.task = 'cat'

        if isColab:
            self.basedir = '/content/drive/Shareddrives/#2024認知心理学研究(1)b/浅川先生/2023Karapetian+OSF/Stimuli'
            self.fnames = list(sorted(glob.glob(os.path.join(basedir,'*.jpg'))))
        else:
            HOME = os.environ['HOME']
            self.basedir = os.path.join(HOME, 'study/2024Agnessa14_Perceptual-decision-making.git/Stimuli')
            self.fnames = list(sorted(glob.glob(os.path.join(basedir,'*.jpg'))))
        self.fname = fnames

        self.Img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(224),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

    def __len__(self):
        return len(self.fname)

    def __getitem__(self, idx):
        img_fname = self.fname[idx]
        #img = PIL.Image.open(os.path.join(basedir, 'Stimuli/'+str(idx+1)+'.jpg')).convert('RGB')
        img = PIL.Image.open(os.path.join(self.basedir, str(idx+1)+'.jpg')).convert('RGB')
        #X = torchvision.transforms.functional.pil_to_tensor(img)
        X = self.Img_transform(img)

        if self.task == 'cat':
            label = idx // 30
        else:
            label = idx // 10

        return X, label

ds = Agressa2023_dataset(task='desc')

train_dl = torch.utils.data.DataLoader(ds,
                                       batch_size=12,    # バッチサイズ
                                       shuffle=True,     # データシャッフル
                                       num_workers=0,    # 高速化
                                       pin_memory=True,  # 高速化
                                       worker_init_fn=worker_init_fn
                                      )

In [None]:
from torchvision import models

a_model = models.alexnet(weights='DEFAULT', progress=True)
#a_model = models.resnet18(weights='DEFAULT', progress=True)
#a_model = AlexNet(weights='DEFAULT', progress=True)
a_model.eval()

In [None]:
a_parameters = {name:param for name, param in a_model.named_parameters()}
a_modules = {name:param for name, param in a_model.named_modules()}

print(f'パラメータ名:{a_parameters.keys()}')
print(f'モジュール名:{a_modules.keys()}')

In [None]:
a_model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=6)
a_model.eval()

In [None]:
update_param_names = ['classifier.6.weight', 'classifier.6.bias']
params_to_update = []
for name, param in a_model.named_parameters():
    if name in update_param_names:
        param.requires_grad = True
        params_to_update.append(param)
        print(name)
    else:
        param.requires_grad = False
#print(f'parmas_to_update:{params_to_update}')


In [None]:
# 最適化手法の設定
lr = 0.01
optimizer = torch.optim.Adam(params=params_to_update, lr=lr)

# 評価基準
criterion = torch.nn.CrossEntropyLoss()

n_epochs = 5
a_model.train()
for epoch in range(n_epochs):
    epoch_loss = 0.
    for X, y in train_dl:
        optimizer.zero_grad()
        out = a_model(X)
        print(f'エポック:{epoch+1}',
              f'教師:{y.detach().numpy()}',
              f'出力:{out.argmax(dim=1).numpy()}')
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'epoch_loss:{epoch_loss}')

