In [None]:
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import Trainer, TrainingArguments, ViTForImageClassification, ViTImageProcessor

In [None]:
train = "data/training/brands-classification/train"
val = "data/training/brands-classification/val"
test = "data/training/brands-classification/test"

In [None]:
model_name_or_path = "google/vit-base-patch16-224-in21k"
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

In [None]:
dataset_train = load_dataset("imagefolder", data_dir=train)["train"]
dataset_val = load_dataset("imagefolder", data_dir=val)["train"]
dataset_test = load_dataset("imagefolder", data_dir=test)["train"]

Resolving data files:   0%|          | 0/3180 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1056 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1073 [00:00<?, ?it/s]

In [None]:
def process_example(example):
    inputs = processor(example["image"], return_tensors="pt")
    inputs["labels"] = example["label"]
    return inputs

In [None]:
def transform(example_batch):
    inputs = processor([x for x in example_batch["image"]], return_tensors="pt")
    inputs["labels"] = example_batch["label"]
    return inputs

In [None]:
processed_train = dataset_train.with_transform(transform)
processed_val = dataset_val.with_transform(transform)
processed_test = dataset_test.with_transform(transform)

In [None]:
def collate_fn(batch):
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "labels": torch.tensor([x["labels"] for x in batch]),
    }

In [None]:
metric = load_metric("f1")


def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids, average="macro")

  metric = load_metric("f1")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [None]:
labels = dataset_train.features["label"].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
torch.set_float32_matmul_precision("medium")
training_args = TrainingArguments(
    output_dir="./vit-test-haha",
    per_device_train_batch_size=16,
    eval_strategy="steps",
    num_train_epochs=4,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to=None,
    load_best_model_at_end=True,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=processed_train,
    eval_dataset=processed_val,
    tokenizer=processor,
)

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mseara[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/796 [00:00<?, ?it/s]

{'loss': 2.2907, 'grad_norm': 1.9101554155349731, 'learning_rate': 0.00019748743718592964, 'epoch': 0.05}
{'loss': 2.1433, 'grad_norm': 1.2307591438293457, 'learning_rate': 0.0001949748743718593, 'epoch': 0.1}
{'loss': 2.1865, 'grad_norm': 1.8704218864440918, 'learning_rate': 0.00019246231155778894, 'epoch': 0.15}
{'loss': 1.9985, 'grad_norm': 1.9788798093795776, 'learning_rate': 0.0001899497487437186, 'epoch': 0.2}
{'loss': 1.8115, 'grad_norm': 2.2659542560577393, 'learning_rate': 0.00018743718592964824, 'epoch': 0.25}
{'loss': 1.683, 'grad_norm': 1.8437142372131348, 'learning_rate': 0.0001849246231155779, 'epoch': 0.3}
{'loss': 1.4762, 'grad_norm': 1.966007947921753, 'learning_rate': 0.00018241206030150754, 'epoch': 0.35}
{'loss': 1.5186, 'grad_norm': 2.184659004211426, 'learning_rate': 0.0001798994974874372, 'epoch': 0.4}
{'loss': 1.3483, 'grad_norm': 2.2431647777557373, 'learning_rate': 0.00017738693467336683, 'epoch': 0.45}
{'loss': 1.4529, 'grad_norm': 1.721116065979004, 'learnin

  0%|          | 0/132 [00:00<?, ?it/s]

{'eval_loss': 1.3831318616867065, 'eval_f1': 0.3315470427743947, 'eval_runtime': 38.7593, 'eval_samples_per_second': 27.245, 'eval_steps_per_second': 3.406, 'epoch': 0.5}
{'loss': 1.4171, 'grad_norm': 3.016906499862671, 'learning_rate': 0.00017236180904522613, 'epoch': 0.55}
{'loss': 1.4334, 'grad_norm': 2.3731658458709717, 'learning_rate': 0.0001698492462311558, 'epoch': 0.6}
{'loss': 1.1782, 'grad_norm': 2.233245372772217, 'learning_rate': 0.00016733668341708543, 'epoch': 0.65}
{'loss': 1.3665, 'grad_norm': 2.1476986408233643, 'learning_rate': 0.0001648241206030151, 'epoch': 0.7}
{'loss': 1.3651, 'grad_norm': 3.164194345474243, 'learning_rate': 0.00016231155778894472, 'epoch': 0.75}
{'loss': 1.055, 'grad_norm': 1.9065033197402954, 'learning_rate': 0.00015979899497487439, 'epoch': 0.8}
{'loss': 1.0722, 'grad_norm': 2.5833590030670166, 'learning_rate': 0.00015728643216080402, 'epoch': 0.85}
{'loss': 1.103, 'grad_norm': 2.480142831802368, 'learning_rate': 0.00015477386934673368, 'epoch'

  0%|          | 0/132 [00:00<?, ?it/s]

{'eval_loss': 0.9603229761123657, 'eval_f1': 0.5694931352073987, 'eval_runtime': 37.505, 'eval_samples_per_second': 28.156, 'eval_steps_per_second': 3.52, 'epoch': 1.01}
{'loss': 0.742, 'grad_norm': 3.5974786281585693, 'learning_rate': 0.00014723618090452262, 'epoch': 1.06}
{'loss': 0.7475, 'grad_norm': 3.379312753677368, 'learning_rate': 0.00014472361809045228, 'epoch': 1.11}
{'loss': 0.8044, 'grad_norm': 1.1737242937088013, 'learning_rate': 0.0001422110552763819, 'epoch': 1.16}
{'loss': 0.7112, 'grad_norm': 2.8028833866119385, 'learning_rate': 0.00013969849246231157, 'epoch': 1.21}
{'loss': 0.6777, 'grad_norm': 1.318703532218933, 'learning_rate': 0.0001371859296482412, 'epoch': 1.26}
{'loss': 0.6579, 'grad_norm': 7.204718112945557, 'learning_rate': 0.00013467336683417087, 'epoch': 1.31}
{'loss': 0.6951, 'grad_norm': 4.691657066345215, 'learning_rate': 0.0001321608040201005, 'epoch': 1.36}
{'loss': 0.5314, 'grad_norm': 3.737499237060547, 'learning_rate': 0.00012964824120603017, 'epoch

  0%|          | 0/132 [00:00<?, ?it/s]

{'eval_loss': 0.7611491084098816, 'eval_f1': 0.6781339812075372, 'eval_runtime': 38.759, 'eval_samples_per_second': 27.245, 'eval_steps_per_second': 3.406, 'epoch': 1.51}
{'loss': 0.5227, 'grad_norm': 0.8119794726371765, 'learning_rate': 0.0001221105527638191, 'epoch': 1.56}
{'loss': 0.6224, 'grad_norm': 3.787081003189087, 'learning_rate': 0.00011959798994974876, 'epoch': 1.61}
{'loss': 0.5572, 'grad_norm': 2.158644676208496, 'learning_rate': 0.00011708542713567841, 'epoch': 1.66}
{'loss': 0.4889, 'grad_norm': 3.384948492050171, 'learning_rate': 0.00011457286432160806, 'epoch': 1.71}
{'loss': 0.4804, 'grad_norm': 3.0189106464385986, 'learning_rate': 0.00011206030150753771, 'epoch': 1.76}
{'loss': 0.6797, 'grad_norm': 2.7612946033477783, 'learning_rate': 0.00010954773869346736, 'epoch': 1.81}
{'loss': 0.5157, 'grad_norm': 2.9752326011657715, 'learning_rate': 0.00010703517587939699, 'epoch': 1.86}
{'loss': 0.5046, 'grad_norm': 3.178652286529541, 'learning_rate': 0.00010452261306532664, '

  0%|          | 0/132 [00:00<?, ?it/s]

{'eval_loss': 0.7172240018844604, 'eval_f1': 0.7165990905139457, 'eval_runtime': 38.4642, 'eval_samples_per_second': 27.454, 'eval_steps_per_second': 3.432, 'epoch': 2.01}
{'loss': 0.3239, 'grad_norm': 2.813413619995117, 'learning_rate': 9.698492462311559e-05, 'epoch': 2.06}
{'loss': 0.2804, 'grad_norm': 2.519469738006592, 'learning_rate': 9.447236180904523e-05, 'epoch': 2.11}
{'loss': 0.2292, 'grad_norm': 0.30940479040145874, 'learning_rate': 9.195979899497488e-05, 'epoch': 2.16}
{'loss': 0.2202, 'grad_norm': 1.8282064199447632, 'learning_rate': 8.944723618090453e-05, 'epoch': 2.21}
{'loss': 0.2516, 'grad_norm': 1.2169932126998901, 'learning_rate': 8.693467336683418e-05, 'epoch': 2.26}
{'loss': 0.262, 'grad_norm': 0.6905980110168457, 'learning_rate': 8.442211055276383e-05, 'epoch': 2.31}
{'loss': 0.2509, 'grad_norm': 0.9855825901031494, 'learning_rate': 8.190954773869348e-05, 'epoch': 2.36}
{'loss': 0.1775, 'grad_norm': 5.226966381072998, 'learning_rate': 7.939698492462313e-05, 'epoch

  0%|          | 0/132 [00:00<?, ?it/s]

{'eval_loss': 0.5569721460342407, 'eval_f1': 0.8077620632678191, 'eval_runtime': 38.5846, 'eval_samples_per_second': 27.368, 'eval_steps_per_second': 3.421, 'epoch': 2.51}
{'loss': 0.1277, 'grad_norm': 0.9726919531822205, 'learning_rate': 7.185929648241206e-05, 'epoch': 2.56}
{'loss': 0.2243, 'grad_norm': 0.2577843964099884, 'learning_rate': 6.93467336683417e-05, 'epoch': 2.61}
{'loss': 0.1326, 'grad_norm': 0.9168246388435364, 'learning_rate': 6.683417085427135e-05, 'epoch': 2.66}
{'loss': 0.1234, 'grad_norm': 0.44683143496513367, 'learning_rate': 6.4321608040201e-05, 'epoch': 2.71}
{'loss': 0.2775, 'grad_norm': 2.5073421001434326, 'learning_rate': 6.180904522613065e-05, 'epoch': 2.76}
{'loss': 0.1137, 'grad_norm': 0.28586140275001526, 'learning_rate': 5.929648241206031e-05, 'epoch': 2.81}
{'loss': 0.1458, 'grad_norm': 1.7539912462234497, 'learning_rate': 5.6783919597989955e-05, 'epoch': 2.86}
{'loss': 0.1662, 'grad_norm': 1.5686002969741821, 'learning_rate': 5.4271356783919604e-05, 'e

  0%|          | 0/132 [00:00<?, ?it/s]

{'eval_loss': 0.4873565435409546, 'eval_f1': 0.8393693805239075, 'eval_runtime': 37.1116, 'eval_samples_per_second': 28.455, 'eval_steps_per_second': 3.557, 'epoch': 3.02}
{'loss': 0.1083, 'grad_norm': 1.5319398641586304, 'learning_rate': 4.673366834170855e-05, 'epoch': 3.07}
{'loss': 0.0622, 'grad_norm': 1.3283040523529053, 'learning_rate': 4.42211055276382e-05, 'epoch': 3.12}
{'loss': 0.0701, 'grad_norm': 0.9043460488319397, 'learning_rate': 4.170854271356784e-05, 'epoch': 3.17}
{'loss': 0.0535, 'grad_norm': 0.1634153425693512, 'learning_rate': 3.919597989949749e-05, 'epoch': 3.22}
{'loss': 0.0504, 'grad_norm': 0.8389099836349487, 'learning_rate': 3.668341708542714e-05, 'epoch': 3.27}
{'loss': 0.0586, 'grad_norm': 0.7816864848136902, 'learning_rate': 3.4170854271356785e-05, 'epoch': 3.32}
{'loss': 0.0496, 'grad_norm': 0.350990355014801, 'learning_rate': 3.1658291457286434e-05, 'epoch': 3.37}
{'loss': 0.0699, 'grad_norm': 0.15425150096416473, 'learning_rate': 2.914572864321608e-05, 'e

  0%|          | 0/132 [00:00<?, ?it/s]

{'eval_loss': 0.4826359748840332, 'eval_f1': 0.8492492861392744, 'eval_runtime': 38.9998, 'eval_samples_per_second': 27.077, 'eval_steps_per_second': 3.385, 'epoch': 3.52}
{'loss': 0.056, 'grad_norm': 0.11402419209480286, 'learning_rate': 2.1608040201005025e-05, 'epoch': 3.57}
{'loss': 0.0343, 'grad_norm': 0.27066585421562195, 'learning_rate': 1.9095477386934673e-05, 'epoch': 3.62}
{'loss': 0.0689, 'grad_norm': 0.12199786305427551, 'learning_rate': 1.6582914572864322e-05, 'epoch': 3.67}
{'loss': 0.0853, 'grad_norm': 1.1182750463485718, 'learning_rate': 1.407035175879397e-05, 'epoch': 3.72}
{'loss': 0.0349, 'grad_norm': 0.11539046466350555, 'learning_rate': 1.1557788944723619e-05, 'epoch': 3.77}
{'loss': 0.0327, 'grad_norm': 0.09080767631530762, 'learning_rate': 9.045226130653267e-06, 'epoch': 3.82}
{'loss': 0.0352, 'grad_norm': 0.07864534109830856, 'learning_rate': 6.532663316582915e-06, 'epoch': 3.87}
{'loss': 0.0417, 'grad_norm': 0.10149459540843964, 'learning_rate': 4.02010050251256

In [None]:
metrics = trainer.evaluate(processed_test)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

  0%|          | 0/135 [00:00<?, ?it/s]

{'eval_loss': 0.4186634421348572, 'eval_f1': 0.8608476791988378, 'eval_runtime': 67.1903, 'eval_samples_per_second': 15.97, 'eval_steps_per_second': 2.009, 'epoch': 0.01}
***** eval metrics *****
  epoch                   =     0.0101
  eval_f1                 =     0.8608
  eval_loss               =     0.4187
  eval_runtime            = 0:01:07.19
  eval_samples_per_second =      15.97
  eval_steps_per_second   =      2.009
