# 環境

我們在這次競賽所需的套件都可以直接在 colab 上import，除了 timm。

timm 是一個 pre-trained 好的圖形分類模型，很適合用在本次競賽較小的資料集。

以下我們先下載最新版本的 timm。

這裡要直接從 github 下載，因為我們使用的模型需要比較新的版本才有，

因此如果直接用 pip sintall timm 就沒辦法使用之後會用到的模型。

In [1]:
# import需要用到的套件
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from tqdm.notebook import tqdm

In [2]:
!pip install git+https://github.com/rwightman/pytorch-image-models.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/rwightman/pytorch-image-models.git
  Cloning https://github.com/rwightman/pytorch-image-models.git to /tmp/pip-req-build-_rywint4
  Running command git clone -q https://github.com/rwightman/pytorch-image-models.git /tmp/pip-req-build-_rywint4
Building wheels for collected packages: timm
  Building wheel for timm (setup.py) ... [?25l[?25hdone
  Created wheel for timm: filename=timm-0.6.2.dev0-py3-none-any.whl size=498066 sha256=640089da97c20501c2319397c799b8c7ec6093805d34302fc6f8a8ff6ee601b8
  Stored in directory: /tmp/pip-ephem-wheel-cache-prtops1i/wheels/a0/ec/5f/289118b747739bb1e02e36cf3d7e759721e881c183653719dc
Successfully built timm
Installing collected packages: timm
Successfully installed timm-0.6.2.dev0


# 準備資料

In [3]:
# 從雲端下載 training.zip
import gdown

url = "https://drive.google.com/u/4/uc?id=1KT_mJEdYtOXF79gdwgQsjmZQfzQS3ApU&export=download"
output = "training.zip"
gdown.download(url, output)

Downloading...
From: https://drive.google.com/u/4/uc?id=1KT_mJEdYtOXF79gdwgQsjmZQfzQS3ApU&export=download
To: /content/training.zip
100%|██████████| 90.7M/90.7M [00:02<00:00, 39.6MB/s]


'training.zip'

In [4]:
# 解壓縮 training.zip 到 data 資料夾
!unzip "training.zip" -d "data"

Archive:  training.zip
  inflating: data/02a91mzn84.jpg     
  inflating: data/02nehv1tf6.jpg     
  inflating: data/02sx9ijfd6.jpg     
  inflating: data/032p7z15ol.jpg     
  inflating: data/03jezhu9i8.jpg     
  inflating: data/03ol9zqjn7.jpg     
  inflating: data/04d6xbohpg.jpg     
  inflating: data/04kxsd2rf9.jpg     
  inflating: data/05tdemvcqx.jpg     
  inflating: data/06uasyi3nt.jpg     
  inflating: data/07ufd3njrv.jpg     
  inflating: data/07uqc9hdnt.jpg     
  inflating: data/084dzytmfe.jpg     
  inflating: data/09de32ybos.jpg     
  inflating: data/0a1h7votc5.jpg     
  inflating: data/0a7yscrh49.jpg     
  inflating: data/0clzd3aqyg.jpg     
  inflating: data/0cwbio2hfn.jpg     
  inflating: data/0d9ucgfmy6.jpg     
  inflating: data/0dhgmwk8ri.jpg     
  inflating: data/0dwec2hx45.jpg     
  inflating: data/0e1u74idof.jpg     
  inflating: data/0eac5fkrg6.jpg     
  inflating: data/0evyz4ogbj.jpg     
  inflating: data/0f2wec1id9.jpg     
  inflating: data/0gawub3e7

In [5]:
# 將訓練資料分成 train/test
# 其中 test size 設為 0.2

from pathlib import Path
import pandas as pd
import cv2
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

def split_dataset(img_dir: str = "data",
                  csv_dir: str = "label.csv", 
                  save_dir: str = "orchid_dataset", 
                  test_size: float = 0.2,
                  random_state: int = None
                  ) -> None:

    # 建立 train/test 資料夾
    Path.mkdir(Path(save_dir))
    Path.mkdir(Path(save_dir) / "train")
    Path.mkdir(Path(save_dir) / "test")

    # 將資料分成 orchid_dataset/train 和 orchid_dataset/test
    df = pd.read_csv(Path(img_dir) / csv_dir)
    train, test = train_test_split(
        df, test_size=test_size, stratify=df["category"], shuffle=True, random_state=random_state
    )
    train_save_path = Path(save_dir) / "train"
    test_save_path = Path(save_dir) / "test"

    for df_, save_dir_ in zip([train, test], [train_save_path, test_save_path]):
        label_list = df_["category"].unique()

        for label in tqdm(label_list):
            label_dir = Path(save_dir_) / str(label)
            if not Path.exists(label_dir):
                Path.mkdir(label_dir)

            sub_df = df_[df_["category"] == label]
            for f in sub_df["filename"]:
                img_path = str(Path(img_dir) / f)
                save_dir = str(Path(save_dir_) / str(label) / f)
                img = cv2.imread(img_path)
                cv2.imwrite(save_dir, img)

    train.to_csv(str(train_save_path) + ".csv", index=False)
    test.to_csv(str(test_save_path) + ".csv", index=False)

# 使用固定的random_state是為了在固定樣本之下去改進模型
# 也可以用random seed，但由於test的樣本數很少，隨機性影響很大
random_state = None
test_size = 0.2
split_dataset(test_size=test_size, random_state=random_state)

  0%|          | 0/219 [00:00<?, ?it/s]

  0%|          | 0/219 [00:00<?, ?it/s]

# 資料處理 (Data Augmentation)

本次競賽中

In [6]:
"""
3Augment implementation
Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
and timm DA(https://github.com/rwightman/pytorch-image-models)
"""

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import random
from torchvision import transforms
from PIL import ImageFilter, ImageOps

class Solarization(object):
    """
    Apply Solarization to the PIL image.
    """

    def __init__(self, p=0.2):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class GaussianBlur(object):
    """
    Apply Gaussian Blur to the PIL image.
    """

    def __init__(self, p=0.1, radius_min=0.1, radius_max=2.0):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        do_it = random.random() <= self.prob
        if not do_it:
            return img

        img = img.filter(
            ImageFilter.GaussianBlur(
                radius=random.uniform(self.radius_min, self.radius_max)
            )
        )
        return img


class gray_scale(object):
    """
    Apply Solarization to the PIL image.
    """

    def __init__(self, p=0.2):
        self.p = p
        self.transf = transforms.Grayscale(3)

    def __call__(self, img):
        if random.random() < self.p:
            return self.transf(img)
        else:
            return img

# Help Function

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score

@torch.no_grad()
def evaluate(model: nn.Module, dataloader: DataLoader, device: str):

    model.eval()
    criterion = nn.CrossEntropyLoss()

    total_loss, total_acc = 0, 0
    outputs, labels = [], []

    for _, (data, label) in enumerate(dataloader):
        data, label = data.to(device), label.to(device)
        output = model(data)
        loss = criterion(output, label)
        y_true = label.cpu().numpy()
        y_pred = torch.argmax(output, 1).cpu().numpy()
        acc = accuracy_score(y_true, y_pred)
        total_loss += loss
        total_acc += acc
        outputs.append(y_pred)
        labels.append(y_true)

    avg_loss = total_loss / len(dataloader)
    avg_acc = total_acc / len(dataloader)

    outputs = np.concatenate(outputs)
    labels = np.concatenate(labels)
    macro_f1 = f1_score(outputs, labels, average="macro")

    final_score = 0.5 * (avg_acc + macro_f1)

    return avg_loss, avg_acc, macro_f1, final_score

# 訓練模型

## Swin-v2 192

In [None]:
import time
from torchvision.datasets import ImageFolder 
from torch.utils.data import DataLoader
from torchvision import transforms as T
import timm
from timm.loss import LabelSmoothingCrossEntropy
from timm.scheduler import CosineLRScheduler
from timm.utils.clip_grad import dispatch_clip_grad
from timm.models.swin_transformer_v2 import swinv2_base_window12_192_22k

# data preprocessing configuration
TRAIN_DIR = "/content/orchid_dataset/train"
TEST_DIR = "/content/orchid_dataset/test"
IMAGE_SIZE = 192
TEST_RESIZE = int((256 / 224) * IMAGE_SIZE)

# training configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 200
BATCH_SIZE = 32
MODEL_NAME = "swinv2_base_window12_192_22k"

# optimizer configuration
LEARNING_RATE = 3e-3

# lr scheduler configuration
T_INITIAL = 10
WARMUP_T = 5
WARMUP_LR_INIT = 1e-5
K_DECAY = 0.75
LR_MIN = 1e-5

# data augmentation
three_augment = T.Compose(
    [T.RandomChoice([gray_scale(p=0.5), Solarization(p=0.5), GaussianBlur(p=0.5)])]
)
color_jitter = 0.2
normalize = T.Normalize(mean=[0.4909, 0.4216, 0.3703], std=[0.2459, 0.2420, 0.2489])
transform = {
    "train": T.Compose(
        [
            T.Resize(IMAGE_SIZE, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomCrop(IMAGE_SIZE, padding=4, padding_mode="reflect"),
            T.RandomHorizontalFlip(),
            three_augment,
            T.ColorJitter(color_jitter, color_jitter, color_jitter),
            T.ToTensor(),
            normalize,
        ]
    ),
    "val": T.Compose(
        [
            T.Resize(IMAGE_SIZE, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(IMAGE_SIZE),
            T.ToTensor(),
            normalize,
        ]
    ),
}



def main():

    train_dataset = ImageFolder(root=TRAIN_DIR, transform=transform["train"])
    val_dataset = ImageFolder(root=TEST_DIR, transform=transform["val"])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

    num_classes = len(train_dataset.classes)
    model = swinv2_base_window12_192_22k(
        pretrained=True, 
        num_classes=num_classes, 
        drop_rate=0.1, 
        attn_drop_rate=0.1, 
        drop_path_rate=0.1)
    
    # only train MSA(multi-head self-attention) and head
    for name_p, p in model.named_parameters():
        if ".attn." in name_p:
            p.requires_grad = True
        else:
            p.requires_grad = False
    
    model.head.weight.requires_grad = True
    model.head.bias.requires_grad = True
    
    try:
        for p in model.patch_embed.parameters():
                p.requires_grad = False
    except:
        print('no patch embed')

    print(
        f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    model.to(DEVICE)

    criterion = LabelSmoothingCrossEntropy(0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.02)
    scheduler = CosineLRScheduler(
        optimizer,
        t_initial=T_INITIAL,
        warmup_t=WARMUP_T,
        warmup_lr_init=WARMUP_LR_INIT,
        k_decay=K_DECAY,
        lr_min=LR_MIN,
    )

    for epoch in range(EPOCHS):
        start_time = time.time()
        for i, (data, label) in tqdm(enumerate(train_loader), total=len(train_loader)):
            model.train()
            # forward pass
            data, label = data.to(DEVICE), label.to(DEVICE)
            output = model(data)
            loss = criterion(output, label)

            # backward pass
            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            dispatch_clip_grad(model.parameters(), 5.0)

            # gradient decent or adam step
            optimizer.step()

        
        # update scheduler
        scheduler.step_update(epoch)

        if (epoch+1) % 2 == 0:
            
            # Print current metric
            train_loss, train_acc, train_macro_f1, train_final_scroe = evaluate(model, train_loader, DEVICE)
            val_loss, val_acc, val_macro_f1, val_final_scroe= evaluate(model, val_loader, DEVICE)

            print(
                f"【Epoch={epoch+1}】 train:【loss={train_loss:.3f}, acc={100*train_acc:.2f}%, f1={train_final_scroe:.3f}, final={train_final_scroe:.3f}】 \
                val: 【loss={val_loss:.3f}, acc={100*val_acc:.2f}%, f1={val_macro_f1:.3f}, final={val_final_scroe:.3f}】 {(time.time() - start_time):.2f}/s"
            )

            torch.save(
                {
                    "epoch": epoch,
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                },
                f=f"{MODEL_NAME}.pt",
            )

            

if __name__ == "__main__":
    main()

number of params: 28408659


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=2】 train:【loss=5.196, acc=2.50%, f1=0.021, final=0.021】                 val: 【loss=5.170, acc=2.66%, f1=0.016, final=0.021】 86.22/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=4】 train:【loss=4.470, acc=12.18%, f1=0.108, final=0.108】                 val: 【loss=4.470, acc=12.80%, f1=0.089, final=0.108】 87.68/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=6】 train:【loss=3.452, acc=28.92%, f1=0.270, final=0.270】                 val: 【loss=3.508, acc=28.41%, f1=0.231, final=0.257】 87.92/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=8】 train:【loss=2.538, acc=48.24%, f1=0.472, final=0.472】                 val: 【loss=2.688, acc=39.96%, f1=0.349, final=0.374】 86.57/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=10】 train:【loss=1.900, acc=62.56%, f1=0.617, final=0.617】                 val: 【loss=2.089, acc=54.53%, f1=0.491, final=0.518】 90.09/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=12】 train:【loss=1.450, acc=71.95%, f1=0.714, final=0.714】                 val: 【loss=1.693, acc=62.80%, f1=0.581, final=0.604】 87.67/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=14】 train:【loss=1.092, acc=79.64%, f1=0.793, final=0.793】                 val: 【loss=1.366, acc=67.67%, f1=0.638, final=0.657】 86.58/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=16】 train:【loss=0.908, acc=82.86%, f1=0.827, final=0.827】                 val: 【loss=1.189, acc=73.92%, f1=0.707, final=0.723】 86.44/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=18】 train:【loss=0.773, acc=85.95%, f1=0.859, final=0.859】                 val: 【loss=1.072, acc=76.06%, f1=0.735, final=0.748】 86.62/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=20】 train:【loss=0.633, acc=90.08%, f1=0.899, final=0.899】                 val: 【loss=0.964, acc=79.30%, f1=0.771, final=0.782】 87.05/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=22】 train:【loss=0.608, acc=90.17%, f1=0.901, final=0.901】                 val: 【loss=0.965, acc=78.08%, f1=0.761, final=0.771】 86.41/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=24】 train:【loss=0.555, acc=92.50%, f1=0.925, final=0.925】                 val: 【loss=0.932, acc=79.85%, f1=0.779, final=0.789】 87.18/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=26】 train:【loss=0.445, acc=93.71%, f1=0.936, final=0.936】                 val: 【loss=0.836, acc=82.16%, f1=0.798, final=0.810】 86.24/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=28】 train:【loss=0.448, acc=93.54%, f1=0.935, final=0.935】                 val: 【loss=0.845, acc=83.06%, f1=0.809, final=0.820】 86.26/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=30】 train:【loss=0.410, acc=95.13%, f1=0.952, final=0.952】                 val: 【loss=0.836, acc=83.20%, f1=0.812, final=0.822】 86.30/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=32】 train:【loss=0.376, acc=96.52%, f1=0.965, final=0.965】                 val: 【loss=0.812, acc=84.98%, f1=0.835, final=0.843】 86.29/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=34】 train:【loss=0.342, acc=96.93%, f1=0.969, final=0.969】                 val: 【loss=0.796, acc=84.54%, f1=0.832, final=0.839】 86.80/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=36】 train:【loss=0.314, acc=97.42%, f1=0.974, final=0.974】                 val: 【loss=0.788, acc=85.09%, f1=0.836, final=0.843】 86.26/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=38】 train:【loss=0.279, acc=98.07%, f1=0.981, final=0.981】                 val: 【loss=0.757, acc=87.54%, f1=0.863, final=0.869】 87.10/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=40】 train:【loss=0.243, acc=98.92%, f1=0.989, final=0.989】                 val: 【loss=0.712, acc=88.96%, f1=0.878, final=0.884】 86.34/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=42】 train:【loss=0.251, acc=98.45%, f1=0.984, final=0.984】                 val: 【loss=0.745, acc=86.75%, f1=0.856, final=0.862】 87.16/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=44】 train:【loss=0.217, acc=98.75%, f1=0.987, final=0.987】                 val: 【loss=0.702, acc=88.66%, f1=0.878, final=0.882】 86.28/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=46】 train:【loss=0.238, acc=98.28%, f1=0.983, final=0.983】                 val: 【loss=0.753, acc=87.54%, f1=0.866, final=0.871】 86.38/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=48】 train:【loss=0.205, acc=98.64%, f1=0.986, final=0.986】                 val: 【loss=0.700, acc=88.43%, f1=0.879, final=0.881】 86.58/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=50】 train:【loss=0.203, acc=98.67%, f1=0.987, final=0.987】                 val: 【loss=0.689, acc=90.20%, f1=0.892, final=0.897】 86.50/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=52】 train:【loss=0.184, acc=98.98%, f1=0.989, final=0.989】                 val: 【loss=0.672, acc=90.30%, f1=0.893, final=0.898】 86.45/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=54】 train:【loss=0.168, acc=99.22%, f1=0.992, final=0.992】                 val: 【loss=0.679, acc=89.43%, f1=0.885, final=0.890】 86.35/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=56】 train:【loss=0.175, acc=99.32%, f1=0.993, final=0.993】                 val: 【loss=0.675, acc=90.85%, f1=0.900, final=0.904】 86.39/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=58】 train:【loss=0.162, acc=99.41%, f1=0.994, final=0.994】                 val: 【loss=0.664, acc=89.65%, f1=0.887, final=0.892】 86.38/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=60】 train:【loss=0.170, acc=99.09%, f1=0.991, final=0.991】                 val: 【loss=0.673, acc=89.65%, f1=0.889, final=0.893】 87.10/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=62】 train:【loss=0.168, acc=99.49%, f1=0.995, final=0.995】                 val: 【loss=0.692, acc=89.98%, f1=0.889, final=0.895】 86.27/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=64】 train:【loss=0.158, acc=99.47%, f1=0.995, final=0.995】                 val: 【loss=0.686, acc=90.56%, f1=0.902, final=0.904】 86.96/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=66】 train:【loss=0.148, acc=99.43%, f1=0.994, final=0.994】                 val: 【loss=0.654, acc=91.31%, f1=0.902, final=0.908】 86.24/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=68】 train:【loss=0.135, acc=99.94%, f1=0.999, final=0.999】                 val: 【loss=0.644, acc=90.77%, f1=0.903, final=0.906】 87.03/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=70】 train:【loss=0.156, acc=99.60%, f1=0.996, final=0.996】                 val: 【loss=0.670, acc=91.64%, f1=0.909, final=0.913】 86.18/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=72】 train:【loss=0.147, acc=99.77%, f1=0.998, final=0.998】                 val: 【loss=0.662, acc=91.07%, f1=0.901, final=0.906】 86.33/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=74】 train:【loss=0.138, acc=99.77%, f1=0.998, final=0.998】                 val: 【loss=0.657, acc=90.65%, f1=0.897, final=0.902】 86.44/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=76】 train:【loss=0.146, acc=99.66%, f1=0.997, final=0.997】                 val: 【loss=0.676, acc=92.19%, f1=0.916, final=0.919】 86.31/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=78】 train:【loss=0.130, acc=99.83%, f1=0.998, final=0.998】                 val: 【loss=0.648, acc=90.85%, f1=0.898, final=0.903】 86.34/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=80】 train:【loss=0.139, acc=99.83%, f1=0.998, final=0.998】                 val: 【loss=0.679, acc=90.77%, f1=0.902, final=0.905】 86.12/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=82】 train:【loss=0.135, acc=99.77%, f1=0.998, final=0.998】                 val: 【loss=0.663, acc=91.19%, f1=0.905, final=0.908】 85.98/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=84】 train:【loss=0.128, acc=99.70%, f1=0.997, final=0.997】                 val: 【loss=0.649, acc=91.42%, f1=0.909, final=0.912】 85.95/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=86】 train:【loss=0.132, acc=99.72%, f1=0.997, final=0.997】                 val: 【loss=0.669, acc=90.44%, f1=0.898, final=0.901】 85.98/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=88】 train:【loss=0.123, acc=99.77%, f1=0.998, final=0.998】                 val: 【loss=0.639, acc=91.64%, f1=0.910, final=0.913】 85.96/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=90】 train:【loss=0.138, acc=99.89%, f1=0.999, final=0.999】                 val: 【loss=0.690, acc=90.10%, f1=0.897, final=0.899】 86.99/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=92】 train:【loss=0.126, acc=99.83%, f1=0.998, final=0.998】                 val: 【loss=0.665, acc=91.11%, f1=0.908, final=0.910】 86.57/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=94】 train:【loss=0.131, acc=99.72%, f1=0.997, final=0.997】                 val: 【loss=0.661, acc=91.31%, f1=0.910, final=0.911】 86.75/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=96】 train:【loss=0.135, acc=99.77%, f1=0.998, final=0.998】                 val: 【loss=0.675, acc=90.52%, f1=0.898, final=0.902】 86.05/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=98】 train:【loss=0.126, acc=99.72%, f1=0.997, final=0.997】                 val: 【loss=0.661, acc=91.21%, f1=0.908, final=0.910】 86.84/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=100】 train:【loss=0.130, acc=99.89%, f1=0.999, final=0.999】                 val: 【loss=0.684, acc=90.52%, f1=0.897, final=0.901】 86.21/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=102】 train:【loss=0.120, acc=99.77%, f1=0.998, final=0.998】                 val: 【loss=0.649, acc=91.07%, f1=0.904, final=0.907】 86.92/s


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=104】 train:【loss=0.120, acc=99.89%, f1=0.999, final=0.999】                 val: 【loss=0.653, acc=90.87%, f1=0.903, final=0.906】 86.09/s


  0%|          | 0/55 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

## Swin-v2 384

In [None]:
import time
from torchvision.datasets import ImageFolder 
from torch.utils.data import DataLoader
from torchvision import transforms as T
import timm
from timm.loss import LabelSmoothingCrossEntropy
from timm.scheduler import CosineLRScheduler
from timm.utils.clip_grad import dispatch_clip_grad
from timm.models.swin_transformer_v2 import swinv2_base_window12to24_192to384_22kft1k

# data preprocessing configuration
TRAIN_DIR = "/content/orchid_dataset/train"
TEST_DIR = "/content/orchid_dataset/test"
IMAGE_SIZE = 384
TEST_RESIZE = int((256 / 224) * IMAGE_SIZE)

# training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 100
BATCH_SIZE = 2

MODEL_NAME = "swinv2_base_window12to24_192to384_22kft1k"

# optimizer
LEARNING_RATE = 2e-3

# lr scheduler
T_INITIAL = 10
WARMUP_T = 5
WARMUP_LR_INIT = 1e-5
K_DECAY = 0.75
LR_MIN = 1e-5

# data augmentation
normalize = T.Normalize(mean=[0.4909, 0.4216, 0.3703], std=[0.2459, 0.2420, 0.2489])
transform = {
    "train": T.Compose(
        [
            T.Resize(IMAGE_SIZE, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(IMAGE_SIZE),
            T.ToTensor(),
            normalize,
        ]
    ),
    "val": T.Compose(
        [
            T.Resize(IMAGE_SIZE, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(IMAGE_SIZE),
            T.ToTensor(),
            normalize,
        ]
    ),
}



def main():

    train_dataset = ImageFolder(root=TRAIN_DIR, transform=transform["train"])
    val_dataset = ImageFolder(root=TEST_DIR, transform=transform["val"])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

    num_classes = len(train_dataset.classes)
    model = swinv2_base_window12to24_192to384_22kft1k( 
                              pretrained=False, 
                              num_classes=num_classes,
                              img_size=IMAGE_SIZE,
                              drop_rate=0.1, 
                              attn_drop_rate=0.1, 
                              drop_path_rate=0.1)
    
    state_dict = model.state_dict()
    checkpoint = torch.load("swinv2_base_window12_192_22k.pt", map_location='cpu')
    checkpoint_model = checkpoint['model']

    pre_trained_layers = {}
    for k, v in checkpoint_model.items():
        if checkpoint_model[k].shape == state_dict[k].shape:
            pre_trained_layers[k] = v 

    model.load_state_dict(pre_trained_layers, strict=False)

    for name_p, p in model.named_parameters():
        if ".attn." in name_p:
            p.requires_grad = True
        else:
            p.requires_grad = False
    
    model.head.weight.requires_grad = True
    model.head.bias.requires_grad = True

    try:
        model.pos_embed.requires_grad = True
    except:
        print('no position encoding')
    
    try:
        for p in model.patch_embed.parameters():
                p.requires_grad = False
    except:
        print('no patch embed')

    print(
        f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    model.to(DEVICE)

    criterion = LabelSmoothingCrossEntropy(0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
    scheduler = CosineLRScheduler(
        optimizer,
        t_initial=T_INITIAL,
        warmup_t=WARMUP_T,
        warmup_lr_init=WARMUP_LR_INIT,
        k_decay=K_DECAY,
        lr_min=LR_MIN,
    )

    for epoch in range(EPOCHS):
        start_time = time.time()
        for i, (data, label) in tqdm(enumerate(train_loader), total=len(train_loader)):
            model.train()
            # forward pass
            data, label = data.to(DEVICE), label.to(DEVICE)
            output = model(data)
            loss = criterion(output, label)

            # backward pass
            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            dispatch_clip_grad(model.parameters(), 5.0)

            # gradient decent or adam step
            optimizer.step()

        
        # update scheduler
        scheduler.step_update(epoch)

        train_loss, train_acc, train_macro_f1, train_final_scroe = evaluate(model, train_loader, DEVICE)
        val_loss, val_acc, val_macro_f1, val_final_scroe= evaluate(model, val_loader, DEVICE)

        print(
            f"【Epoch={epoch+1}】 train:【loss={train_loss:.3f}, acc={100*train_acc:.2f}%, f1={train_final_scroe:.3f}, final={train_final_scroe:.3f}】 \
            val: 【loss={val_loss:.3f}, acc={100*val_acc:.2f}%, f1={val_macro_f1:.3f}, final={val_final_scroe:.3f}】 {(time.time() - start_time):.2f}/s"
        )

        torch.save(
                {
                    "epoch": epoch,
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                },
                f=f"{MODEL_NAME}.pt",
            )


            

if __name__ == "__main__":
    main()