In [1]:
import glob
import random
import pickle

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import yaml
import torchmetrics
from transformers import get_cosine_schedule_with_warmup

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# seedの固定
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

SEED = 0
fix_seed(SEED)

In [3]:
#YAMLファイルを読み込む
with open('../config/config.yaml', 'r') as f:
    config = yaml.safe_load(f)


In [4]:
class MyDataset(Dataset):

    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):

        # 画像を読みこんで、指定の方法でtransform
        img_path = self.file_list[index]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        # pathに含まれる文字を使用してラベリングを実施
        if 'ants' in img_path:
            label = 0
        else:
            label = 1

        return img_transformed, label

In [5]:
class CreateDataModule(pl.LightningDataModule):

    def __init__(self, train_path, val_path, test_path, img_size=224,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
                 batch_size=config['batchsize']):
        super().__init__()
        self.train_path = train_path
        self.val_path = val_path
        self.test_path = test_path
        self.batch_size = batch_size

        # train時、val/test時の前処理をそれぞれ定義
        self.train_transforms = transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.5, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        self.val_test_transforms = transforms.Compose([
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    # データのダウンロードなどを行う場合は定義、今回は不要
    def prepare_data(self):
        pass

    # Trainer.fit()ではtrain/valのDatasetを、Trainer.test()ではtestのDatasetを生成
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = MyDataset(self.train_path, self.train_transforms)
            self.val_dataset = MyDataset(self.val_path, self.val_test_transforms)

        if stage == 'test' or stage is None:
            self.test_dataset = MyDataset(self.test_path, self.val_test_transforms)

    # こちらもTrainer.fit()ではtrain/valのDataLoaderを、Trainer.test()ではtestのDataLoaderを生成
    # trainはshuffleあり、val/testはshuffleなし
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


In [6]:
# seedを固定
fix_seed(SEED)

# valフォルダはtestとして使用
test_path = [path for path in glob.glob("../data/input/hymenoptera_data/val/*/*.jpg")]

# trainフォルダの画像を7:3でtrain:validに分割
modeling_path = [path for path in glob.glob("../data/input/hymenoptera_data/train/*/*.jpg")]
train_path, val_path = train_test_split(modeling_path, train_size=0.7)

# インスタンスを作成
data_module = CreateDataModule(train_path,val_path,test_path)

In [7]:

from typing import Any


class mymodel(pl.LightningModule):
    def __init__(self,model_name: str,pretrained: bool,
                 hidden_dim: int, out_dim: int,ratio=0.5
    ):
        super(mymodel,self).__init__()
        self.backbone=timm.create_model(model_name,pretrained=pretrained,num_classes=0)
        self.in_features=self.backbone.num_features
        self.head=nn.Sequential(
            nn.Linear(self.in_features,hidden_dim),
            nn.ReLU(),
            nn.Dropout(ratio),
            nn.Linear(hidden_dim,out_dim)
        )
        #accuracy測定用
        self.train_top1_acc = torchmetrics.Accuracy(top_k=1,task='multiclass', num_classes=7)
        self.train_top3_acc = torchmetrics.Accuracy(top_k=3,task='multiclass', num_classes=7)
        self.valid_top1_acc = torchmetrics.Accuracy(top_k=1,task='multiclass', num_classes=7)
        self.valid_top3_acc = torchmetrics.Accuracy(top_k=3,task='multiclass', num_classes=7)
        self.test_top1_acc= torchmetrics.Accuracy(top_k=1,task='multiclass', num_classes=7)
        self.test_top3_acc= torchmetrics.Accuracy(top_k=3,task='multiclass', num_classes=7)
    
    def forward(self,x):
        h=self.backbone(x)
        y=self.head(h)
        return y
    
    def training_step(self,batch,batch_idx):
        x,y=batch
        pred=self(x)
        loss=nn.functional.cross_entropy(pred,y)
        top1_acc = self.train_top1_acc(pred, y)
        top3_acc = self.train_top3_acc(pred, y)
        self.log('train_loss', loss)
        self.log('train_top1_acc', top1_acc)
        self.log('train_top3_acc', top3_acc)
        return {'loss':loss, 'train_top1_accuracy':top1_acc, 'train_top3_accuracy':top3_acc}
    
    def configure_optimizers(self):
        optimizer=optim.Adam(self.parameters(),lr=config['lr'])
        scheduler=get_cosine_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=20,num_training_steps=30)
        return [optimizer],[scheduler]
    
    def validation_step(self,batch,batch_idx):
        with torch.no_grad():
            x,y=batch
            pred = self(x)
            loss=nn.functional.cross_entropy(pred,y)
            top1_acc = self.valid_top1_acc(pred, y)
            top3_acc = self.valid_top3_acc(pred, y)
            self.log('valid_loss',loss)
            self.log('valid_top1_acc', top1_acc)
            self.log('valid_top3_acc', top3_acc)
        return {'loss':loss,'valid_top1_accuracy':top1_acc,'valid_top3_accuracy':top3_acc}
    
    def test_step(self,batch,batch_idx):
        x,y=batch
        pred=self(x)
        loss=nn.functional.cross_entropy(pred,y)
        top1_acc = self.test_top1_acc(pred, y)
        top3_acc = self.test_top3_acc(pred, y)
        self.log('test_loss', loss)
        self.log('test_top1_acc', top1_acc)
        self.log('test_top3_acc', top3_acc)
        return {'loss':loss, 'test_top1_accuracy':top1_acc, 'test_top3_accuracy':top3_acc}


    
model=mymodel(model_name='resnet18.a2_in1k',pretrained=True,hidden_dim=1000,out_dim=7)



In [8]:
for item in timm.list_models(pretrained=True):
    print(item)

bat_resnext26ts.ch_in1k
beit_base_patch16_224.in22k_ft_in22k
beit_base_patch16_224.in22k_ft_in22k_in1k
beit_base_patch16_384.in22k_ft_in22k_in1k
beit_large_patch16_224.in22k_ft_in22k
beit_large_patch16_224.in22k_ft_in22k_in1k
beit_large_patch16_384.in22k_ft_in22k_in1k
beit_large_patch16_512.in22k_ft_in22k_in1k
beitv2_base_patch16_224.in1k_ft_in1k
beitv2_base_patch16_224.in1k_ft_in22k
beitv2_base_patch16_224.in1k_ft_in22k_in1k
beitv2_large_patch16_224.in1k_ft_in1k
beitv2_large_patch16_224.in1k_ft_in22k
beitv2_large_patch16_224.in1k_ft_in22k_in1k
botnet26t_256.c1_in1k
caformer_b36.sail_in1k
caformer_b36.sail_in1k_384
caformer_b36.sail_in22k
caformer_b36.sail_in22k_ft_in1k
caformer_b36.sail_in22k_ft_in1k_384
caformer_m36.sail_in1k
caformer_m36.sail_in1k_384
caformer_m36.sail_in22k
caformer_m36.sail_in22k_ft_in1k
caformer_m36.sail_in22k_ft_in1k_384
caformer_s18.sail_in1k
caformer_s18.sail_in1k_384
caformer_s18.sail_in22k
caformer_s18.sail_in22k_ft_in1k
caformer_s18.sail_in22k_ft_in1k_384
c

In [9]:
# EarlyStoppingの設定
# 3epochで'val_loss'が0.05以上減少しなければ学習をストップ
early_stop_callback = EarlyStopping(
    monitor='valid_top1_acc', min_delta=0.05, patience=5, mode='max')

# モデルの保存先
# epoch数に応じて、「epoch=0.ckpt」のような形で保存
checkpoint_callback = ModelCheckpoint(
    dirpath=config['checkpoint_path'], monitor='valid_top1_acc', mode='max', verbose=True,save_last=True)

# trainerの設定
trainer = pl.Trainer(max_epochs=config['epoch'],
                     callbacks=[checkpoint_callback, early_stop_callback],
                     log_every_n_steps=10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:

# 訓練開始
trainer.fit(model, data_module)

/home/bdr/kaggle/kaggle/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory /home/bdr/kaggle/tutorial/checkpoint exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | backbone       | ResNet             | 11.2 M | train
1 | head           | Sequential         | 520 K  | train
2 | train_top1_acc | MulticlassAccuracy | 0      | train
3 | train_top3_acc | MulticlassAccuracy | 0      | train
4 | valid_top1_acc | MulticlassAccuracy | 0      | train
5 | valid_top3_acc | MulticlassAccuracy | 0      | train
6 | test_top1_acc  | MulticlassAccuracy | 0      | train
7 | test_top3_acc  | MulticlassAccuracy | 0      | train
--------------------------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.786    Total estimated model params size (MB)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/bdr/kaggle/kaggle/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


                                                                           

/home/bdr/kaggle/kaggle/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/home/bdr/kaggle/kaggle/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 1/1 [00:01<00:00,  0.67it/s, v_num=8]

Epoch 0, global step 1: 'valid_top1_acc' reached 0.00000 (best 0.00000), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=0-step=1.ckpt' as top 1


Epoch 1: 100%|██████████| 1/1 [00:01<00:00,  0.85it/s, v_num=8]

Epoch 1, global step 2: 'valid_top1_acc' reached 0.01370 (best 0.01370), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=1-step=2.ckpt' as top 1


Epoch 2: 100%|██████████| 1/1 [00:01<00:00,  0.84it/s, v_num=8]

Epoch 2, global step 3: 'valid_top1_acc' reached 0.28767 (best 0.28767), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=2-step=3.ckpt' as top 1


Epoch 3: 100%|██████████| 1/1 [00:01<00:00,  0.88it/s, v_num=8]

Epoch 3, global step 4: 'valid_top1_acc' reached 0.52055 (best 0.52055), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=3-step=4.ckpt' as top 1


Epoch 4: 100%|██████████| 1/1 [00:01<00:00,  0.99it/s, v_num=8]

Epoch 4, global step 5: 'valid_top1_acc' reached 0.63014 (best 0.63014), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=4-step=5.ckpt' as top 1


Epoch 5: 100%|██████████| 1/1 [00:00<00:00,  1.00it/s, v_num=8]

Epoch 5, global step 6: 'valid_top1_acc' was not in top 1


Epoch 6: 100%|██████████| 1/1 [00:00<00:00,  1.02it/s, v_num=8]

Epoch 6, global step 7: 'valid_top1_acc' reached 0.64384 (best 0.64384), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=6-step=7.ckpt' as top 1


Epoch 7: 100%|██████████| 1/1 [00:01<00:00,  0.95it/s, v_num=8]

Epoch 7, global step 8: 'valid_top1_acc' reached 0.65753 (best 0.65753), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=7-step=8.ckpt' as top 1


Epoch 8: 100%|██████████| 1/1 [00:00<00:00,  1.01it/s, v_num=8]

Epoch 8, global step 9: 'valid_top1_acc' was not in top 1


Epoch 9: 100%|██████████| 1/1 [00:01<00:00,  0.84it/s, v_num=8]

Epoch 9, global step 10: 'valid_top1_acc' reached 0.73973 (best 0.73973), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=9-step=10.ckpt' as top 1


Epoch 10: 100%|██████████| 1/1 [00:01<00:00,  0.99it/s, v_num=8]

Epoch 10, global step 11: 'valid_top1_acc' reached 0.80822 (best 0.80822), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=10-step=11.ckpt' as top 1


Epoch 11: 100%|██████████| 1/1 [00:00<00:00,  1.01it/s, v_num=8]

Epoch 11, global step 12: 'valid_top1_acc' reached 0.84932 (best 0.84932), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=11-step=12.ckpt' as top 1


Epoch 12: 100%|██████████| 1/1 [00:00<00:00,  1.01it/s, v_num=8]

Epoch 12, global step 13: 'valid_top1_acc' reached 0.86301 (best 0.86301), saving model to '/home/bdr/kaggle/tutorial/checkpoint/epoch=12-step=13-v1.ckpt' as top 1


Epoch 13: 100%|██████████| 1/1 [00:01<00:00,  1.00it/s, v_num=8]

Epoch 13, global step 14: 'valid_top1_acc' was not in top 1


Epoch 14: 100%|██████████| 1/1 [00:01<00:00,  0.97it/s, v_num=8]

Epoch 14, global step 15: 'valid_top1_acc' was not in top 1


Epoch 15: 100%|██████████| 1/1 [00:01<00:00,  0.98it/s, v_num=8]

Epoch 15, global step 16: 'valid_top1_acc' was not in top 1


Epoch 16: 100%|██████████| 1/1 [00:01<00:00,  0.97it/s, v_num=8]

Epoch 16, global step 17: 'valid_top1_acc' was not in top 1


Epoch 17: 100%|██████████| 1/1 [00:01<00:00,  0.98it/s, v_num=8]

Epoch 17, global step 18: 'valid_top1_acc' was not in top 1


Epoch 17: 100%|██████████| 1/1 [00:01<00:00,  0.64it/s, v_num=8]


In [11]:
# tensorboardでの確認
'''
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs
'''

'\n%load_ext tensorboard\n%tensorboard --logdir /content/lightning_logs\n'

In [12]:
result=trainer.test(model,datamodule=data_module,ckpt_path=checkpoint_callback.best_model_path)
result

Restoring states from the checkpoint path at /home/bdr/kaggle/tutorial/checkpoint/epoch=12-step=13-v1.ckpt
/home/bdr/kaggle/kaggle/lib/python3.10/site-packages/lightning_fabric/utilities/cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on G

Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 19.01it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.3695000112056732
      test_top1_acc         0.8888888955116272
      test_top3_acc                 1.0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.3695000112056732,
  'test_top1_acc': 0.8888888955116272,
  'test_top3_acc': 1.0}]

In [13]:
# 最良モデルの保存
best_model = model.load_from_checkpoint(checkpoint_callback.best_model_path)

with open(config['checkpoint_path'], mode='wb') as fp:
    pickle.dump(best_model, fp)

TypeError: The classmethod `mymodel.load_from_checkpoint` cannot be called on an instance. Please call it on the class type and make sure the return value is used.