In [1]:
import random
import numpy as np
import torch  # Assuming PyTorch is the framework used
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer, AutoImageProcessor
from datasets import load_dataset, load_metric

TRAIN=False

# random.seed(0)
# np.random.seed(0)
# torch.manual_seed(0)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(0)
    
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PLATE_NUMBER=6

BASE_DIR = '/home/isacc/bite_acquisition/task_planning_tests/study_plates/'

if PLATE_NUMBER == 1:
    PLATE_NAME = 'spaghetti_meatballs'
elif PLATE_NUMBER == 2:
    PLATE_NAME = 'fettuccine_chicken_broccoli'
elif PLATE_NUMBER == 3:
    PLATE_NAME = 'mashed_potato_sausage'
elif PLATE_NUMBER == 4:
    PLATE_NAME = 'oatmeal_strawberry'
elif PLATE_NUMBER == 6:
    PLATE_NAME = 'dessert'

TRAIN_DIR = BASE_DIR +'/log/' + PLATE_NAME + '/classification_format/train'
TEST_DIR = BASE_DIR + '/log/' + PLATE_NAME + '/classification_format/test'
OUTPUT_DIR = BASE_DIR + 'outputs/swin_transformers/' + PLATE_NAME

In [3]:
metric = load_metric("accuracy")

model_checkpoint = "microsoft/swin-tiny-patch4-window7-224" # pre-trained model from which to fine-tune
batch_size = 32 # batch size for training and evaluation

image_processor  = AutoImageProcessor.from_pretrained(model_checkpoint)
image_processor 

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

  metric = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [4]:
if TRAIN:
    train_dataset = load_dataset("imagefolder", data_dir=TRAIN_DIR)
    example = train_dataset["train"][10]
    example['image'].resize((200, 200))
    example['label']
    train_dataset["train"].features["label"]

    # count the numbers of datapoints per class
    from collections import Counter
    print(Counter(train_dataset["train"]["label"]))
    # print(Counter(test_dataset["train"]["label"]))

In [5]:
if TRAIN:
    for seed in [0, 1, 2, 3, 4]:
        print(f"----------Running seed {seed}----------")
        train_dataset = train_dataset.shuffle()
        labels = train_dataset["train"].features["label"].names
        label2id, id2label = dict(), dict()
        for i, label in enumerate(labels):
            label2id[label] = i
            id2label[i] = label

        # split up training into training + validation
        splits = train_dataset["train"].train_test_split(test_size=0.4)
        train_ds = splits['train']
        val_ds = splits['test']

        print(train_ds['label'])
        print(val_ds['label'])
        
        train_ds.set_transform(preprocess_train)
        val_ds.set_transform(preprocess_val)
        
        model = AutoModelForImageClassification.from_pretrained(
            model_checkpoint, 
            label2id=label2id,
            id2label=id2label,
            ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
        )
        
        model_name = model_checkpoint.split("/")[-1]

        args = TrainingArguments(
            OUTPUT_DIR + f"/checkpoints/{model_name}-finetuned-{PLATE_NAME}-{seed}",
            remove_unused_columns=False,
            evaluation_strategy = "epoch",
            save_strategy = "epoch",
            save_total_limit=2,
            learning_rate=5e-5,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=4,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=10,
            warmup_ratio=0.1,
            logging_steps=10,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            push_to_hub=False,
            # max_grad_norm=1.0,  # Add this line for gradient clipping
        )
        
        trainer = Trainer(
            model,
            args,
            train_dataset=train_ds,
            eval_dataset=val_ds,
            tokenizer=image_processor,
            compute_metrics=compute_metrics,
            data_collator=collate_fn,
        )
        
        train_results = trainer.train()
        # rest is optional but nice to have
        trainer.save_model()
        trainer.log_metrics("train", train_results.metrics)
        trainer.save_metrics("train", train_results.metrics)
        trainer.save_state()
        
        metrics = trainer.evaluate()
        # some nice to haves:
        trainer.log_metrics("eval", metrics)

        # evaluate on test set without data augmentation
        test_dataset = load_dataset("imagefolder", data_dir=TEST_DIR)
        test_dataset.set_transform(preprocess_val)
        test_dataset = test_dataset["train"]
        metrics = trainer.evaluate(test_dataset)
        trainer.log_metrics("test", metrics)

In [6]:
# for each seed, load the best model and evaluate on the test set
test_dataset = load_dataset("imagefolder", data_dir=TEST_DIR)
test_dataset.set_transform(preprocess_val)
test_dataset = test_dataset["train"]

seed_results = {}

# Iterate over each seed, load the corresponding model and evaluate it on the test dataset
for seed in [0, 1, 2, 3, 4]:
    model_path = OUTPUT_DIR + f"/checkpoints/{model_checkpoint.split('/')[-1]}-finetuned-{PLATE_NAME}-{seed}"
    model = AutoModelForImageClassification.from_pretrained(model_path)
    trainer = Trainer(
        model=model,
        args=TrainingArguments(
            output_dir="./results",  # Adjust this to wherever your models are stored
            remove_unused_columns=False,
            per_device_eval_batch_size=batch_size
        ),
        compute_metrics=compute_metrics,
        data_collator=collate_fn,
    )
    print(f"Evaluating the model for seed {seed}...")
    metrics = trainer.evaluate(test_dataset)
    seed_results[seed] = metrics
    print(f"Results for seed {seed}: {metrics}")

for seed, results in seed_results.items():
    print(f"Seed {seed}: {results}")

average_metric = np.mean([result["eval_accuracy"] for result in seed_results.values()])
print(f"Average accuracy over all seeds: {average_metric}")

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Evaluating the model for seed 0...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 28.33it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Results for seed 0: {'eval_loss': 0.5586597323417664, 'eval_accuracy': 0.8461538461538461, 'eval_runtime': 0.6005, 'eval_samples_per_second': 86.598, 'eval_steps_per_second': 3.331}
Evaluating the model for seed 1...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 27.32it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Results for seed 1: {'eval_loss': 0.7040219902992249, 'eval_accuracy': 0.5, 'eval_runtime': 0.4736, 'eval_samples_per_second': 109.799, 'eval_steps_per_second': 4.223}
Evaluating the model for seed 2...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 27.79it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Results for seed 2: {'eval_loss': 0.5732942819595337, 'eval_accuracy': 0.7884615384615384, 'eval_runtime': 0.4732, 'eval_samples_per_second': 109.893, 'eval_steps_per_second': 4.227}
Evaluating the model for seed 3...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 26.00it/s]
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Results for seed 3: {'eval_loss': 0.5847002863883972, 'eval_accuracy': 0.7884615384615384, 'eval_runtime': 0.5042, 'eval_samples_per_second': 103.139, 'eval_steps_per_second': 3.967}
Evaluating the model for seed 4...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 28.89it/s]

Results for seed 4: {'eval_loss': 0.5641040802001953, 'eval_accuracy': 0.75, 'eval_runtime': 0.4927, 'eval_samples_per_second': 105.545, 'eval_steps_per_second': 4.059}
Seed 0: {'eval_loss': 0.5586597323417664, 'eval_accuracy': 0.8461538461538461, 'eval_runtime': 0.6005, 'eval_samples_per_second': 86.598, 'eval_steps_per_second': 3.331}
Seed 1: {'eval_loss': 0.7040219902992249, 'eval_accuracy': 0.5, 'eval_runtime': 0.4736, 'eval_samples_per_second': 109.799, 'eval_steps_per_second': 4.223}
Seed 2: {'eval_loss': 0.5732942819595337, 'eval_accuracy': 0.7884615384615384, 'eval_runtime': 0.4732, 'eval_samples_per_second': 109.893, 'eval_steps_per_second': 4.227}
Seed 3: {'eval_loss': 0.5847002863883972, 'eval_accuracy': 0.7884615384615384, 'eval_runtime': 0.5042, 'eval_samples_per_second': 103.139, 'eval_steps_per_second': 3.967}
Seed 4: {'eval_loss': 0.5641040802001953, 'eval_accuracy': 0.75, 'eval_runtime': 0.4927, 'eval_samples_per_second': 105.545, 'eval_steps_per_second': 4.059}
Averag


