In [None]:
from datasets import load_dataset, Image
"""
.venv/Scripts/activate

python -m image_process
"""
base_output_dir = f"models/may13_VIT1"
dataset = load_dataset("potato_train/train")
filenames_ds = load_dataset("potato_train/train").cast_column("image", Image(decode=False))

filename_col = [x['image']['path'].split('\\')[-1] for x in filenames_ds['train']]
dataset['train'] = dataset['train'].add_column("filename", filename_col)

print(dataset['train'][0])
base_output_dir

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1500x1500 at 0x21EA5823B90>, 'label': 0, 'filename': 'b0.jpeg'}


'models/may13_VIT1'

In [20]:
from transformers import ViTImageProcessor

# import model
model_id = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTImageProcessor.from_pretrained(
    model_id
)
feature_extractor

ViTImageProcessor {
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [21]:
import torch
import numpy as np

# device will determine whether to run the training on GPU or CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [23]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomVerticalFlip,
    RandomRotation,
    Resize,
    ToTensor,
    ColorJitter,
    RandomAffine
)
from PIL import Image  # Import PIL for RandomAffine's resample
import torch

torch.manual_seed(42)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
size = (feature_extractor.size["height"], feature_extractor.size["width"])

training_transforms = Compose([
    Resize(size),
    CenterCrop(size),
    # RandomRotation((-30, 30)),
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    ColorJitter(brightness=0.3, contrast=0.2, saturation=0.1, hue=0.05),
    RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05), interpolation=Image.BILINEAR),
    ToTensor(),
    normalize
])

def training_image_preprocess(batch):
    batch["pixel_values"] = torch.stack([training_transforms(img) for img in batch["image"]])
    return batch

def preprocess(batch):
    # take a list of PIL images and turn them to pixel values
    inputs = feature_extractor(
        batch['image'],
        return_tensors='pt'
    )
    inputs['label'] = batch['label']
    return inputs

In [8]:
train_test_split = dataset["train"].train_test_split(test_size=0.2, shuffle=True, seed=42)
dataset_train = train_test_split["train"]
dataset_test = train_test_split["test"]

In [9]:
num_classes = len(set(dataset_train['label']))
labels = dataset_train.features['label']
num_classes, labels

(6,
 ClassLabel(names=['Bacteria', 'Fungi', 'Healthy', 'Pest', 'Phytopthora', 'Virus'], id=None))

In [10]:
# transform the training dataset
prepared_train = dataset_train.with_transform(training_image_preprocess)
# ... and the testing dataset
prepared_test = dataset_test.with_transform(preprocess)

Save images of preprocessed images (both train and test)

In [11]:
import os
from torchvision.transforms.functional import to_pil_image

output_dir = f"{base_output_dir}/preprocessed_train_images"
os.makedirs(output_dir, exist_ok=True)

for index, item in enumerate(prepared_train):
    if index >= 10:
        break
    pixel_values = item["pixel_values"]
    image = to_pil_image(pixel_values)
    label_filename = dataset_train[index]["filename"]

    name_without_extension, extension = os.path.splitext(label_filename)
    filename = f"pp_{name_without_extension}.png"

    filepath = os.path.join(output_dir, filename)
    image.save(filepath)

In [12]:
output_dir = f"{base_output_dir}/preprocessed_test_images"
os.makedirs(output_dir, exist_ok=True)

for index, item in enumerate(prepared_test):
    if index >= 10:
        break
    pixel_values = item["pixel_values"]
    image = to_pil_image(pixel_values)
    label_filename = dataset_test[index]["filename"]

    name_without_extension, extension = os.path.splitext(label_filename)
    filename = f"pp_{name_without_extension}.png"

    filepath = os.path.join(output_dir, filename)
    image.save(filepath)

In [13]:
import evaluate

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)
    results = {}
    results.update(accuracy_metric.compute(
        predictions=predictions, 
        references=p.label_ids,
        )
    )
    results.update(f1_metric.compute(predictions=predictions, references=p.label_ids, average="weighted"))
    return results
#

In [15]:
from transformers import ViTForImageClassification, Trainer, TrainingArguments

training_args = TrainingArguments(
  output_dir='models',
  per_device_train_batch_size=16,
  eval_strategy="steps",
  num_train_epochs=8,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=5e-5,
  save_total_limit=2,
  seed=42,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,
)

labels = dataset_train.features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_id,  # classification head
    num_labels=len(labels)
)

model.to(device)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_train,
    eval_dataset=prepared_test,
    processing_class=feature_extractor,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
train_results = trainer.train()

# save tokenizer with the model
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)

# save the trainer state
trainer.save_state()

Step,Training Loss,Validation Loss,Accuracy,F1
100,0.739,0.713668,0.813653,0.804905
200,0.4436,0.5127,0.852399,0.853595
300,0.3979,0.45161,0.861624,0.863346
400,0.3371,0.447209,0.845018,0.843741
500,0.217,0.395547,0.872694,0.872265
600,0.1738,0.413418,0.872694,0.872293
700,0.1306,0.382393,0.883764,0.883073
800,0.1312,0.372025,0.891144,0.891026
900,0.0978,0.384194,0.878229,0.877236
1000,0.1168,0.394922,0.885609,0.88495


***** train metrics *****
  epoch                    =          8.0
  total_flos               = 1250029893GF
  train_loss               =       0.3427
  train_runtime            =   0:30:04.25
  train_samples_per_second =          9.6
  train_steps_per_second   =        0.603


In [18]:
from transformers import Trainer, ViTForImageClassification, ViTFeatureExtractor

# Load the trained model
model = ViTForImageClassification.from_pretrained(base_output_dir)
feature_extractor = ViTFeatureExtractor.from_pretrained(base_output_dir)

# Define the Trainer for evaluation
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    processing_class=feature_extractor,
    eval_dataset=prepared_test,  # Use your evaluation dataset here
)

# Now you can run the evaluation
eval_results = trainer.evaluate()

# Log and print the evaluation metrics
trainer.log_metrics("eval", eval_results)
trainer.save_metrics("eval", eval_results)

print(eval_results)



***** eval metrics *****
  eval_accuracy               =     0.8911
  eval_f1                     =      0.891
  eval_loss                   =      0.372
  eval_model_preparation_time =     0.0009
  eval_runtime                = 0:00:38.44
  eval_samples_per_second     =     14.096
  eval_steps_per_second       =      1.769
{'eval_loss': 0.3720252215862274, 'eval_model_preparation_time': 0.0009, 'eval_accuracy': 0.8911439114391144, 'eval_f1': 0.8910257759924213, 'eval_runtime': 38.4497, 'eval_samples_per_second': 14.096, 'eval_steps_per_second': 1.769}
