In [1]:
!pip install tensorboard
!pip install -qqq accelerate==0.28.0
!pip install -qqq transformers==4.48.3
!pip install -qqq datasets==3.6.0 # huggingface's lib.

!pip install -U accelerate

import os

gdown_id = "1i6o6daWuQn59S6Q123ZLyTDh2unZxFgs"
output_zip_file = "train.zip"
extract_dir = "data"
train_data_labels_path = os.path.join(extract_dir, "train", "train_labels.csv") # 파일 경로

if not os.path.exists(output_zip_file):
    print(f"{output_zip_file} 파일이 존재하지 않습니다. 다운로드합니다.")
    !gdown --id {gdown_id} --output {output_zip_file}
else:
    print(f"{output_zip_file} 파일이 이미 존재합니다. 다운로드를 건너뜁니다.")

if not os.path.exists(extract_dir) or not os.path.exists(train_data_labels_path):
    print(f"{extract_dir} 디렉토리 또는 필요한 파일({train_data_labels_path})이 존재하지 않습니다. 압축 해제합니다.")
    if not os.path.exists(extract_dir):
        os.makedirs(extract_dir)
    !unzip -q {output_zip_file} -d {extract_dir}
else:
    print(f"{extract_dir} 디렉토리와 필요한 파일이 이미 존재합니다. 압축 해제를 건너뜁니다.")

train_data_labels = train_data_labels_path
train_image_path = os.path.join(extract_dir, "train", "images") + "/"

zsh:1: no matches found: https://docs.googㅁle.com/uc?id=1e7P8XjrkPSKzIrmjt-zi5ndKxpuG07X8
mkdir: data: File exists
unzip:  cannot find or open train.zip, train.zip.zip or train.zip.ZIP.


In [1]:
train_data_labels = "./data/train/train_labels.csv"
train_image_path = "./data/train/images/"

model_output_path = "./output"

fruit_labels = ["apple", "asian pear", "banana", "cherry", "grape", "pineapple"]
style_labels = ["pencil color", "oil painting", "water color"]

In [2]:
from transformers import AutoImageProcessor
from transformers import Trainer, TrainingArguments
from transformers import ViTModel, PreTrainedModel, ViTConfig
from transformers import AutoConfig
import transformers as tf

from datasets import Dataset, load_dataset
from datasets.features import ClassLabel, Image


import torch.nn as nn
import torch


from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from sklearn.model_selection import train_test_split


import matplotlib.pyplot as plt


import numpy as np


import pandas as pd


import json

In [3]:
print("PyTorch 버전:", torch.__version__)
print("MPS 사용 가능 여부:", torch.backends.mps.is_available())
print("MPS 준비 완료 여부:", torch.backends.mps.is_built())

PyTorch 버전: 2.6.0
MPS 사용 가능 여부: True
MPS 준비 완료 여부: True


In [4]:
def load_data(csv_path):
    df = pd.read_csv(csv_path)

    df['image'] = train_image_path + df['file_name']

    ds = Dataset.from_pandas(df)
    ds = ds.cast_column("image", Image())

    return ds

df = pd.read_csv(train_data_labels)

dataset = load_data(train_data_labels)

In [5]:
print(df.head())
print()
print(dataset)

  file_name  style  fruit
0     0.jpg      0      0
1     1.jpg      0      0
2     2.jpg      0      0
3     3.jpg      0      0
4     4.jpg      0      0

Dataset({
    features: ['file_name', 'style', 'fruit', 'image'],
    num_rows: 7200
})


In [6]:
feature_extractor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

def preprocess_images(examples):
    images = [feature_extractor(image.convert("RGB")) for image in examples["image"]]
    examples['pixel_values'] = [image['pixel_values'][0] for image in images]
    return examples

dataset = dataset.map(preprocess_images, batched=True)
dataset.set_format(type='torch', columns=['image', 'pixel_values', 'fruit', 'style'])

Map:   0%|          | 0/7200 [00:00<?, ? examples/s]

In [15]:
train_val_dataset = dataset.train_test_split(test_size=0.2)

train_dataset = train_val_dataset["train"]
val_dataset = train_val_dataset["test"]

print(train_dataset.features)

{'file_name': Value(dtype='string', id=None), 'style': Value(dtype='int64', id=None), 'fruit': Value(dtype='int64', id=None), 'image': Image(mode=None, decode=True, id=None), 'pixel_values': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None)}


In [54]:
from transformers.modeling_outputs import SequenceClassifierOutput

class MultiTaskViTConfig(ViTConfig):
    def __init__(self, **kwargs):
        self.num_fruit = kwargs.pop("num_fruit", 0)
        self.num_style = kwargs.pop("num_style", 0)
        super().__init__(**kwargs)

class MultiTaskViT(PreTrainedModel):
    config_class = MultiTaskViTConfig

    def __init__(self, config):
        super().__init__(config)
        self.vit = ViTModel(config)

        # 백본의 출력 feature 크기
        hidden_size = self.vit.config.hidden_size

        # 두 개의 독립 분류 헤드
        self.fruit_classifier = nn.Linear(hidden_size, config.num_fruit)
        self.style_classifier = nn.Linear(hidden_size, config.num_style)

    def forward(self, pixel_values, fruit=None, style=None, **kwargs):
        # ViT forward
        outputs = self.vit(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output  # [batch_size, hidden_size]

        # 각 헤드를 통과시켜 logits 생성
        fruit_logits = self.fruit_classifier(pooled_output)
        style_logits = self.style_classifier(pooled_output)

        # optional: loss 계산
        loss = None
        if fruit is not None and style is not None:
            loss_fn = nn.CrossEntropyLoss()
            fruit_loss = loss_fn(fruit_logits, fruit)
            style_loss = loss_fn(style_logits, style)
            loss = fruit_loss + style_loss

        logits = torch.cat([fruit_logits, style_logits], dim=1)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits
        )

In [55]:
train_epoch = 3
learning_rate = 2e-4
batch_size = 16
weight_decay = 0.01

In [56]:
num_fruit = len(set(int(x) for x in dataset["fruit"]))
num_style = len(set(int(x) for x in dataset["style"]))

print("num_fruit: ", num_fruit)
print("num_style: ", num_style)

config = MultiTaskViTConfig.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_fruit=num_fruit,
    num_style=num_style
)

model = MultiTaskViT.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    config=config
)

num_fruit:  6
num_style:  3


Some weights of MultiTaskViT were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['fruit_classifier.bias', 'fruit_classifier.weight', 'style_classifier.bias', 'style_classifier.weight', 'vit.embeddings.cls_token', 'vit.embeddings.patch_embeddings.projection.bias', 'vit.embeddings.patch_embeddings.projection.weight', 'vit.embeddings.position_embeddings', 'vit.encoder.layer.0.attention.attention.key.bias', 'vit.encoder.layer.0.attention.attention.key.weight', 'vit.encoder.layer.0.attention.attention.query.bias', 'vit.encoder.layer.0.attention.attention.query.weight', 'vit.encoder.layer.0.attention.attention.value.bias', 'vit.encoder.layer.0.attention.attention.value.weight', 'vit.encoder.layer.0.attention.output.dense.bias', 'vit.encoder.layer.0.attention.output.dense.weight', 'vit.encoder.layer.0.intermediate.dense.bias', 'vit.encoder.layer.0.intermediate.dense.weight', 'vit.encoder.layer.0.layernorm_after.bias', 'vit.encoder.la

In [57]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./vit_fruit_classification",
    eval_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=train_epoch,
    weight_decay=weight_decay,
    logging_dir='./logs',
    report_to=["tensorboard"],
    label_names=["fruit", "style"],
)

# Function to compute metrics
def compute_metrics(eval_pred):
    logits = eval_pred.predictions

    fruit_logits = logits[:, :num_fruit]
    style_logits = logits[:, num_fruit:]

    fruit_labels, style_labels = eval_pred.label_ids

    fruit_predictions = np.argmax(fruit_logits, axis=-1)
    style_predictions = np.argmax(style_logits, axis=-1)

    fruit_precision, fruit_recall, fruit_f1, _ = precision_recall_fscore_support(fruit_labels, fruit_predictions,
                                                               average='weighted')
    fruit_acc = accuracy_score(fruit_labels, fruit_predictions)

    style_precision, style_recall, style_f1, _ = precision_recall_fscore_support(style_labels, style_predictions,
                                                               average='weighted')
    style_acc = accuracy_score(style_labels, style_predictions)

    return {"fruit_acc": fruit_acc, "fruit_precision": fruit_precision, "fruit_recall": fruit_recall, "fruit_f1": fruit_f1,
            "style_acc": style_acc, "style_precision": style_precision, "style_recall": style_recall, "style_f1": style_f1}

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

In [58]:
trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Epoch,Training Loss,Validation Loss,Fruit Acc,Fruit Precision,Fruit Recall,Fruit F1,Style Acc,Style Precision,Style Recall,Style F1
1,No log,,0.161111,0.025957,0.161111,0.04471,0.328472,0.107894,0.328472,0.162433


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


KeyboardInterrupt: 

In [19]:
torch.save(model.state_dict(), "./vit_fruit_cls")
feature_extractor.save_pretrained("./vit_fruit_cls")

RuntimeError: File ./vit_fruit_cls cannot be opened.

In [15]:
from transformers import pipeline

# Load the pipeline with the model and feature extractor
image_classifier = pipeline('image-classification',
                            model='./vit_fruit_cls',
                            feature_extractor='./vit_fruit_cls')

HFValidationError: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './vit_fruit_cls'.