In [1]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [8]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
dataset = load_dataset("yainage90/fashion-pattern-images")

Resolving data files:   0%|          | 0/1900 [00:00<?, ?it/s]

In [9]:
labels = dataset["train"].features["label"].names
print("Labels:", labels)

Labels: ['argyle', 'camouflage', 'checked', 'dot', 'floral', 'geometric', 'gradient', 'graphic', 'houndstooth', 'leopard', 'lettering', 'muji', 'paisley', 'snake_skin', 'snow_flake', 'stripe', 'tropical', 'zebra', 'zigzag']


In [11]:
from sklearn.model_selection import train_test_split

# Split the train set into train and test
split = dataset["train"].train_test_split(test_size=0.2, seed=42)
trainset = split["train"]
testset = split["test"]
print("Trainset size:", len(trainset))
print("Testset size:", len(testset))


Trainset size: 1520
Testset size: 380


In [None]:
from datasets import load_dataset
from transformers import ViTForImageClassification, ViTImageProcessor, TrainingArguments, Trainer
import torch
from torchvision import transforms
from sklearn.metrics import accuracy_score, f1_score

# Load dataset and split
dataset = load_dataset("yainage90/fashion-pattern-images")
split = dataset["train"].train_test_split(test_size=0.2, seed=42)
trainset = split["train"]
testset = split["test"]

# Get label info
labels = trainset.features["label"].names
num_labels = len(labels)
id2label = {str(i): l for i, l in enumerate(labels)}
label2id = {l: i for i, l in enumerate(labels)}

# Preprocessing (ViT processor + optional augmentation)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

augment = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
])

def transform_examples(batch):
    # Data augmentation only on training set
    if "train" in batch["__split__"][0]:
        images = [augment(img.convert("RGB")) for img in batch["image"]]
    else:
        images = [img.convert("RGB") for img in batch["image"]]
    processed = processor(images=images, return_tensors="pt")
    # Remove batch dimension for each image
    pixel_values = [img for img in processed["pixel_values"]]
    return {
        "pixel_values": pixel_values,
        "labels": batch["label"]
    }

# Add split info for augmentation
trainset = trainset.add_column("__split__", ["train"] * len(trainset))
testset = testset.add_column("__split__", ["test"] * len(testset))

trainset = trainset.map(transform_examples, batched=True, remove_columns=trainset.column_names)
testset = testset.map(transform_examples, batched=True, remove_columns=testset.column_names)

# Model
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id
)

# Data collator
def collate_fn(batch):
    pixel_values = torch.stack([torch.tensor(x["pixel_values"]) for x in batch])
    labels = torch.tensor([x["labels"] for x in batch])
    return {"pixel_values": pixel_values, "labels": labels}

# Metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="weighted"),
    }

# Training arguments
training_args = TrainingArguments(
    output_dir="./vit-pattern",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,  # Try more epochs for better results
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=8,
    fp16=True,
    remove_unused_columns=True,
    report_to="none",
    learning_rate=1e-5,  # Lower learning rate for better fine-tuning
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=trainset,
    eval_dataset=testset,
    data_collator=collate_fn,
    tokenizer=None,
    compute_metrics=compute_metrics,
)

# Train
trainer.train()

# Evaluate
metrics = trainer.evaluate()
print(metrics)

Resolving data files:   0%|          | 0/1900 [00:00<?, ?it/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/950 [00:00<?, ?it/s]

{'loss': 2.9491, 'grad_norm': 1.5425986051559448, 'learning_rate': 9.894736842105264e-06, 'epoch': 0.11}
{'loss': 2.9435, 'grad_norm': 1.5613164901733398, 'learning_rate': 9.789473684210527e-06, 'epoch': 0.21}
{'loss': 2.9213, 'grad_norm': 1.6143027544021606, 'learning_rate': 9.68421052631579e-06, 'epoch': 0.32}
{'loss': 2.9271, 'grad_norm': 1.7591242790222168, 'learning_rate': 9.578947368421054e-06, 'epoch': 0.42}
{'loss': 2.9126, 'grad_norm': 1.7223918437957764, 'learning_rate': 9.473684210526315e-06, 'epoch': 0.53}
{'loss': 2.9204, 'grad_norm': 1.4843759536743164, 'learning_rate': 9.36842105263158e-06, 'epoch': 0.63}
{'loss': 2.9241, 'grad_norm': 1.6439865827560425, 'learning_rate': 9.263157894736842e-06, 'epoch': 0.74}
{'loss': 2.91, 'grad_norm': 1.5305286645889282, 'learning_rate': 9.157894736842105e-06, 'epoch': 0.84}
{'loss': 2.9036, 'grad_norm': 1.703059196472168, 'learning_rate': 9.05263157894737e-06, 'epoch': 0.95}


  0%|          | 0/24 [00:00<?, ?it/s]

{'eval_loss': 2.8934314250946045, 'eval_accuracy': 0.1736842105263158, 'eval_f1': 0.15399213333823342, 'eval_runtime': 126.7072, 'eval_samples_per_second': 2.999, 'eval_steps_per_second': 0.189, 'epoch': 1.0}
{'loss': 2.8582, 'grad_norm': 1.7762049436569214, 'learning_rate': 8.947368421052632e-06, 'epoch': 1.05}
{'loss': 2.8327, 'grad_norm': 1.7582260370254517, 'learning_rate': 8.842105263157895e-06, 'epoch': 1.16}
{'loss': 2.8278, 'grad_norm': 1.68967866897583, 'learning_rate': 8.736842105263158e-06, 'epoch': 1.26}
{'loss': 2.8191, 'grad_norm': 1.7700130939483643, 'learning_rate': 8.631578947368422e-06, 'epoch': 1.37}
{'loss': 2.7909, 'grad_norm': 1.6738005876541138, 'learning_rate': 8.526315789473685e-06, 'epoch': 1.47}
{'loss': 2.8011, 'grad_norm': 1.7791657447814941, 'learning_rate': 8.421052631578948e-06, 'epoch': 1.58}
{'loss': 2.7644, 'grad_norm': 1.8091455698013306, 'learning_rate': 8.315789473684212e-06, 'epoch': 1.68}
{'loss': 2.7703, 'grad_norm': 1.8883768320083618, 'learnin

  0%|          | 0/24 [00:00<?, ?it/s]

{'eval_loss': 2.8095755577087402, 'eval_accuracy': 0.2894736842105263, 'eval_f1': 0.25121845488089367, 'eval_runtime': 124.641, 'eval_samples_per_second': 3.049, 'eval_steps_per_second': 0.193, 'epoch': 2.0}
{'loss': 2.6931, 'grad_norm': 1.941855788230896, 'learning_rate': 7.894736842105265e-06, 'epoch': 2.11}
{'loss': 2.6998, 'grad_norm': 1.9155468940734863, 'learning_rate': 7.789473684210526e-06, 'epoch': 2.21}
{'loss': 2.6807, 'grad_norm': 1.9378896951675415, 'learning_rate': 7.68421052631579e-06, 'epoch': 2.32}
{'loss': 2.645, 'grad_norm': 1.9952360391616821, 'learning_rate': 7.578947368421054e-06, 'epoch': 2.42}
{'loss': 2.6384, 'grad_norm': 1.9649302959442139, 'learning_rate': 7.473684210526316e-06, 'epoch': 2.53}
{'loss': 2.6132, 'grad_norm': 1.9654693603515625, 'learning_rate': 7.368421052631579e-06, 'epoch': 2.63}
{'loss': 2.6476, 'grad_norm': 2.2477834224700928, 'learning_rate': 7.263157894736843e-06, 'epoch': 2.74}
{'loss': 2.6108, 'grad_norm': 1.9668926000595093, 'learning_

  0%|          | 0/24 [00:00<?, ?it/s]

{'eval_loss': 2.713651418685913, 'eval_accuracy': 0.3815789473684211, 'eval_f1': 0.3567176872269258, 'eval_runtime': 69.4733, 'eval_samples_per_second': 5.47, 'eval_steps_per_second': 0.345, 'epoch': 3.0}
{'loss': 2.5511, 'grad_norm': 1.953678011894226, 'learning_rate': 6.947368421052632e-06, 'epoch': 3.05}
{'loss': 2.549, 'grad_norm': 2.0125861167907715, 'learning_rate': 6.842105263157896e-06, 'epoch': 3.16}
{'loss': 2.519, 'grad_norm': 1.9681227207183838, 'learning_rate': 6.736842105263158e-06, 'epoch': 3.26}
{'loss': 2.527, 'grad_norm': 2.1301472187042236, 'learning_rate': 6.631578947368421e-06, 'epoch': 3.37}
{'loss': 2.5103, 'grad_norm': 2.097996234893799, 'learning_rate': 6.526315789473685e-06, 'epoch': 3.47}
{'loss': 2.4873, 'grad_norm': 2.1301376819610596, 'learning_rate': 6.421052631578948e-06, 'epoch': 3.58}
{'loss': 2.5112, 'grad_norm': 2.075584650039673, 'learning_rate': 6.31578947368421e-06, 'epoch': 3.68}
{'loss': 2.471, 'grad_norm': 2.2883036136627197, 'learning_rate': 6

  0%|          | 0/24 [00:00<?, ?it/s]

{'eval_loss': 2.611682653427124, 'eval_accuracy': 0.4789473684210526, 'eval_f1': 0.4624292456146372, 'eval_runtime': 125.613, 'eval_samples_per_second': 3.025, 'eval_steps_per_second': 0.191, 'epoch': 4.0}
{'loss': 2.3912, 'grad_norm': 2.0000555515289307, 'learning_rate': 5.8947368421052634e-06, 'epoch': 4.11}
{'loss': 2.3837, 'grad_norm': 2.0142300128936768, 'learning_rate': 5.789473684210527e-06, 'epoch': 4.21}


KeyboardInterrupt: 