<a href="https://colab.research.google.com/github/hope04302/sciencipia-plant-disease/blob/main/vit_new_code_gpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer를 활용한 식물 질환 분류

- [주의] drive에 남은 용량이 1~2GB 정도 있어야 함

- [관련 논문]

    https://arxiv.org/pdf/2010.11929.pdf (ViT 제안)

    https://www.frontiersin.org/articles/10.3389/fpls.2016.01419/full (간접 관련)
    
    https://www.sciencedirect.com/science/article/pii/S2666285X22000218 (간접 관련)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# 모델 준비

## 필요 라이브러리 설치

In [None]:
!pip install datasets
!pip install transformers
!pip install torchmetrics
!pip install accelerate -U



## 데이터셋 다운로드



- 총 38개의 카테고리가 있음. 모든 데이터의 비율이 균등하지 않다는 점에 유의

- [데이터셋 분석] https://www.dropbox.com/scl/fi/hgv4fgnuc5kskcxpcws8e/.xlsx?rlkey=pfh6bpp25jd4fk6nf7c1qap8w&dl=1

- [관련 논문] Hughes, David P., and Marcel Salathe. “An Open Access Repository of Images on Plant Health to Enable the Development of Mobile Disease Diagnostics.” ArXiv:1511.08060 [Cs], Apr. 2016. arXiv.org, http://arxiv.org/abs/1511.08060.

- [출처] https://github.com/spMohanty/PlantVillage-Dataset

In [None]:
import os
import io
import shutil
import requests
from zipfile import ZipFile

def generate_dataset_dir(data_dir, test_rate=0.2, valid_rate=0.2, gap=1):
    """
    made by hope04302
    zip파일의 데이터셋을 train, validation, test로 분리
    """

    zip_path = "https://data.mendeley.com/public-files/datasets/tywbtsjrjv/files/d5652a28-c1d8-4b76-97f3-72fb80f94efc/file_downloaded"
    zip = requests.get(zip_path)

    if os.path.isdir(data_dir):
        shutil.rmtree(data_dir)

    with ZipFile(io.BytesIO(zip.content), 'r') as zf:
        zf.extractall(f"{data_dir}")

    in_path = f"{data_dir}/Plant_leave_diseases_dataset_without_augmentation"
    clss = os.listdir(in_path)

    for cls in clss:

        for folder in ["train", "validation", "test"]:
            os.makedirs(f"{data_dir}/{folder}/{cls}")

        images = os.listdir(f"{in_path}/{cls}")

        train_idx = int(len(images) * (1 - test_rate) * (1 - valid_rate))
        valid_idx = int(len(images) * (1 - test_rate))

        for image in images[:train_idx:gap]:
            if image[-3:].lower() == "jpg":
                shutil.move(f"{in_path}/{cls}/{image}", f"{data_dir}/train/{cls}")

        for image in images[train_idx:valid_idx:gap]:
            if image[-3:].lower() == "jpg":
                shutil.move(f"{in_path}/{cls}/{image}", f"{data_dir}/validation/{cls}")

        for image in images[valid_idx::gap]:
            if image[-3:].lower() == "jpg":
                shutil.move(f"{in_path}/{cls}/{image}", f"{data_dir}/test/{cls}")

    shutil.rmtree(in_path)

In [None]:
from datasets import load_dataset

generate_dataset_dir(data_dir='/content/plantvillage', test_rate=0.95, gap=1)
dataset = load_dataset("imagefolder", data_dir='/content/plantvillage')

Resolving data files:   0%|          | 0/2203 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/552 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52691 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/2203 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/552 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/52691 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [None]:
clss = sorted(set(dataset["test"].features["label"].names))
sps = sorted(set([i.split('_')[0] for i in clss]))

cls2label = {j: i for i, j in enumerate(clss)}
sps2label = {j: i for i, j in enumerate(sps)}

show_sp = [sps2label[(clss[label].split('_'))[0]] for label in range(len(clss))]
show_cls = [[cls2label[x] for x in clss if (x.split('_'))[0] == sps[label]] for label in range(len(sps))]

## 모델, 토크나이저 다운로드(중요)

In [None]:
# 실제 실험에서는 다르게 설정할 것.
# NUM_LABELS는 건들지 말 것.

NUM_LABELS = 39
BATCH_SIZE = 32
EPOCHS = 200

In [None]:
import torch
from transformers import AutoImageProcessor

# 수정 1 =====================================

# model_ckpt: 원하는 모델을 아래에서 검색 후, 이름 붙여넣기(ctrl + click하면 들어가짐)
# https://huggingface.co/models?pipeline_tag=image-classification&sort=trending

# model_name: 모델을 드라이브에 저장할 이름
# save_dir: 모델을 저장할 경로

model_ckpt = 'google/vit-base-patch16-224-in21k'
save_dir = f'/content/drive/MyDrive/saki_model/{model_ckpt}'

# ============================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

processor = AutoImageProcessor.from_pretrained(model_ckpt)

Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

In [None]:
# from transformers import AutoModelForImageClassification

# # 수정 2-1 =====================================

# # [주의] 수정 2-1, 2-2 중 하나는 주석 처리할 것.
# # 만약 원래 코드를 그대로 쓰고 싶으면 선택

# model = (AutoModelForImageClassification
#          .from_pretrained(model_ckpt, num_labels=NUM_LABELS, ignore_mismatched_sizes=True)
#          .to(device))

# # ==============================================

In [None]:
import torch.nn as nn
from transformers import AutoModel

# 수정 2-2 =====================================

# [주의] 수정 2-1, 2-2 중 하나는 주석 처리할 것.
# 원래 코드의 classifier를 고치고 싶다면 선택

class Model(nn.Module):

    def __init__(self):

        super(Model, self).__init__()
        self.model = AutoModel.from_pretrained(model_ckpt)
        self.sp_classifier = nn.Linear(self.model.config.hidden_size, len(sps))
        self.classifier = nn.Linear(self.model.config.hidden_size, NUM_LABELS)
        self.sp_loss = nn.CrossEntropyLoss()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, pixel_values, sp_labels, labels):

        hidden = self.model(pixel_values=pixel_values)
        hidden = hidden.last_hidden_state

        sp_logits = self.sp_classifier(hidden[:, 0, :])
        di_logits = self.classifier(torch.mean(hidden, dim=1))

        sp_loss = self.sp_loss(sp_logits, sp_labels.argmax(-1))
        di_loss = self.loss(di_logits, labels.argmax(-1))

        loss = sp_loss + di_loss

        model_output = {
            'loss': loss,
            'di_logits': di_logits,
            'sp_logits': sp_logits
        }

        return model_output

model = Model()

# ==============================================

## 데이터셋, 모델 상태 확인

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 2203
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 552
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 52691
    })
})

In [None]:
dataset["train"].features["label"].names

['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Background_without_leaves',
 'Blueberry___healthy',
 'Cherry___Powdery_mildew',
 'Cherry___healthy',
 'Corn___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn___Common_rust',
 'Corn___Northern_Leaf_Blight',
 'Corn___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_spider_mite',
 'Tomato___Target_Spot'

In [None]:
processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [None]:
model

Model(
  (model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3

## 모델 학습 준비

In [None]:
import torch.nn.functional as F

def transform(example_batch):
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    inputs['sp_labels'] = [show_sp[i] for i in example_batch['label']]
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack(
            [F.one_hot(torch.tensor(x['labels']), num_classes=NUM_LABELS).type(torch.float) for x in batch]
        ),
        'sp_labels': torch.stack(
            [F.one_hot(torch.tensor(x['sp_labels']), num_classes=NUM_LABELS).type(torch.float) for x in batch]
        )
    }

In [None]:
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy, F1Score, HammingDistance, AUROC, ExactMatch

def compute_metrics(eval_pred):


    # 수정 3 ==============================================

    # eval_pred: model에서 나온 output값.
    # 여기를 수정하면 원하는 방향으로 metrics를 조작 가능

    di_logits, sp_logits = eval_pred.predictions
    di_labels, sp_labels = eval_pred.label_ids

    di_logits = torch.Tensor(di_logits)
    sp_logits = torch.Tensor(sp_logits)

    di_labels = torch.Tensor(di_labels).long()
    sp_labels = torch.Tensor(sp_labels).long()

    probs = di_logits
    sp_probs = sp_logits.argmax(dim=-1)
    preds = torch.zeros(size = probs.size())

    for i in range(len(probs)):
        _max_idx = -1
        _max_val = -1
        for j in show_cls[int(sp_probs[i])]:
            if _max_val < probs[i][j]:
                _max_val = probs[i][j]
                _max_idx = j
        preds[i][_max_idx] = 1

    accuracy = Accuracy(task='multilabel', num_labels=NUM_LABELS)
    f1_macro = F1Score(task="multilabel", num_labels=NUM_LABELS, average='macro')
    f1_micro = F1Score(task="multilabel", num_labels=NUM_LABELS, average='micro')
    f1_weight = F1Score(task="multilabel", num_labels=NUM_LABELS, average='weighted')
    em = ExactMatch(task='multiclass', num_classes=2)
    auroc = AUROC(task='multilabel', num_labels=NUM_LABELS, average='micro')
    hamming = HammingDistance(task="multiclass", num_classes=2)
    sp_accuracy = Accuracy(task='multiclass', num_classes=len(sps))

    each_f1_score = F1Score(task='multilabel', num_labels=NUM_LABELS, average=None)

    metrics = {'accuracy': accuracy(preds, di_labels),
               'f1_macro': f1_macro(preds, di_labels),
               'f1_micro': f1_micro(preds, di_labels),
               'f1_weighted': f1_weight(preds, di_labels),
               'auroc': auroc(preds, di_labels),
               'em': em(preds, di_labels),
               'hamming_loss': hamming(preds, di_labels),
               'species_acc': sp_accuracy(sp_probs, sp_labels.argmax(-1))} | dict(zip(map(str, range(NUM_LABELS)), each_f1_score(preds, di_labels)))

    # =======================================================

    return metrics

In [None]:
train, validation, test = dataset.with_transform(transform).values()

# 모델 학습

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="/content/model",
    do_train=True,
    do_eval=True,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_strategy="epoch",
    logging_dir='./logs',
    learning_rate=2e-5,
    run_name="v1",
    seed=42,
    remove_unused_columns=False,
    label_names=['labels', 'sp_labels'],

    # 수정 4 =============================

    # 부정확하지만 빠르고 용량 적게 먹는 학습:
    # evaluation_strategy="epoch", save_strategy="no"

    # 좀 더 정확한 학습:
    evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model='eval_f1_macro', save_total_limit=1,

    # 좀 더 많이 정확한 학습(숫자 수정 가능):
    #evaluation_strategy="steps", eval_steps=100, save_strategy="steps", load_best_model_at_end=True, metric_for_best_model='eval_f1_score', save_total_limit=1,

    # ====================================

)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=validation,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    data_collator=collate_fn,
)

In [None]:
try: trainer.train()
finally: trainer.save_model()

Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Micro,F1 Weighted,Auroc,Em,Hamming Loss,Species Acc,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,Runtime,Samples Per Second,Steps Per Second
1,5.5065,4.812342,0.968785,0.166072,0.391304,0.313418,0.687643,0.391304,0.031215,0.668478,0.0,0.0,0.0,0.583333,0.285714,0.0,0.0,0.0,0.0,0.916667,0.375,0.8,0.454545,0.0,0.0,0.0,0.964912,0.230769,0.0,0.0,0.5,0.0,0.0,0.0,0.0,0.859649,0.0,0.0,0.0,0.193548,0.0,0.0,0.0,0.0,0.0,0.0,0.312684,0.0,0.0,5.468,100.95,3.292
2,4.2622,3.841035,0.980491,0.407549,0.619565,0.563082,0.804777,0.619565,0.019509,0.896739,0.0,0.0,0.0,0.641509,0.736842,0.421053,0.823529,0.4,0.0,1.0,0.7,0.827586,0.564103,0.758621,0.333333,0.0,0.964912,0.901961,0.0,0.181818,0.756757,0.705882,0.428571,0.0,0.0,0.93578,0.875,0.533333,0.4,0.509091,0.0,0.24,0.0,0.347826,0.071429,0.190476,0.527363,0.0,0.117647,5.3259,103.645,3.38
3,3.4235,3.160238,0.988852,0.560143,0.782609,0.740647,0.888444,0.782609,0.011148,0.960145,0.0,0.0,0.0,0.68,0.736842,0.965517,0.947368,0.933333,0.0,1.0,0.727273,0.888889,0.628571,0.928571,0.571429,0.0,0.990826,0.92,0.0,0.181818,0.8,0.952381,0.75,0.0,0.0,0.971429,1.0,0.869565,0.4,0.727273,0.0,0.666667,0.0,0.709677,0.604651,0.538462,0.821705,0.0,0.933333,5.4092,102.049,3.328
4,2.7945,2.640204,0.99201,0.627935,0.844203,0.805975,0.920051,0.844203,0.00799,0.963768,0.0,0.0,0.0,0.666667,0.8,1.0,0.888889,1.0,0.0,1.0,0.833333,0.96,0.740741,0.933333,0.947368,0.333333,0.990826,0.92,0.0,0.666667,0.848485,0.777778,0.823529,0.0,0.0,0.990291,1.0,0.909091,0.545455,0.777778,0.0,0.765957,0.0,0.829268,0.8,0.785714,0.954955,0.0,1.0,5.6938,96.948,3.161
5,2.3108,2.22048,0.993032,0.693471,0.86413,0.831508,0.930277,0.86413,0.006968,0.974638,0.285714,0.0,0.0,0.708333,0.956522,1.0,1.0,1.0,0.0,1.0,0.8,1.0,0.88,0.933333,1.0,0.571429,0.990826,0.92,0.0,0.75,0.933333,0.823529,0.888889,0.0,0.857143,0.990291,1.0,0.952381,0.888889,0.857143,0.0,0.790698,0.2,0.829268,0.717949,0.636364,0.883333,0.0,1.0,5.3753,102.692,3.349
6,1.9288,1.893468,0.994519,0.738459,0.893116,0.866572,0.945152,0.893116,0.005481,0.985507,0.285714,0.0,0.0,0.708333,0.956522,0.967742,0.947368,1.0,0.0,1.0,0.8,1.0,0.869565,0.933333,1.0,0.888889,0.990826,0.901961,0.0,0.947368,0.965517,1.0,0.9,0.0,1.0,1.0,1.0,0.952381,1.0,0.875,0.0,0.809524,0.714286,0.85,0.780488,0.782609,0.972477,0.0,1.0,5.4035,102.156,3.331
7,1.6326,1.646674,0.995262,0.766123,0.907609,0.887769,0.952589,0.907609,0.004738,0.981884,0.5,0.25,0.0,0.73913,0.956522,0.967742,1.0,1.0,0.0,1.0,0.8,1.0,0.916667,0.965517,0.952381,0.888889,0.990826,0.92,0.0,0.888889,0.965517,1.0,0.842105,0.0,1.0,1.0,1.0,0.952381,1.0,0.976744,0.333333,0.782609,0.714286,0.871795,0.842105,0.888889,0.972477,0.0,1.0,5.4236,101.778,3.319
8,1.3874,1.444351,0.996005,0.798272,0.922101,0.908032,0.960026,0.922101,0.003995,0.978261,0.8,0.727273,0.0,0.809524,0.956522,0.967742,0.947368,1.0,0.0,1.0,0.8,1.0,0.956522,0.965517,1.0,1.0,0.990826,0.92,0.0,0.888889,0.965517,0.888889,0.888889,0.0,1.0,1.0,1.0,0.952381,1.0,1.0,0.571429,0.791667,0.666667,0.878049,0.888889,0.928571,0.981481,0.0,1.0,5.3993,102.235,3.334
9,1.1818,1.255512,0.99712,0.861248,0.943841,0.932945,0.971181,0.943841,0.00288,0.985507,0.8,1.0,0.0,0.894737,0.956522,1.0,1.0,1.0,0.0,1.0,0.8,1.0,0.956522,0.965517,1.0,1.0,0.990826,0.938776,0.4,0.947368,1.0,1.0,0.9,1.0,1.0,1.0,1.0,0.952381,1.0,0.976744,0.666667,0.837209,0.888889,0.923077,0.888889,0.923077,0.981481,0.0,1.0,5.5415,99.612,3.248
10,0.9913,1.071746,0.997677,0.887883,0.95471,0.947443,0.976759,0.95471,0.002323,0.985507,0.833333,1.0,0.0,0.944444,0.956522,1.0,1.0,1.0,0.0,1.0,0.8,1.0,0.956522,0.965517,1.0,1.0,0.990826,0.958333,0.666667,1.0,1.0,0.947368,0.9,1.0,1.0,1.0,1.0,0.952381,1.0,1.0,0.777778,0.857143,0.888889,0.972973,0.914286,0.962963,0.981481,0.4,1.0,5.4117,102.002,3.326


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Micro,F1 Weighted,Auroc,Em,Hamming Loss,Species Acc,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,Runtime,Samples Per Second,Steps Per Second
1,5.5065,4.812342,0.968785,0.166072,0.391304,0.313418,0.687643,0.391304,0.031215,0.668478,0.0,0.0,0.0,0.583333,0.285714,0.0,0.0,0.0,0.0,0.916667,0.375,0.8,0.454545,0.0,0.0,0.0,0.964912,0.230769,0.0,0.0,0.5,0.0,0.0,0.0,0.0,0.859649,0.0,0.0,0.0,0.193548,0.0,0.0,0.0,0.0,0.0,0.0,0.312684,0.0,0.0,5.468,100.95,3.292
2,4.2622,3.841035,0.980491,0.407549,0.619565,0.563082,0.804777,0.619565,0.019509,0.896739,0.0,0.0,0.0,0.641509,0.736842,0.421053,0.823529,0.4,0.0,1.0,0.7,0.827586,0.564103,0.758621,0.333333,0.0,0.964912,0.901961,0.0,0.181818,0.756757,0.705882,0.428571,0.0,0.0,0.93578,0.875,0.533333,0.4,0.509091,0.0,0.24,0.0,0.347826,0.071429,0.190476,0.527363,0.0,0.117647,5.3259,103.645,3.38
3,3.4235,3.160238,0.988852,0.560143,0.782609,0.740647,0.888444,0.782609,0.011148,0.960145,0.0,0.0,0.0,0.68,0.736842,0.965517,0.947368,0.933333,0.0,1.0,0.727273,0.888889,0.628571,0.928571,0.571429,0.0,0.990826,0.92,0.0,0.181818,0.8,0.952381,0.75,0.0,0.0,0.971429,1.0,0.869565,0.4,0.727273,0.0,0.666667,0.0,0.709677,0.604651,0.538462,0.821705,0.0,0.933333,5.4092,102.049,3.328
4,2.7945,2.640204,0.99201,0.627935,0.844203,0.805975,0.920051,0.844203,0.00799,0.963768,0.0,0.0,0.0,0.666667,0.8,1.0,0.888889,1.0,0.0,1.0,0.833333,0.96,0.740741,0.933333,0.947368,0.333333,0.990826,0.92,0.0,0.666667,0.848485,0.777778,0.823529,0.0,0.0,0.990291,1.0,0.909091,0.545455,0.777778,0.0,0.765957,0.0,0.829268,0.8,0.785714,0.954955,0.0,1.0,5.6938,96.948,3.161
5,2.3108,2.22048,0.993032,0.693471,0.86413,0.831508,0.930277,0.86413,0.006968,0.974638,0.285714,0.0,0.0,0.708333,0.956522,1.0,1.0,1.0,0.0,1.0,0.8,1.0,0.88,0.933333,1.0,0.571429,0.990826,0.92,0.0,0.75,0.933333,0.823529,0.888889,0.0,0.857143,0.990291,1.0,0.952381,0.888889,0.857143,0.0,0.790698,0.2,0.829268,0.717949,0.636364,0.883333,0.0,1.0,5.3753,102.692,3.349
6,1.9288,1.893468,0.994519,0.738459,0.893116,0.866572,0.945152,0.893116,0.005481,0.985507,0.285714,0.0,0.0,0.708333,0.956522,0.967742,0.947368,1.0,0.0,1.0,0.8,1.0,0.869565,0.933333,1.0,0.888889,0.990826,0.901961,0.0,0.947368,0.965517,1.0,0.9,0.0,1.0,1.0,1.0,0.952381,1.0,0.875,0.0,0.809524,0.714286,0.85,0.780488,0.782609,0.972477,0.0,1.0,5.4035,102.156,3.331
7,1.6326,1.646674,0.995262,0.766123,0.907609,0.887769,0.952589,0.907609,0.004738,0.981884,0.5,0.25,0.0,0.73913,0.956522,0.967742,1.0,1.0,0.0,1.0,0.8,1.0,0.916667,0.965517,0.952381,0.888889,0.990826,0.92,0.0,0.888889,0.965517,1.0,0.842105,0.0,1.0,1.0,1.0,0.952381,1.0,0.976744,0.333333,0.782609,0.714286,0.871795,0.842105,0.888889,0.972477,0.0,1.0,5.4236,101.778,3.319
8,1.3874,1.444351,0.996005,0.798272,0.922101,0.908032,0.960026,0.922101,0.003995,0.978261,0.8,0.727273,0.0,0.809524,0.956522,0.967742,0.947368,1.0,0.0,1.0,0.8,1.0,0.956522,0.965517,1.0,1.0,0.990826,0.92,0.0,0.888889,0.965517,0.888889,0.888889,0.0,1.0,1.0,1.0,0.952381,1.0,1.0,0.571429,0.791667,0.666667,0.878049,0.888889,0.928571,0.981481,0.0,1.0,5.3993,102.235,3.334
9,1.1818,1.255512,0.99712,0.861248,0.943841,0.932945,0.971181,0.943841,0.00288,0.985507,0.8,1.0,0.0,0.894737,0.956522,1.0,1.0,1.0,0.0,1.0,0.8,1.0,0.956522,0.965517,1.0,1.0,0.990826,0.938776,0.4,0.947368,1.0,1.0,0.9,1.0,1.0,1.0,1.0,0.952381,1.0,0.976744,0.666667,0.837209,0.888889,0.923077,0.888889,0.923077,0.981481,0.0,1.0,5.5415,99.612,3.248
10,0.9913,1.071746,0.997677,0.887883,0.95471,0.947443,0.976759,0.95471,0.002323,0.985507,0.833333,1.0,0.0,0.944444,0.956522,1.0,1.0,1.0,0.0,1.0,0.8,1.0,0.956522,0.965517,1.0,1.0,0.990826,0.958333,0.666667,1.0,1.0,0.947368,0.9,1.0,1.0,1.0,1.0,0.952381,1.0,1.0,0.777778,0.857143,0.888889,0.972973,0.914286,0.962963,0.981481,0.4,1.0,5.4117,102.002,3.326


# 모델 평가

In [None]:
training_args = TrainingArguments(
    output_dir=save_dir,
    do_train=True,
    do_eval=True,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_strategy="epoch",
    logging_dir='./logs',
    learning_rate=2e-5,
    run_name="v1",
    seed=42,
    remove_unused_columns=False,
)

model = AutoModelForImageClassification.from_pretrained("/content/model").to(device)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    data_collator=collate_fn,
)

In [None]:
try: print(trainer.evaluate())
finally: trainer.save_model()

In [None]:
import pandas as pd
from torch import tensor

a = {'eval_loss': 0.007293897680938244, 'eval_accuracy': 0.9979521036148071, 'eval_f1_score': 0.9610897898674011, 'eval_auroc': 0.9715272784233093, 'eval_recall': 0.9610897898674011, 'eval_precision': 0.9610897898674011, 'eval_0': 0.9348914623260498, 'eval_1': 0.9439002871513367, 'eval_2': 0.978723406791687, 'eval_3': 0.9713168144226074, 'eval_4': 0.9909090995788574, 'eval_5': 0.9720101952552795, 'eval_6': 0.968137264251709, 'eval_7': 0.8056460618972778, 'eval_8': 0.9977924823760986, 'eval_9': 0.9023883938789368, 'eval_10': 0.998643159866333, 'eval_11': 0.9737556576728821, 'eval_12': 0.9800676703453064, 'eval_13': 0.9874030947685242, 'eval_14': 0.966625452041626, 'eval_15': 0.9981871843338013, 'eval_16': 0.9843288660049438, 'eval_17': 0.9479768872261047, 'eval_18': 0.9586426019668579, 'eval_19': 0.9671787619590759, 'eval_20': 0.9717513918876648, 'eval_21': 0.9022801518440247, 'eval_22': 0.8054607510566711, 'eval_23': 0.9841726422309875, 'eval_24': 0.9904054403305054, 'eval_25': 0.9991401433944702, 'eval_26': 0.9816778898239136, 'eval_27': 0.9838337302207947, 'eval_28': 0.9446237683296204, 'eval_29': 0.7343904972076416, 'eval_30': 0.8848314881324768, 'eval_31': 0.9044075608253479, 'eval_32': 0.9190462231636047, 'eval_33': 0.9388272762298584, 'eval_34': 0.864098846912384, 'eval_35': 0.9904612302780151, 'eval_36': 0.9533527493476868, 'eval_37': 0.9835957884788513, 'eval_runtime': 840.3154, 'eval_samples_per_second': 61.425, 'eval_steps_per_second': 1.92}
a  = pd.DataFrame(a.values(), index=a.keys())
a