# 環境

我們在這次競賽所需的套件都可以直接在 colab 上import，除了 timm。

timm 是一個 pre-trained 好的圖形分類模型，很適合用在本次競賽較小的資料集。

以下我們先下載最新版本的 timm。

這裡要直接從 github 下載，因為我們使用的模型需要比較新的版本才有，

因此如果直接用 pip sintall timm 就沒辦法使用之後會用到的模型。

In [1]:
# import需要用到的套件
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from tqdm.notebook import tqdm

In [2]:
!pip install git+https://github.com/rwightman/pytorch-image-models.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/rwightman/pytorch-image-models.git
  Cloning https://github.com/rwightman/pytorch-image-models.git to /tmp/pip-req-build-55a13pmp
  Running command git clone -q https://github.com/rwightman/pytorch-image-models.git /tmp/pip-req-build-55a13pmp
Building wheels for collected packages: timm
  Building wheel for timm (setup.py) ... [?25l[?25hdone
  Created wheel for timm: filename=timm-0.6.2.dev0-py3-none-any.whl size=498066 sha256=8e29900f6a9cea6f65401574372c0fb7ae016857f40fa2e83b7d0ca8c5cda0f3
  Stored in directory: /tmp/pip-ephem-wheel-cache-9tr69q3z/wheels/a0/ec/5f/289118b747739bb1e02e36cf3d7e759721e881c183653719dc
Successfully built timm
Installing collected packages: timm
Successfully installed timm-0.6.2.dev0


# 準備資料

In [3]:
# 從雲端下載 training.zip
import gdown

url = "https://drive.google.com/u/4/uc?id=1KT_mJEdYtOXF79gdwgQsjmZQfzQS3ApU&export=download"
output = "training.zip"
gdown.download(url, output)

# 從雲端下載 public.zip
url = "https://drive.google.com/u/3/uc?id=18VYedKncZwsru5NgVFDTtHpRZAHf2-zE&export=download"
output = "orchid_public_set.zip"
gdown.download(url, output)


# 從雲端下載 private.zip
url = "https://drive.google.com/u/3/uc?id=1Qt5jcyZYnoykcwbkCjRpTHCJWf-JB1Vm&export=download"
output = "orchid_private_set.zip"
gdown.download(url, output)

Downloading...
From: https://drive.google.com/u/4/uc?id=1KT_mJEdYtOXF79gdwgQsjmZQfzQS3ApU&export=download
To: /content/training.zip
100%|██████████| 90.7M/90.7M [00:00<00:00, 103MB/s]
Downloading...
From: https://drive.google.com/u/3/uc?id=18VYedKncZwsru5NgVFDTtHpRZAHf2-zE&export=download
To: /content/orchid_public_set.zip
100%|██████████| 2.52G/2.52G [00:29<00:00, 86.4MB/s]
Downloading...
From: https://drive.google.com/u/3/uc?id=1Qt5jcyZYnoykcwbkCjRpTHCJWf-JB1Vm&export=download
To: /content/orchid_private_set.zip
100%|██████████| 2.13G/2.13G [00:28<00:00, 75.5MB/s]


'orchid_private_set.zip'

In [4]:
# 解壓縮 training.zip 到 data 資料夾
!unzip "training.zip" -d "data"

Archive:  training.zip
  inflating: data/02a91mzn84.jpg     
  inflating: data/02nehv1tf6.jpg     
  inflating: data/02sx9ijfd6.jpg     
  inflating: data/032p7z15ol.jpg     
  inflating: data/03jezhu9i8.jpg     
  inflating: data/03ol9zqjn7.jpg     
  inflating: data/04d6xbohpg.jpg     
  inflating: data/04kxsd2rf9.jpg     
  inflating: data/05tdemvcqx.jpg     
  inflating: data/06uasyi3nt.jpg     
  inflating: data/07ufd3njrv.jpg     
  inflating: data/07uqc9hdnt.jpg     
  inflating: data/084dzytmfe.jpg     
  inflating: data/09de32ybos.jpg     
  inflating: data/0a1h7votc5.jpg     
  inflating: data/0a7yscrh49.jpg     
  inflating: data/0clzd3aqyg.jpg     
  inflating: data/0cwbio2hfn.jpg     
  inflating: data/0d9ucgfmy6.jpg     
  inflating: data/0dhgmwk8ri.jpg     
  inflating: data/0dwec2hx45.jpg     
  inflating: data/0e1u74idof.jpg     
  inflating: data/0eac5fkrg6.jpg     
  inflating: data/0evyz4ogbj.jpg     
  inflating: data/0f2wec1id9.jpg     
  inflating: data/0gawub3e7

In [5]:
# 解壓縮 orchid_public_set.zip 和 orchid_private_set.zip
!unzip -P "sxRHRQmzmRw8TS!X4Kz23oRvg@" "orchid_public_set.zip" -d "test_dataset"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: test_dataset/vhbo76sqlc.jpg  
  inflating: test_dataset/vhc1blas5x.jpg  
  inflating: test_dataset/vhcew02zb4.jpg  
  inflating: test_dataset/vhcyjrde2l.jpg  
  inflating: test_dataset/vhebkr1j5t.jpg  
  inflating: test_dataset/vhfxpu4eni.jpg  
  inflating: test_dataset/vhg1c7ea9x.jpg  
  inflating: test_dataset/vhgxed092j.jpg  
  inflating: test_dataset/vhi3fb84t7.jpg  
  inflating: test_dataset/vhiojxqz9w.jpg  
  inflating: test_dataset/vhjbyf7u9g.jpg  
  inflating: test_dataset/vhkar0z521.jpg  
  inflating: test_dataset/vhkxm25u80.jpg  
  inflating: test_dataset/vhnmdja4ki.jpg  
  inflating: test_dataset/vhnuemgz6d.jpg  
  inflating: test_dataset/vhrmxqzju7.jpg  
  inflating: test_dataset/vhsxur8ap1.jpg  
  inflating: test_dataset/vhtbly4g52.jpg  
  inflating: test_dataset/vhtkx2fr94.jpg  
  inflating: test_dataset/vhueatw0lx.jpg  
  inflating: test_dataset/vhxbzlwr4g.jpg  
  inflating: test_dataset/vhxk0q

In [6]:
!unzip -P "Y8vBt&e*AAZ5GREL3#gA9i9j3A" "orchid_private_set.zip" -d "test_dataset"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: test_dataset/vmgbzn9sae.jpg  
  inflating: test_dataset/vmh9a2xlf8.jpg  
  inflating: test_dataset/vmhnkp5sgb.jpg  
  inflating: test_dataset/vmjlxuc98p.jpg  
  inflating: test_dataset/vmjx52t4an.jpg  
  inflating: test_dataset/vmjz1l8h7u.jpg  
  inflating: test_dataset/vmnxkp1zft.jpg  
  inflating: test_dataset/vmp5xjyiek.jpg  
  inflating: test_dataset/vmrl8t9zew.jpg  
  inflating: test_dataset/vmrtyzecq3.jpg  
  inflating: test_dataset/vmwbo9f0pi.jpg  
  inflating: test_dataset/vmwj7hndog.jpg  
  inflating: test_dataset/vmwx0s8gir.jpg  
  inflating: test_dataset/vmx3ezp45j.jpg  
  inflating: test_dataset/vmxfsnw379.jpg  
  inflating: test_dataset/vmxy6lf03d.jpg  
  inflating: test_dataset/vmzhw2cy05.jpg  
  inflating: test_dataset/vn0k92poy7.jpg  
  inflating: test_dataset/vn1k0jg4r9.jpg  
  inflating: test_dataset/vn1y480xei.jpg  
  inflating: test_dataset/vn2tcp731z.jpg  
  inflating: test_dataset/vn31hr

In [7]:
# 將訓練資料分成 train/test
# 其中 test size 設為 0.2

from pathlib import Path
import pandas as pd
import cv2
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

def split_dataset(img_dir: str = "data",
                  csv_dir: str = "label.csv", 
                  save_dir: str = "orchid_dataset", 
                  test_size: float = 0.2,
                  random_state: int = None
                  ) -> None:

    # 建立 train/test 資料夾
    Path.mkdir(Path(save_dir))
    Path.mkdir(Path(save_dir) / "train")
    Path.mkdir(Path(save_dir) / "test")

    # 將資料分成 orchid_dataset/train 和 orchid_dataset/test
    df = pd.read_csv(Path(img_dir) / csv_dir)
    train, test = train_test_split(
        df, test_size=test_size, stratify=df["category"], shuffle=True, random_state=random_state
    )
    train_save_path = Path(save_dir) / "train"
    test_save_path = Path(save_dir) / "test"

    for df_, save_dir_ in zip([train, test], [train_save_path, test_save_path]):
        label_list = df_["category"].unique()

        for label in tqdm(label_list):
            label_dir = Path(save_dir_) / str(label)
            if not Path.exists(label_dir):
                Path.mkdir(label_dir)

            sub_df = df_[df_["category"] == label]
            for f in sub_df["filename"]:
                img_path = str(Path(img_dir) / f)
                save_dir = str(Path(save_dir_) / str(label) / f)
                img = cv2.imread(img_path)
                cv2.imwrite(save_dir, img)

    train.to_csv(str(train_save_path) + ".csv", index=False)
    test.to_csv(str(test_save_path) + ".csv", index=False)

# 使用固定的random_state是為了在固定樣本之下去改進模型
# 也可以用random seed，但由於test的樣本數很少，隨機性影響很大
random_state = None
test_size = 0.2
split_dataset(test_size=test_size, random_state=random_state)

  0%|          | 0/219 [00:00<?, ?it/s]

  0%|          | 0/219 [00:00<?, ?it/s]

# 資料處理 (Data Augmentation)

本次競賽中

In [8]:
"""
3Augment implementation
Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
and timm DA(https://github.com/rwightman/pytorch-image-models)
"""

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import random
from torchvision import transforms
from PIL import ImageFilter, ImageOps

class Solarization(object):
    """
    Apply Solarization to the PIL image.
    """

    def __init__(self, p=0.2):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class GaussianBlur(object):
    """
    Apply Gaussian Blur to the PIL image.
    """

    def __init__(self, p=0.1, radius_min=0.1, radius_max=2.0):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        do_it = random.random() <= self.prob
        if not do_it:
            return img

        img = img.filter(
            ImageFilter.GaussianBlur(
                radius=random.uniform(self.radius_min, self.radius_max)
            )
        )
        return img


class gray_scale(object):
    """
    Apply Solarization to the PIL image.
    """

    def __init__(self, p=0.2):
        self.p = p
        self.transf = transforms.Grayscale(3)

    def __call__(self, img):
        if random.random() < self.p:
            return self.transf(img)
        else:
            return img

# Help Function

In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score

@torch.no_grad()
def evaluate(model: nn.Module, dataloader: DataLoader, device: str):

    model.eval()
    criterion = nn.CrossEntropyLoss()

    total_loss, total_acc = 0, 0
    outputs, labels = [], []

    for _, (data, label) in enumerate(dataloader):
        data, label = data.to(device), label.to(device)
        output = model(data)
        loss = criterion(output, label)
        y_true = label.cpu().numpy()
        y_pred = torch.argmax(output, 1).cpu().numpy()
        acc = accuracy_score(y_true, y_pred)
        total_loss += loss
        total_acc += acc
        outputs.append(y_pred)
        labels.append(y_true)

    avg_loss = total_loss / len(dataloader)
    avg_acc = total_acc / len(dataloader)

    outputs = np.concatenate(outputs)
    labels = np.concatenate(labels)
    macro_f1 = f1_score(outputs, labels, average="macro")

    final_score = 0.5 * (avg_acc + macro_f1)

    return avg_loss, avg_acc, macro_f1, final_score



# 訓練模型

## Swin-v2 192

In [10]:
import time
from torchvision.datasets import ImageFolder 
from torch.utils.data import DataLoader
from torchvision import transforms as T
import timm
from timm.loss import LabelSmoothingCrossEntropy
from timm.scheduler import CosineLRScheduler
from timm.utils.clip_grad import dispatch_clip_grad
from timm.models.swin_transformer_v2 import swinv2_base_window12_192_22k

# data preprocessing configuration
TRAIN_DIR = "/content/orchid_dataset/train"
TEST_DIR = "/content/orchid_dataset/test"
IMAGE_SIZE = 192
TEST_RESIZE = int((256 / 224) * IMAGE_SIZE)

# training configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 200
BATCH_SIZE = 32
MODEL_NAME = "swinv2_base_window12_192_22k"

# optimizer configuration
LEARNING_RATE = 3e-3

# lr scheduler configuration
T_INITIAL = 10
WARMUP_T = 5
WARMUP_LR_INIT = 1e-5
K_DECAY = 0.75
LR_MIN = 1e-5

# data augmentation
three_augment = T.Compose(
    [T.RandomChoice([gray_scale(p=0.5), Solarization(p=0.5), GaussianBlur(p=0.5)])]
)
color_jitter = 0.2
normalize = T.Normalize(mean=[0.4909, 0.4216, 0.3703], std=[0.2459, 0.2420, 0.2489])
transform = {
    "train": T.Compose(
        [
            T.Resize(IMAGE_SIZE, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomCrop(IMAGE_SIZE, padding=4, padding_mode="reflect"),
            T.RandomHorizontalFlip(),
            three_augment,
            T.ColorJitter(color_jitter, color_jitter, color_jitter),
            T.ToTensor(),
            normalize,
        ]
    ),
    "val": T.Compose(
        [
            T.Resize(IMAGE_SIZE, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(IMAGE_SIZE),
            T.ToTensor(),
            normalize,
        ]
    ),
}



def main():

    train_dataset = ImageFolder(root=TRAIN_DIR, transform=transform["train"])
    val_dataset = ImageFolder(root=TEST_DIR, transform=transform["val"])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

    num_classes = len(train_dataset.classes)
    model = swinv2_base_window12_192_22k(
        pretrained=True, 
        num_classes=num_classes, 
        drop_rate=0.1, 
        attn_drop_rate=0.1, 
        drop_path_rate=0.1)
    
    # only train MSA(multi-head self-attention) and head
    for name_p, p in model.named_parameters():
        if ".attn." in name_p:
            p.requires_grad = True
        else:
            p.requires_grad = False
    
    model.head.weight.requires_grad = True
    model.head.bias.requires_grad = True
    
    try:
        for p in model.patch_embed.parameters():
                p.requires_grad = False
    except:
        print('no patch embed')

    print(
        f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    model.to(DEVICE)

    criterion = LabelSmoothingCrossEntropy(0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.02)
    scheduler = CosineLRScheduler(
        optimizer,
        t_initial=T_INITIAL,
        warmup_t=WARMUP_T,
        warmup_lr_init=WARMUP_LR_INIT,
        k_decay=K_DECAY,
        lr_min=LR_MIN,
    )

    for epoch in range(EPOCHS):
        start_time = time.time()
        for i, (data, label) in tqdm(enumerate(train_loader), total=len(train_loader)):
            model.train()
            # forward pass
            data, label = data.to(DEVICE), label.to(DEVICE)
            output = model(data)
            loss = criterion(output, label)

            # backward pass
            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            dispatch_clip_grad(model.parameters(), 5.0)

            # gradient decent or adam step
            optimizer.step()

        
        # update scheduler
        scheduler.step_update(epoch)

        train_loss, train_acc, train_macro_f1, train_final_scroe = evaluate(model, train_loader, DEVICE)
        val_loss, val_acc, val_macro_f1, val_final_scroe= evaluate(model, val_loader, DEVICE)

        print(
            f"【Epoch={epoch+1}】 train:【loss={train_loss:.3f}, acc={100*train_acc:.2f}%, f1={train_final_scroe:.3f}, final={train_final_scroe:.3f}】 \
            val: 【loss={val_loss:.3f}, acc={100*val_acc:.2f}%, f1={val_macro_f1:.3f}, final={val_final_scroe:.3f}】 {(time.time() - start_time):.2f}/s"
        )

        torch.save(
            {
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            f=f"{MODEL_NAME}.pt",
        )

            

if __name__ == "__main__":
    main()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth" to /root/.cache/torch/hub/checkpoints/swinv2_base_patch4_window12_192_22k.pth


number of params: 28408659


  0%|          | 0/55 [00:00<?, ?it/s]

【Epoch=1】 train:【loss=5.374, acc=1.36%, f1=0.011, final=0.011】             val: 【loss=5.351, acc=1.34%, f1=0.008, final=0.011】 92.16/s


## Swin-v2 384

In [11]:
import time
from torchvision.datasets import ImageFolder 
from torch.utils.data import DataLoader
from torchvision import transforms as T
import timm
from timm.loss import LabelSmoothingCrossEntropy
from timm.scheduler import CosineLRScheduler
from timm.utils.clip_grad import dispatch_clip_grad
from timm.models.swin_transformer_v2 import swinv2_base_window12to24_192to384_22kft1k

# data preprocessing configuration
TRAIN_DIR = "/content/orchid_dataset/train"
TEST_DIR = "/content/orchid_dataset/test"
IMAGE_SIZE = 384
TEST_RESIZE = int((256 / 224) * IMAGE_SIZE)

# training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 100
BATCH_SIZE = 2

MODEL_NAME = "swinv2_base_window12to24_192to384_22kft1k"

# optimizer
LEARNING_RATE = 2e-3

# lr scheduler
T_INITIAL = 10
WARMUP_T = 5
WARMUP_LR_INIT = 1e-5
K_DECAY = 0.75
LR_MIN = 1e-5

# data augmentation
normalize = T.Normalize(mean=[0.4909, 0.4216, 0.3703], std=[0.2459, 0.2420, 0.2489])
transform = {
    "train": T.Compose(
        [
            T.Resize(IMAGE_SIZE, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(IMAGE_SIZE),
            T.ToTensor(),
            normalize,
        ]
    ),
    "val": T.Compose(
        [
            T.Resize(IMAGE_SIZE, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(IMAGE_SIZE),
            T.ToTensor(),
            normalize,
        ]
    ),
}



def main():

    train_dataset = ImageFolder(root=TRAIN_DIR, transform=transform["train"])
    val_dataset = ImageFolder(root=TEST_DIR, transform=transform["val"])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

    num_classes = len(train_dataset.classes)
    model = swinv2_base_window12to24_192to384_22kft1k( 
                              pretrained=False, 
                              num_classes=num_classes,
                              img_size=IMAGE_SIZE,
                              drop_rate=0.1, 
                              attn_drop_rate=0.1, 
                              drop_path_rate=0.1)
    
    state_dict = model.state_dict()
    checkpoint = torch.load("swinv2_base_window12_192_22k.pt", map_location='cpu')
    checkpoint_model = checkpoint['model']

    pre_trained_layers = {}
    for k, v in checkpoint_model.items():
        if checkpoint_model[k].shape == state_dict[k].shape:
            pre_trained_layers[k] = v 

    model.load_state_dict(pre_trained_layers, strict=False)

    for name_p, p in model.named_parameters():
        if ".attn." in name_p:
            p.requires_grad = True
        else:
            p.requires_grad = False
    
    model.head.weight.requires_grad = True
    model.head.bias.requires_grad = True

    try:
        model.pos_embed.requires_grad = True
    except:
        print('no position encoding')
    
    try:
        for p in model.patch_embed.parameters():
                p.requires_grad = False
    except:
        print('no patch embed')

    print(
        f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    model.to(DEVICE)

    criterion = LabelSmoothingCrossEntropy(0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
    scheduler = CosineLRScheduler(
        optimizer,
        t_initial=T_INITIAL,
        warmup_t=WARMUP_T,
        warmup_lr_init=WARMUP_LR_INIT,
        k_decay=K_DECAY,
        lr_min=LR_MIN,
    )

    for epoch in range(EPOCHS):
        start_time = time.time()
        for i, (data, label) in tqdm(enumerate(train_loader), total=len(train_loader)):
            model.train()
            # forward pass
            data, label = data.to(DEVICE), label.to(DEVICE)
            output = model(data)
            loss = criterion(output, label)

            # backward pass
            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            dispatch_clip_grad(model.parameters(), 5.0)

            # gradient decent or adam step
            optimizer.step()

        
        # update scheduler
        scheduler.step_update(epoch)

        train_loss, train_acc, train_macro_f1, train_final_scroe = evaluate(model, train_loader, DEVICE)
        val_loss, val_acc, val_macro_f1, val_final_scroe= evaluate(model, val_loader, DEVICE)

        print(
            f"【Epoch={epoch+1}】 train:【loss={train_loss:.3f}, acc={100*train_acc:.2f}%, f1={train_final_scroe:.3f}, final={train_final_scroe:.3f}】 \
            val: 【loss={val_loss:.3f}, acc={100*val_acc:.2f}%, f1={val_macro_f1:.3f}, final={val_final_scroe:.3f}】 {(time.time() - start_time):.2f}/s"
        )

        torch.save(
                {
                    "epoch": epoch,
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                },
                f=f"{MODEL_NAME}.pt",
            )


            

if __name__ == "__main__":
    main()

no position encoding
number of params: 28408659


  0%|          | 0/876 [00:00<?, ?it/s]

【Epoch=1】 train:【loss=2.938, acc=41.38%, f1=0.390, final=0.390】             val: 【loss=3.256, acc=31.96%, f1=0.247, final=0.284】 407.29/s


# 推論

In [12]:
# 從雲端下載 class_mapping.json
url = "https://drive.google.com/u/3/uc?id=1NttS-JghMY_fkPFj7wqH8BUkkYm_2Bcv&export=download"
output = "class_mapping.json"
gdown.download(url, output)

Downloading...
From: https://drive.google.com/u/3/uc?id=1NttS-JghMY_fkPFj7wqH8BUkkYm_2Bcv&export=download
To: /content/class_mapping.json
100%|██████████| 2.85k/2.85k [00:00<00:00, 725kB/s]


'class_mapping.json'

In [13]:
# 從雲端下載 submission_template.csv
url = "https://drive.google.com/u/3/uc?id=1ZYeBeTvHM3OW9hvZV0u7zRNKHyUH9LWf&export=download"
output = "submission_template.csv"
gdown.download(url, output)

Downloading...
From: https://drive.google.com/u/3/uc?id=1ZYeBeTvHM3OW9hvZV0u7zRNKHyUH9LWf&export=download
To: /content/submission_template.csv
100%|██████████| 1.39M/1.39M [00:00<00:00, 109MB/s]


'submission_template.csv'

In [14]:
from pathlib import Path
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.notebook import tqdm


class OrchidDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_dir: str, transform: transforms):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        f = self.df["filename"].iloc[index]
        img_path = Path(self.img_dir) / f
        img = Image.open(img_path)

        if self.transform:
            img = self.transform(img)

        return img

@torch.no_grad()
def predict(model, dataloader, device, class_mapping):

    label_predictions_list = []

    for _, img in tqdm(enumerate(dataloader), total=len(dataloader)):
        img = img.to(device)
        output = model(img)

        y_pred = torch.argmax(output, 1).cpu().numpy()

        # label encode
        for y_ in y_pred:
            y = class_mapping[y_]
            label_predictions_list.append(y)

    return label_predictions_list

In [None]:
import json
import argparse
import warnings
import pandas as pd
from torch.utils.data import DataLoader
import torch
import timm


warnings.filterwarnings("ignore")


# data preprocessing configuration
DATA_DIR = "/content/test_dataset"
IMAGE_SIZE = 384
TEST_RESIZE = int((256 / 224) * IMAGE_SIZE)

# training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 5
BATCH_SIZE = 2
MODEL_NAME = "swinv2_base_window12to24_192to384_22kft1k"
CHECKPOINT = "swinv2_base_window12to24_192to384_22kft1k.pt"

normalize = transforms.Normalize(mean=[0.4909, 0.4216, 0.3703], std=[0.2459, 0.2420, 0.2489])
TRANSFORM = transforms.Compose(
        [
            transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            normalize,
        ]
    )



def main():

    # load class to index json file
    with open("class_mapping.json", "r") as f:
        class_mapping = json.load(f)
    class_mapping = {int(k): v for k, v in class_mapping.items()}

    device = "cuda" if torch.cuda.is_available() else "cpu"

    df = pd.read_csv("submission_template.csv")
    dataset = OrchidDataset(df=df, img_dir=DATA_DIR, transform=TRANSFORM)
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
    )

    # checkpoint = torch.hub.load_state_dict_from_url("https://drive.google.com/u/3/uc?id=1881QBnw6DB8tPg-PelPNYSGI5G3uY_C9&export=download")
    checkpoint = torch.load(CHECKPOINT)
    model = timm.create_model(
        MODEL_NAME,
        pretrained=False,
        num_classes=len(class_mapping),
        img_size=IMAGE_SIZE,
    )
    model.load_state_dict(checkpoint["model"])
    model.to(device)
    model.eval()

    # prediction of public and private dataset
    predictions = predict(model, dataloader, device, class_mapping)

    df["category"] = predictions
    df.to_csv("swinv2_submission", index=False)

if __name__ == "__main__":
    main()