# Vision Transformer를 활용한 식물 질환 분류

- [주의] drive에 남은 용량이 1~2GB 정도 있어야 함

- [관련 논문]

    https://arxiv.org/pdf/2010.11929.pdf (ViT 제안)

    https://www.frontiersin.org/articles/10.3389/fpls.2016.01419/full (간접 관련)
    
    https://www.sciencedirect.com/science/article/pii/S2666285X22000218 (간접 관련)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# 모델 준비

## 필요 라이브러리 설치

In [None]:
!pip install torch==2.0.1 torchvision~=0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
!pip install torchaudio==2.0.1

!pip install datasets
!pip install transformers
!pip install torchmetrics
!pip install accelerate -U

Collecting torch-xla==2.0
  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl (162.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.9/162.9 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch==2.0.1
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision~=0.15.1
  Downloading torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.0/21.0 MB[0m [31m50.9 MB/s[0m eta [36m0:00:00[0m


Collecting torchaudio==2.0.1
  Downloading torchaudio-2.0.1-cp310-cp310-manylinux1_x86_64.whl (4.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.4/4.4 MB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch==2.0.0 (from torchaudio==2.0.1)
  Downloading torch-2.0.0-cp310-cp310-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch, torchaudio
  Attempting uninstall: torch
    Found existing installation: torch 2.0.1
    Uninstalling torch-2.0.1:
      Successfully uninstalled torch-2.0.1
  Attempting uninstall: torchaudio
    Found existing installation: torchaudio 2.1.0+cu118
    Uninstalling torchaudio-2.1.0+cu118:
      Successfully uninstalled torchaudio-2.1.0+cu118
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the

## 데이터셋 다운로드



- 총 38개의 카테고리가 있음. 모든 데이터의 비율이 균등하지 않다는 점에 유의

- [데이터셋 분석] https://www.dropbox.com/scl/fi/hgv4fgnuc5kskcxpcws8e/.xlsx?rlkey=pfh6bpp25jd4fk6nf7c1qap8w&dl=1

- [관련 논문] Hughes, David P., and Marcel Salathe. “An Open Access Repository of Images on Plant Health to Enable the Development of Mobile Disease Diagnostics.” ArXiv:1511.08060 [Cs], Apr. 2016. arXiv.org, http://arxiv.org/abs/1511.08060.

- [출처] https://github.com/spMohanty/PlantVillage-Dataset

In [None]:
import os
import io
import shutil
import requests
from zipfile import ZipFile

def generate_dataset_dir(data_dir, test_rate=0.2, valid_rate=0.2, gap=1):
    """
    made by hope04302
    zip파일의 데이터셋을 train, validation, test로 분리
    """

    zip_path = "https://data.mendeley.com/public-files/datasets/tywbtsjrjv/files/d5652a28-c1d8-4b76-97f3-72fb80f94efc/file_downloaded"
    zip = requests.get(zip_path)

    if os.path.isdir(data_dir):
        shutil.rmtree(data_dir)

    with ZipFile(io.BytesIO(zip.content), 'r') as zf:
        zf.extractall(f"{data_dir}")

    in_path = f"{data_dir}/Plant_leave_diseases_dataset_without_augmentation"
    clss = os.listdir(in_path)

    for cls in clss:

        for folder in ["train", "validation", "test"]:
            os.makedirs(f"{data_dir}/{folder}/{cls}")

        images = os.listdir(f"{in_path}/{cls}")

        train_idx = int(len(images) * (1 - test_rate) * (1 - valid_rate))
        valid_idx = int(len(images) * (1 - test_rate))

        for image in images[:train_idx:gap]:
            if image[-3:].lower() == "jpg":
                shutil.move(f"{in_path}/{cls}/{image}", f"{data_dir}/train/{cls}")

        for image in images[train_idx:valid_idx:gap]:
            if image[-3:].lower() == "jpg":
                shutil.move(f"{in_path}/{cls}/{image}", f"{data_dir}/validation/{cls}")

        for image in images[valid_idx::gap]:
            if image[-3:].lower() == "jpg":
                shutil.move(f"{in_path}/{cls}/{image}", f"{data_dir}/test/{cls}")

    shutil.rmtree(in_path)

In [None]:
from datasets import load_dataset

generate_dataset_dir(data_dir='/content/plantvillage', test_rate=0.95, gap=1)
dataset = load_dataset("imagefolder", data_dir='/content/plantvillage')

Resolving data files:   0%|          | 0/2158 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/540 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/51606 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/2158 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/540 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/51606 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

## 모델, 토크나이저 다운로드(중요)

In [None]:
# 실제 실험에서는 다르게 설정할 것.
# NUM_LABELS는 건들지 말 것.

NUM_LABELS = 39
BATCH_SIZE = 32
EPOCHS = 200

In [None]:
import torch
from transformers import AutoImageProcessor

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

# 수정 1 =====================================

# model_ckpt: 원하는 모델을 아래에서 검색 후, 이름 붙여넣기(ctrl + click하면 들어가짐)
# https://huggingface.co/models?pipeline_tag=image-classification&sort=trending

# model_name: 모델을 드라이브에 저장할 이름
# save_dir: 모델을 저장할 경로

model_ckpt = 'google/vit-base-patch16-224-in21k'
save_dir = f'/content/drive/MyDrive/saki_model/{model_ckpt}'

# ============================================

device = xm.xla_device()

processor = AutoImageProcessor.from_pretrained(model_ckpt)

Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

In [None]:
from transformers import AutoModelForImageClassification

# 수정 2-1 =====================================

# [주의] 수정 2-1, 2-2 중 하나는 주석 처리할 것.
# 만약 원래 코드를 그대로 쓰고 싶으면 선택

def Model():
    return (AutoModelForImageClassification
            .from_pretrained(model_ckpt, num_labels=NUM_LABELS, ignore_mismatched_sizes=True)
            .to(device))

# ==============================================

In [None]:
# import torch.nn as nn
# from transformers import AutoModel

# # 수정 2-2 =====================================

# # [주의] 수정 2-1, 2-2 중 하나는 주석 처리할 것.
# # 원래 코드의 classifier를 고치고 싶다면 선택

# class Model(nn.Module):
#     def __init__(self):
#         super(Model, self).__init__()
#         self.model = AutoModel.from_pretrained(model_ckpt)
#         self.classifier = nn.Linear(self.model.config.hidden_size, NUM_LABELS)
#         self.loss = nn.BCEWithLogitsLoss()

#     def forward(self, pixel_values, labels=None):

#         # hidden.last_hidden_state도 있음...

#         hidden = self.model(pixel_values=pixel_values).pooler_output
#         logits = self.classifier(hidden)
#         loss = self.loss(logits, labels)

#         model_output = {
#             'loss': loss,
#             'logits': logits
#         }

#         return model_output

# # ==============================================

## 데이터셋, 모델 상태 확인

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 2158
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 540
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 51606
    })
})

In [None]:
dataset["train"].features["label"].names

['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry___Powdery_mildew',
 'Cherry___healthy',
 'Corn___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn___Common_rust',
 'Corn___Northern_Leaf_Blight',
 'Corn___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_spider_mite',
 'Tomato___Target_Spot',
 'Tomato___Tomato_Yellow_Lea

In [None]:
processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [None]:
Model()

Downloading pytorch_model.bin:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

## 모델 학습 준비

In [None]:
import torch.nn.functional as F

def transform(example_batch):
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack(
            [F.one_hot(torch.tensor(x['labels']), num_classes=NUM_LABELS).type(torch.float) for x in batch]
        )
    }

In [None]:
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy, F1Score, HammingDistance, AUROC, ExactMatch

def compute_metrics(eval_pred):


    # 수정 3 ==============================================

    # eval_pred: model에서 나온 output값.
    # 여기를 수정하면 원하는 방향으로 metrics를 조작 가능

    logits, labels = eval_pred
    logits = torch.Tensor(logits)
    labels = torch.Tensor(labels).long()

    probs = logits
    preds = torch.zeros(size = probs.size())
    for i in range(len(probs)):
        preds[i][torch.argmax(probs[i])] = 1

    accuracy = Accuracy(task='multilabel', num_labels=NUM_LABELS)
    f1_macro = F1Score(task="multilabel", num_labels=NUM_LABELS, average='macro')
    f1_micro = F1Score(task="multilabel", num_labels=NUM_LABELS, average='micro')
    f1_weight = F1Score(task="multilabel", num_labels=NUM_LABELS, average='weighted')
    em = ExactMatch(task='multiclass', num_classes=2)
    auroc = AUROC(task='multilabel', num_labels=NUM_LABELS, average='micro')
    hamming = HammingDistance(task="multiclass", num_classes=2)

    each_f1_score = F1Score(task='multilabel', num_labels=NUM_LABELS, average=None)

    metrics = {'accuracy': accuracy(preds, labels),
               'f1_macro': f1_macro(preds, labels),
               'f1_micro': f1_micro(preds, labels),
               'f1_weighted': f1_weight(preds, labels),
               'auroc': auroc(preds, labels),
               'em': em(preds, labels),
               'hamming_loss': hamming(preds, labels)} | dict(zip(map(str, range(NUM_LABELS)), each_f1_score(preds, labels)))

    # =======================================================

    return metrics

In [None]:
train, validation, test = dataset.with_transform(transform).values()

# 모델 학습

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="/content/model",
    do_train=True,
    do_eval=True,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_strategy="epoch",
    logging_dir='./logs',
    learning_rate=2e-5,
    run_name="v1",
    seed=42,
    remove_unused_columns=False,

    # 수정 4 =============================

    # 부정확하지만 빠르고 용량 적게 먹는 학습:
    # evaluation_strategy="epoch", save_strategy="no"

    # 좀 더 정확한 학습:
    evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model='eval_f1_macro', save_total_limit=1,

    # 좀 더 많이 정확한 학습(숫자 수정 가능):
    #evaluation_strategy="steps", eval_steps=100, save_strategy="steps", load_best_model_at_end=True, metric_for_best_model='eval_f1_score', save_total_limit=1,

    # ====================================

)

In [None]:
def _mp_train_fn(index=None, frags=None):

    model = Model().to(device)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train,
        eval_dataset=validation,
        compute_metrics=compute_metrics,
        tokenizer=processor,
        data_collator=collate_fn,
    )

    try: trainer.train()
    finally: trainer.save_model()

In [None]:
xmp.spawn(_mp_train_fn, args=(), nprocs=1, start_method='fork')

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Micro,F1 Weighted,Auroc,Em,Hamming Loss,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,Runtime,Samples Per Second,Steps Per Second
1,0.5268,0.439112,0.95305,0.021313,0.107955,0.031578,0.541923,0.107955,0.04695,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.631579,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.178322,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,17.3703,31.318,0.979
2,0.4062,0.376088,0.953216,0.026151,0.111111,0.034762,0.543544,0.111111,0.046784,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.818182,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.175559,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.9109,54.889,1.715
3,0.3508,0.326305,0.953021,0.023025,0.107407,0.032079,0.541642,0.107407,0.046979,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.174957,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.3162,74.355,2.324
4,0.3051,0.284664,0.954094,0.03211,0.127778,0.058066,0.552102,0.127778,0.045906,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.814815,0.0,0.0,0.0,0.0,0.0,0.0,0.225806,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.179577,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.8535,79.376,2.48
5,0.2672,0.250348,0.959259,0.064282,0.225926,0.176477,0.602502,0.225926,0.040741,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.705882,0.0,0.153846,0.0,0.0,0.0,0.0,0.741573,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.200393,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.641026,0.0,0.0,13.6034,39.99,1.25
6,0.236,0.222232,0.964717,0.09987,0.32963,0.250166,0.655756,0.32963,0.035283,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.648649,0.0,0.909091,0.0,0.0,0.0,0.0,0.955752,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.241706,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.097561,0.0,0.942308,0.0,0.0,9.1472,59.472,1.858
7,0.2105,0.199395,0.965302,0.101136,0.340741,0.240415,0.661461,0.340741,0.034698,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.727273,0.0,0.96,0.0,0.0,0.0,0.0,0.916667,0.083333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.248175,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.095238,0.0,0.8125,0.0,0.0,7.5196,72.344,2.261
8,0.1898,0.180841,0.965887,0.102996,0.351852,0.236553,0.667167,0.351852,0.034113,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.631579,0.0,0.96,0.0,0.0,0.0,0.0,0.80292,0.516129,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.280992,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.722222,0.0,0.0,6.6691,81.57,2.549
9,0.1731,0.165813,0.966764,0.113589,0.368519,0.24397,0.675726,0.368519,0.033236,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.615385,0.0,0.96,0.0,0.0,0.0,0.0,0.691824,0.722222,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.317757,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.193548,0.0,0.697987,0.0,0.117647,6.7817,80.216,2.507
10,0.1593,0.153384,0.96998,0.193741,0.42963,0.337137,0.707107,0.42963,0.03002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.774194,0.75,0.96,0.0,0.0,0.0,0.0,0.797101,0.904762,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.320755,0.5,0.0,0.0,0.32,0.0,0.190476,0.0,0.0,0.2,0.315789,0.66242,0.0,0.666667,9.4064,57.833,1.807


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Micro,F1 Weighted,Auroc,Em,Hamming Loss,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,Runtime,Samples Per Second,Steps Per Second
1,0.5268,0.439112,0.95305,0.021313,0.107955,0.031578,0.541923,0.107955,0.04695,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.631579,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.178322,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,17.3703,31.318,0.979
2,0.4062,0.376088,0.953216,0.026151,0.111111,0.034762,0.543544,0.111111,0.046784,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.818182,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.175559,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.9109,54.889,1.715
3,0.3508,0.326305,0.953021,0.023025,0.107407,0.032079,0.541642,0.107407,0.046979,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.174957,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.3162,74.355,2.324
4,0.3051,0.284664,0.954094,0.03211,0.127778,0.058066,0.552102,0.127778,0.045906,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.814815,0.0,0.0,0.0,0.0,0.0,0.0,0.225806,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.179577,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.8535,79.376,2.48
5,0.2672,0.250348,0.959259,0.064282,0.225926,0.176477,0.602502,0.225926,0.040741,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.705882,0.0,0.153846,0.0,0.0,0.0,0.0,0.741573,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.200393,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.641026,0.0,0.0,13.6034,39.99,1.25
6,0.236,0.222232,0.964717,0.09987,0.32963,0.250166,0.655756,0.32963,0.035283,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.648649,0.0,0.909091,0.0,0.0,0.0,0.0,0.955752,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.241706,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.097561,0.0,0.942308,0.0,0.0,9.1472,59.472,1.858
7,0.2105,0.199395,0.965302,0.101136,0.340741,0.240415,0.661461,0.340741,0.034698,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.727273,0.0,0.96,0.0,0.0,0.0,0.0,0.916667,0.083333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.248175,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.095238,0.0,0.8125,0.0,0.0,7.5196,72.344,2.261
8,0.1898,0.180841,0.965887,0.102996,0.351852,0.236553,0.667167,0.351852,0.034113,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.631579,0.0,0.96,0.0,0.0,0.0,0.0,0.80292,0.516129,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.280992,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.722222,0.0,0.0,6.6691,81.57,2.549
9,0.1731,0.165813,0.966764,0.113589,0.368519,0.24397,0.675726,0.368519,0.033236,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.615385,0.0,0.96,0.0,0.0,0.0,0.0,0.691824,0.722222,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.317757,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.193548,0.0,0.697987,0.0,0.117647,6.7817,80.216,2.507
10,0.1593,0.153384,0.96998,0.193741,0.42963,0.337137,0.707107,0.42963,0.03002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.774194,0.75,0.96,0.0,0.0,0.0,0.0,0.797101,0.904762,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.320755,0.5,0.0,0.0,0.32,0.0,0.190476,0.0,0.0,0.2,0.315789,0.66242,0.0,0.666667,9.4064,57.833,1.807


# 모델 평가

In [None]:
training_args = TrainingArguments(
    output_dir=save_dir,
    do_train=True,
    do_eval=True,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_strategy="epoch",
    logging_dir='./logs',
    learning_rate=2e-5,
    run_name="v1",
    seed=42,
    remove_unused_columns=False,
)

In [None]:
def _mp_test_fn(index=None, frags=None):

    model = AutoModelForImageClassification.from_pretrained("/content/model").to(device)

    # state_dict = torch.load(f"/content/model/pytorch_model.bin")
    # model = Model().to(device)
    # model.load_state_dict(state_dict=state_dict)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train,
        eval_dataset=test,
        compute_metrics=compute_metrics,
        tokenizer=processor,
        data_collator=collate_fn,
    )

    try: print(trainer.evaluate())
    finally: trainer.save_model()

In [None]:
xmp.spawn(_mp_test_fn, args=(), nprocs=1, start_method='fork')

In [None]:
import pandas as pd
from torch import tensor

a = {'eval_loss': 0.007293897680938244, 'eval_accuracy': 0.9979521036148071, 'eval_f1_score': 0.9610897898674011, 'eval_auroc': 0.9715272784233093, 'eval_recall': 0.9610897898674011, 'eval_precision': 0.9610897898674011, 'eval_0': 0.9348914623260498, 'eval_1': 0.9439002871513367, 'eval_2': 0.978723406791687, 'eval_3': 0.9713168144226074, 'eval_4': 0.9909090995788574, 'eval_5': 0.9720101952552795, 'eval_6': 0.968137264251709, 'eval_7': 0.8056460618972778, 'eval_8': 0.9977924823760986, 'eval_9': 0.9023883938789368, 'eval_10': 0.998643159866333, 'eval_11': 0.9737556576728821, 'eval_12': 0.9800676703453064, 'eval_13': 0.9874030947685242, 'eval_14': 0.966625452041626, 'eval_15': 0.9981871843338013, 'eval_16': 0.9843288660049438, 'eval_17': 0.9479768872261047, 'eval_18': 0.9586426019668579, 'eval_19': 0.9671787619590759, 'eval_20': 0.9717513918876648, 'eval_21': 0.9022801518440247, 'eval_22': 0.8054607510566711, 'eval_23': 0.9841726422309875, 'eval_24': 0.9904054403305054, 'eval_25': 0.9991401433944702, 'eval_26': 0.9816778898239136, 'eval_27': 0.9838337302207947, 'eval_28': 0.9446237683296204, 'eval_29': 0.7343904972076416, 'eval_30': 0.8848314881324768, 'eval_31': 0.9044075608253479, 'eval_32': 0.9190462231636047, 'eval_33': 0.9388272762298584, 'eval_34': 0.864098846912384, 'eval_35': 0.9904612302780151, 'eval_36': 0.9533527493476868, 'eval_37': 0.9835957884788513, 'eval_runtime': 840.3154, 'eval_samples_per_second': 61.425, 'eval_steps_per_second': 1.92}
a  = pd.DataFrame(a.values(), index=a.keys())
a