In [1]:
from datasets import load_dataset
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
import torch
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("imagefolder", data_dir="../data/img/vit")
dataset

Resolving data files: 100%|██████████| 6910/6910 [00:00<00:00, 158374.22it/s]
Resolving data files: 100%|██████████| 1728/1728 [00:00<00:00, 183929.89it/s]


DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 6910
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1728
    })
})

In [3]:
print(f"Training samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")

Training samples: 6910
Validation samples: 1728


In [4]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'

In [5]:
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [6]:
def transform(batch):
  # Take a list of PIL images and turn them to pixel values
  inputs = processor([x for x in batch['image']], return_tensors="pt")

  # Include labels
  inputs['labels'] = batch['label']
  return inputs

In [7]:
train = dataset["train"].with_transform(transform)
validation = dataset["validation"].with_transform(transform)

In [8]:
def collate_fn(batch):
  return {
    'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
    'labels': torch.tensor([x['labels'] for x in batch])
  }

In [9]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
  logits, labels = eval_pred
  predictions = np.argmax(logits, axis=-1)
  return metric.compute(predictions=predictions, references=labels)

In [10]:
model = ViTForImageClassification.from_pretrained(model_name_or_path, num_labels=2)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
epochs = 10
warmup_steps = 100
weight_decay = 0.01

training_args = TrainingArguments(
  output_dir='../results',
  num_train_epochs=epochs,
  per_device_train_batch_size=24,
  per_device_eval_batch_size=12,
  evaluation_strategy="epoch",
  warmup_steps=warmup_steps,
  weight_decay=weight_decay,
  logging_dir='../logs',
  remove_unused_columns=False
)

In [12]:
trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=train,
  eval_dataset=validation,
  tokenizer=processor,
  data_collator=collate_fn,
  compute_metrics=compute_metrics
)

In [13]:
trainer.train()

                                                    
 10%|█         | 288/2880 [06:32<56:14,  1.30s/it]

{'eval_loss': 0.10311707109212875, 'eval_accuracy': 0.9652777777777778, 'eval_runtime': 31.5663, 'eval_samples_per_second': 54.742, 'eval_steps_per_second': 4.562, 'epoch': 1.0}


 17%|█▋        | 500/2880 [10:50<46:48,  1.18s/it]  

{'loss': 0.1779, 'learning_rate': 4.280575539568346e-05, 'epoch': 1.74}


                                                    
 20%|██        | 576/2880 [12:54<44:55,  1.17s/it]

{'eval_loss': 0.06288963556289673, 'eval_accuracy': 0.9814814814814815, 'eval_runtime': 29.2092, 'eval_samples_per_second': 59.16, 'eval_steps_per_second': 4.93, 'epoch': 2.0}


                                                    
 30%|███       | 864/2880 [19:04<37:09,  1.11s/it]

{'eval_loss': 0.09144069254398346, 'eval_accuracy': 0.9756944444444444, 'eval_runtime': 29.1705, 'eval_samples_per_second': 59.238, 'eval_steps_per_second': 4.936, 'epoch': 3.0}


 35%|███▍      | 1000/2880 [21:45<36:41,  1.17s/it] 

{'loss': 0.0259, 'learning_rate': 3.3812949640287773e-05, 'epoch': 3.47}


                                                   
 40%|████      | 1152/2880 [25:18<32:34,  1.13s/it]

{'eval_loss': 0.07018504291772842, 'eval_accuracy': 0.9820601851851852, 'eval_runtime': 29.7057, 'eval_samples_per_second': 58.171, 'eval_steps_per_second': 4.848, 'epoch': 4.0}


                                                     
 50%|█████     | 1440/2880 [31:39<27:05,  1.13s/it]

{'eval_loss': 0.06399541348218918, 'eval_accuracy': 0.9849537037037037, 'eval_runtime': 29.3115, 'eval_samples_per_second': 58.953, 'eval_steps_per_second': 4.913, 'epoch': 5.0}


 52%|█████▏    | 1500/2880 [32:51<28:18,  1.23s/it]  

{'loss': 0.0063, 'learning_rate': 2.482014388489209e-05, 'epoch': 5.21}


                                                   
 60%|██████    | 1728/2880 [37:58<21:54,  1.14s/it]

{'eval_loss': 0.07573655247688293, 'eval_accuracy': 0.984375, 'eval_runtime': 29.4315, 'eval_samples_per_second': 58.713, 'eval_steps_per_second': 4.893, 'epoch': 6.0}


 69%|██████▉   | 2000/2880 [43:25<17:11,  1.17s/it]  

{'loss': 0.0011, 'learning_rate': 1.5827338129496403e-05, 'epoch': 6.94}


                                                   
 70%|███████   | 2016/2880 [44:14<16:10,  1.12s/it]

{'eval_loss': 0.07601729035377502, 'eval_accuracy': 0.9855324074074074, 'eval_runtime': 29.0312, 'eval_samples_per_second': 59.522, 'eval_steps_per_second': 4.96, 'epoch': 7.0}


                                                     
 80%|████████  | 2304/2880 [50:21<10:41,  1.11s/it]

{'eval_loss': 0.07795415073633194, 'eval_accuracy': 0.9855324074074074, 'eval_runtime': 29.0759, 'eval_samples_per_second': 59.431, 'eval_steps_per_second': 4.953, 'epoch': 8.0}


 87%|████████▋ | 2500/2880 [54:10<07:22,  1.16s/it]  

{'loss': 0.0008, 'learning_rate': 6.83453237410072e-06, 'epoch': 8.68}


                                                   
 90%|█████████ | 2592/2880 [56:27<05:22,  1.12s/it]

{'eval_loss': 0.07931925356388092, 'eval_accuracy': 0.9855324074074074, 'eval_runtime': 28.8256, 'eval_samples_per_second': 59.947, 'eval_steps_per_second': 4.996, 'epoch': 9.0}


                                                     
100%|██████████| 2880/2880 [1:02:33<00:00,  1.30s/it]

{'eval_loss': 0.07990724593400955, 'eval_accuracy': 0.9855324074074074, 'eval_runtime': 28.8007, 'eval_samples_per_second': 59.998, 'eval_steps_per_second': 5.0, 'epoch': 10.0}
{'train_runtime': 3753.6133, 'train_samples_per_second': 18.409, 'train_steps_per_second': 0.767, 'train_loss': 0.036923396742592256, 'epoch': 10.0}





TrainOutput(global_step=2880, training_loss=0.036923396742592256, metrics={'train_runtime': 3753.6133, 'train_samples_per_second': 18.409, 'train_steps_per_second': 0.767, 'train_loss': 0.036923396742592256, 'epoch': 10.0})

In [14]:
trainer.evaluate()

100%|██████████| 144/144 [00:28<00:00,  5.03it/s]


{'eval_loss': 0.07990724593400955,
 'eval_accuracy': 0.9855324074074074,
 'eval_runtime': 29.7362,
 'eval_samples_per_second': 58.111,
 'eval_steps_per_second': 4.843,
 'epoch': 10.0}

In [15]:
trainer.save_model("../models/vit_binary_classifier_20231205")

In [16]:
model = ViTForImageClassification.from_pretrained("../models/vit_binary_classifier_20231205")

In [18]:
i = -103
input = torch.tensor(processor(dataset['validation'][i]['image'], return_tensors="pt")['pixel_values'])

outputs = model(input)
predictions = outputs.logits.argmax(dim=-1)
predictions

  input = torch.tensor(processor(dataset['validation'][i]['image'], return_tensors="pt")['pixel_values'])


tensor([1])