In [1]:
# Code referenced from https://huggingface.co/docs/transformers/tasks/image_classification
import torch
from PIL import Image
from transformers import (
    AutoImageProcessor,
    ViTForImageClassification,
    AutoModelForImageClassification,
    DefaultDataCollator,
    TrainingArguments,
    Trainer,

)
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from datasets import load_dataset, get_dataset_split_names, ClassLabel, Sequence
import evaluate
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load dataset

dataset = load_dataset("ashraq/fashion-product-images-small", split="train")
dataset = dataset.train_test_split(test_size=0.2)

In [3]:
print(dataset['train'].features)
dataset = dataset.remove_columns(["id", "gender", "masterCategory", "articleType", "baseColour", "season", "year", "usage", "productDisplayName"])
dataset = dataset.rename_column("subCategory", "label")
dataset = dataset.class_encode_column("label")
print(dataset['train'].features)
dataset.remove_columns("label")
dataset["train"][0]

{'id': Value(dtype='int64', id=None), 'gender': Value(dtype='string', id=None), 'masterCategory': Value(dtype='string', id=None), 'subCategory': Value(dtype='string', id=None), 'articleType': Value(dtype='string', id=None), 'baseColour': Value(dtype='string', id=None), 'season': Value(dtype='string', id=None), 'year': Value(dtype='float64', id=None), 'usage': Value(dtype='string', id=None), 'productDisplayName': Value(dtype='string', id=None), 'image': Image(decode=True, id=None)}


Flattening the indices:   0%|          | 0/35257 [00:00<?, ? examples/s]

Flattening the indices: 100%|██████████| 35257/35257 [00:02<00:00, 16746.01 examples/s]
Casting to class labels: 100%|██████████| 35257/35257 [00:01<00:00, 17892.49 examples/s]
Flattening the indices: 100%|██████████| 8815/8815 [00:00<00:00, 18240.43 examples/s]
Casting to class labels: 100%|██████████| 8815/8815 [00:00<00:00, 17805.88 examples/s]


{'label': ClassLabel(names=['Accessories', 'Apparel Set', 'Bags', 'Bath and Body', 'Beauty Accessories', 'Belts', 'Bottomwear', 'Cufflinks', 'Dress', 'Eyes', 'Eyewear', 'Flip Flops', 'Fragrance', 'Free Gifts', 'Gloves', 'Hair', 'Headwear', 'Home Furnishing', 'Innerwear', 'Jewellery', 'Lips', 'Loungewear and Nightwear', 'Makeup', 'Mufflers', 'Nails', 'Perfumes', 'Sandal', 'Saree', 'Scarves', 'Shoe Accessories', 'Shoes', 'Skin', 'Skin Care', 'Socks', 'Sports Accessories', 'Sports Equipment', 'Stoles', 'Ties', 'Topwear', 'Umbrellas', 'Vouchers', 'Wallets', 'Watches', 'Water Bottle', 'Wristbands'], id=None), 'image': Image(decode=True, id=None)}


{'label': 2,
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=60x80>}

In [4]:
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [5]:
# Load image processor
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [6]:

# Preprocess dataset

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])



In [7]:
def transforms(dataset):
    dataset["pixel_values"] = [
        _transforms(img.convert("RGB")) for img in dataset["image"]
    ]
    del dataset["image"]
    return dataset

In [8]:
dataset = dataset.with_transform(transforms)

In [9]:
data_collator = DefaultDataCollator()

In [10]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [11]:
# Train
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
training_args = TrainingArguments(
    output_dir="clothing_category_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

In [13]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

In [14]:
trainer.train()

  2%|▏         | 10/551 [08:31<7:43:36, 51.42s/it]

{'loss': 3.7589, 'learning_rate': 8.92857142857143e-06, 'epoch': 0.02}


  4%|▎         | 20/551 [17:00<7:30:27, 50.90s/it]

{'loss': 3.6653, 'learning_rate': 1.785714285714286e-05, 'epoch': 0.04}


  5%|▌         | 30/551 [25:24<7:19:52, 50.66s/it]

{'loss': 3.4073, 'learning_rate': 2.6785714285714288e-05, 'epoch': 0.05}


  7%|▋         | 40/551 [33:52<7:14:30, 51.02s/it]

{'loss': 3.0555, 'learning_rate': 3.571428571428572e-05, 'epoch': 0.07}


  9%|▉         | 50/551 [42:30<7:13:31, 51.92s/it]

{'loss': 2.7588, 'learning_rate': 4.464285714285715e-05, 'epoch': 0.09}


 11%|█         | 60/551 [51:08<7:03:17, 51.73s/it]

{'loss': 2.4491, 'learning_rate': 4.9595959595959594e-05, 'epoch': 0.11}


 13%|█▎        | 70/551 [59:55<6:59:08, 52.28s/it]

{'loss': 2.2463, 'learning_rate': 4.858585858585859e-05, 'epoch': 0.13}


 15%|█▍        | 80/551 [1:08:33<6:41:43, 51.17s/it]

{'loss': 2.1231, 'learning_rate': 4.7575757575757576e-05, 'epoch': 0.15}


 16%|█▋        | 90/551 [1:17:03<6:31:15, 50.92s/it]

{'loss': 1.8197, 'learning_rate': 4.656565656565657e-05, 'epoch': 0.16}


 18%|█▊        | 100/551 [1:25:34<6:28:17, 51.66s/it]

{'loss': 1.7809, 'learning_rate': 4.555555555555556e-05, 'epoch': 0.18}


 19%|█▊        | 102/551 [1:27:17<6:26:51, 51.70s/it]

KeyboardInterrupt: 