In [None]:
import torch
import numpy as np
from datasets import load_dataset, DatasetDict#, load_metric
from PIL import Image
import torchvision.transforms as transforms
from transformers import AutoImageProcessor
import torchvision.transforms.functional as F
from transformers import Trainer
import evaluate
from transformers import TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
root_dir = '/home/veit/Downloads/export__TSV_17108_20250610_1155'
ds = load_dataset(
    "imagefolder", 
    data_dir=root_dir,
    split={
        "train": "train[:80%]",  # Use 80% for training
        "validation": "train[80%:]"  # Use 20% for validation
    }
)

In [15]:
ds

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 7980
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1995
    })
})

In [4]:
labels = ds['train'].features['label'].names
labels

['Acantharia',
 'Appendicularia',
 'Asteroidea larvae',
 'Bacillariophyceae',
 'Calanoida',
 'Chaetognatha',
 'Cnidaria_Metazoa',
 'Copepoda_Maxillopoda',
 'Crustacea',
 'Ctenophora_Metazoa',
 'Doliolida',
 'Echinodermata',
 'Mollusca',
 'Noctiluca sp.',
 'Oithona',
 'Ophiuroidea',
 'Polychaeta',
 'Rhizaria',
 'Salpidae',
 'Siphonophorae',
 'Solmundella',
 'Spumellaria',
 'Thaliacea',
 'Tunicata',
 'artefact',
 'body_Appendicularia',
 'bubble',
 'budding_Doliolida',
 'dark_sphere',
 'detritus',
 'gelatinous',
 'house',
 'multiple_Copepoda',
 'multiple_other',
 'noise',
 'nurse',
 'othertocheck',
 'part_Cnidaria',
 'puff',
 'solitaryblack',
 'solitaryblack-like',
 'sphere_othertocheck',
 'streak',
 'tentacle_Cnidaria',
 'tuft']

In [24]:
def resize_to_larger_edge(image, target_size):
    # Get the original dimensions of the image
    original_width, original_height = image.size
    
    # Determine which dimension is larger
    larger_edge = max(original_width, original_height)
    
    # Compute the scale factor to resize the larger edge to the target size
    scale_factor = target_size / larger_edge
    
    # Compute new dimensions
    new_width = int(original_width * scale_factor)
    new_height = int(original_height * scale_factor)
    
    try:
    # Resize the image
        resized_image = F.resize(image, (new_height, new_width))
    except(ValueError):
        #print(image.size,new_height,new_width)
        logging.info(f"Skipping: {image}: image size: {image.size}, new height: {new_height}, new width: {new_width}")
        return None        
    return resized_image

def custom_image_processor(image, target_size=(224, 224), padding_color=255):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Step 1: Resize the image
    resized_image = resize_to_larger_edge(image,224)

    if resized_image is None:  # Skip processing if resizing failed
        #print(f"Skipping image due to resize failure: {image.size}")
        return None  # This allows to filter out bad images later

    #Step 2: Calculate padding
    new_width, new_height = resized_image.size
    padding_left = (target_size[0] - new_width) // 2
    padding_right = target_size[0] - new_width - padding_left
    padding_top = (target_size[1] - new_height) // 2
    padding_bottom = target_size[1] - new_height - padding_top

    # Step 3: Apply padding
    padding = (padding_left, padding_top, padding_right, padding_bottom)
    pad_transform = transforms.Pad(padding, fill=padding_color)
    padded_image = pad_transform(resized_image)

    # Step 4: Apply other transformations
    transform_chain = transforms.Compose([
        transforms.RandomRotation(degrees=180,fill=255),
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
    ])
    
    # Apply the transformations
    return transform_chain(padded_image)

# Example function to process a batch
def process_batch(example_batch):
    # Process each image in the batch
    processed_images = [
        custom_image_processor(img) for img in example_batch['image']
        if custom_image_processor(img) is not None
        ]
    
    # Convert to a batch tensor
    inputs = torch.stack(processed_images)
    
    # Include labels (assuming they are present and you want to keep them)
    return {'pixel_values': inputs, 'label': example_batch['label']}


In [25]:
prepared_ds = ds.with_transform(process_batch)

In [26]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

metric = evaluate.load("accuracy")

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [27]:
root_dir = "/home/veit/PIScO_dev/ViT_custom_size_sensitive/"  # Path where all config files and checkpoints will be saved
training_args = TrainingArguments(
  output_dir=root_dir,
  per_device_train_batch_size=16,
  eval_strategy="epoch",
  save_strategy="epoch",
  fp16=True,
  num_train_epochs=10,
  logging_steps=500,
  learning_rate=2e-4,
  save_total_limit=1,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    #tokenizer=processor,
)

In [28]:
from transformers import AutoImageProcessor, ViTForImageClassification

model_name_or_path = 'google/vit-base-patch16-224-in21k'
#model_name_or_path = '/home/plankton/PISCO_Classification/ViT_custom_size_sensitive/best_model/'

#processor = AutoImageProcessor.from_pretrained(model_name_or_path, use_fast=True)
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
save_dir = '/home/veit/PIScO_dev/ViT_custom_size_sensitive/best_model/'  # Define the path to save the model
train_results = trainer.train()
trainer.save_model(save_dir)  # Save the best model
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 