In [1]:
cd ..

d:\project\cls_product


In [None]:
from datasets import load_dataset

food = load_dataset("food101",split="train[:5000]")

In [None]:
food_split = food.train_test_split(0.2)

In [None]:
labels = food_split['train'].features['label'].names
label2id, id2label = {}, {}
for i , label in enumerate(labels):
    label2id[label] = i,
    id2label[id] = label

In [None]:
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer

processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")

In [None]:
size = processor.size

In [None]:
from torchvision.transforms import RandomResizedCrop ,Compose ,Resize ,Normalize ,ToTensor
transform = Compose([
    RandomResizedCrop(size['shortest_edge']),
    ToTensor(),
    Normalize(mean=processor.image_mean,std=processor.image_std)
])

In [None]:
from PIL import Image

def process_image(samples):
    samples['pixel_values'] = [transform(img.convert('RGB')) for img in samples['image']]
    del samples['image']
    return samples

In [None]:
food_split_proc = food_split.with_transform(process_image)

In [None]:
from transformers import AutoModelForImageClassification ,TrainingArguments ,Trainer,default_data_collator
import evaluate
import numpy as np
data_collator = default_data_collator
metric = evaluate.load('accuracy')

In [None]:
def compute_metrics(predictions_and_labels):
    predictions,labels = predictions_and_labels
    predictions = np.argmax(predictions,axis=-1)
    return metric.compute(predictions=predictions,references=labels)

In [None]:
model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',num_labels=len(labels),id2label=id2label,label2id=label2id)

In [None]:
training_args = TrainingArguments(
    # output_dir="my_model",
    output_dir="./results",
    report_to=[],
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy="epoch",
    weight_decay=0.01,
    remove_unused_columns=False,
    save_strategy="epoch",
    learning_rate=0.001,
    num_train_epochs=7,
    logging_steps=10,
    push_to_hub=False,
    metric_for_best_model="accuracy"
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    train_dataset=food_split_proc['train'],
    eval_dataset=food_split_proc['test'],
    processing_class=processor
)

In [None]:
trainer.train()