In [6]:
from datasets import load_dataset
food = load_dataset("food101", split="train[:5000]")

In [7]:
food = food.train_test_split(test_size=0.2)

In [8]:
food["train"][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
 'label': 6}

In [9]:
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
    

In [10]:
from transformers import AutoImageProcessor

checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [11]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [12]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples
food = food.with_transform(transforms)

평가

In [13]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [14]:
import evaluate

accuracy = evaluate.load("accuracy")

In [15]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

훈련

In [16]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
training_args = TrainingArguments(
    output_dir="food101_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=food["train"],
    eval_dataset=food["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()

  0%|          | 0/186 [51:01<?, ?it/s]
  5%|▌         | 10/186 [09:14<2:42:56, 55.55s/it]

{'loss': 4.5869, 'learning_rate': 2.6315789473684212e-05, 'epoch': 0.16}


 11%|█         | 20/186 [18:21<2:30:49, 54.52s/it]

{'loss': 4.2936, 'learning_rate': 4.970059880239521e-05, 'epoch': 0.32}


 16%|█▌        | 30/186 [27:36<2:21:28, 54.41s/it]

{'loss': 3.7253, 'learning_rate': 4.670658682634731e-05, 'epoch': 0.48}


 22%|██▏       | 40/186 [36:28<2:09:41, 53.30s/it]

{'loss': 3.3021, 'learning_rate': 4.3712574850299406e-05, 'epoch': 0.64}


 27%|██▋       | 50/186 [45:22<2:00:44, 53.27s/it]

{'loss': 2.9887, 'learning_rate': 4.07185628742515e-05, 'epoch': 0.8}


 32%|███▏      | 60/186 [54:12<1:51:30, 53.10s/it]

{'loss': 2.7254, 'learning_rate': 3.7724550898203595e-05, 'epoch': 0.96}


                                                  
 33%|███▎      | 62/186 [1:01:11<1:49:21, 52.91s/it]

{'eval_loss': 2.5522100925445557, 'eval_accuracy': 0.805, 'eval_runtime': 287.1393, 'eval_samples_per_second': 3.483, 'eval_steps_per_second': 0.219, 'epoch': 0.99}


 38%|███▊      | 70/186 [1:07:52<1:56:50, 60.44s/it] 

{'loss': 2.4883, 'learning_rate': 3.473053892215569e-05, 'epoch': 1.12}


 43%|████▎     | 80/186 [1:16:47<1:35:02, 53.80s/it]

{'loss': 2.2694, 'learning_rate': 3.1736526946107784e-05, 'epoch': 1.28}


 48%|████▊     | 90/186 [1:25:40<1:25:11, 53.25s/it]

{'loss': 2.1444, 'learning_rate': 2.874251497005988e-05, 'epoch': 1.44}


 54%|█████▍    | 100/186 [1:34:39<1:17:01, 53.73s/it]

{'loss': 2.0137, 'learning_rate': 2.5748502994011976e-05, 'epoch': 1.6}


 59%|█████▉    | 110/186 [1:43:33<1:07:50, 53.56s/it]

{'loss': 1.9457, 'learning_rate': 2.275449101796407e-05, 'epoch': 1.76}


 65%|██████▍   | 120/186 [1:52:31<59:05, 53.71s/it]  

{'loss': 1.799, 'learning_rate': 1.9760479041916168e-05, 'epoch': 1.92}


                                                   
 67%|██████▋   | 125/186 [2:02:05<55:12, 54.31s/it]

{'eval_loss': 1.779510498046875, 'eval_accuracy': 0.853, 'eval_runtime': 303.8578, 'eval_samples_per_second': 3.291, 'eval_steps_per_second': 0.207, 'epoch': 2.0}


 70%|██████▉   | 130/186 [2:06:47<1:12:33, 77.73s/it] 

{'loss': 1.752, 'learning_rate': 1.6766467065868263e-05, 'epoch': 2.08}


 75%|███████▌  | 140/186 [2:16:04<43:29, 56.72s/it]  

{'loss': 1.6756, 'learning_rate': 1.377245508982036e-05, 'epoch': 2.24}


 81%|████████  | 150/186 [2:25:22<33:36, 56.03s/it]

{'loss': 1.6643, 'learning_rate': 1.0778443113772455e-05, 'epoch': 2.4}


 86%|████████▌ | 160/186 [2:34:42<24:08, 55.70s/it]

{'loss': 1.6553, 'learning_rate': 7.784431137724551e-06, 'epoch': 2.56}


 91%|█████████▏| 170/186 [2:44:01<14:53, 55.82s/it]

{'loss': 1.6109, 'learning_rate': 4.7904191616766475e-06, 'epoch': 2.72}


 97%|█████████▋| 180/186 [2:53:22<05:40, 56.68s/it]

{'loss': 1.5682, 'learning_rate': 1.7964071856287426e-06, 'epoch': 2.88}


                                                   
100%|██████████| 186/186 [3:03:49<00:00, 55.13s/it]

{'eval_loss': 1.626529335975647, 'eval_accuracy': 0.872, 'eval_runtime': 295.5109, 'eval_samples_per_second': 3.384, 'eval_steps_per_second': 0.213, 'epoch': 2.98}


100%|██████████| 186/186 [3:03:50<00:00, 59.30s/it]

{'train_runtime': 11030.205, 'train_samples_per_second': 1.088, 'train_steps_per_second': 0.017, 'train_loss': 2.429398859700849, 'epoch': 2.98}





TrainOutput(global_step=186, training_loss=2.429398859700849, metrics={'train_runtime': 11030.205, 'train_samples_per_second': 1.088, 'train_steps_per_second': 0.017, 'train_loss': 2.429398859700849, 'epoch': 2.98})

추론

In [20]:
ds = load_dataset("food101", split="validation[:10]")
image = ds["image"][0]

In [21]:
from transformers import pipeline

classifier = pipeline("image-classification", model="my_awesome_food_model")
classifier(image)

OSError: my_awesome_food_model does not appear to have a file named config.json. Checkout 'https://huggingface.co/my_awesome_food_model/None' for available files.

In [None]:
from transformers import AutoImageProcessor
import torch

image_processor = AutoImageProcessor.from_pretrained("my_awesome_food_model")
inputs = image_processor(image, return_tensors="pt")

In [None]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained("my_awesome_food_model")
with torch.no_grad():
    logits = model(**inputs).logits

In [None]:
predicted_label = logits.argmax(-1).item()
model.config.id2label[predicted_label]