
### Model description
The notebook references the following:
- [ViT (Vision Transformer) model by Google](https://huggingface.co/google/vit-base-patch16-224)
- [Dataset for the project](https://github.com/Alib)


### Install necessary libraries
This code block installs all the required libraries for running the notebook. These libraries include:
- `transformers` for accessing pre-trained models and tokenizers from Hugging Face.
- `torch` (PyTorch) for building and training neural networks.
- `torchvision` for image processing utilities and datasets.
- `datasets` for accessing datasets easily.

Libraries:
- `transformers`: A library by Hugging Face for state-of-the-art NLP models, also supports vision models. [Documentation](https://huggingface.co/transformers/)
- `torch`: PyTorch is an open-source machine learning library for Python, used for applications such as natural language processing. [Documentation](https://pytorch.org/docs/stable/index.html)
- `torchvision`: A library with tools and datasets for computer vision tasks. [Documentation](https://pytorch.org/vision/stable/index.html)
- `datasets`: A library by Hugging Face to easily share and use datasets. [Documentation](https://huggingface.co/docs/datasets/)

In [1]:
pip install transformers torch torchvision datasets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Note: you may need to restart the kernel to use updated packages.


### Import libraries and set up data transformations
This block imports all the necessary libraries and sets up the data transformations needed for the model. It performs the following tasks:
- Imports essential libraries for handling data and models.
- Defines transformations to preprocess images to the required format for the model.
- Loads training and evaluation datasets.
- Sets up data loaders to iterate over the datasets in batches during training and evaluation.

Libraries:
- `torch`: Core library.
- `torch.utils.data.DataLoader`: Utility to load data in batches.
- `torchvision.datasets`: Contains many standard vision datasets.
- `torchvision.transforms`: Common image transformations.
- `datasets.load_dataset`: Function to load datasets from Hugging Face.
- `ViTFeatureExtractor`: A feature extractor for ViT models.

In [5]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from datasets import load_dataset
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer

# Load dataset
dataset = load_dataset('training', data_dir='data')

# Split the dataset into training and validation sets
train_test_split = dataset['train'].train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']

print(dataset)
print(f"train_dataset[0] {train_dataset[0]}")
print(f"train_dataset[1] {train_dataset[1]}")
print(f"train_dataset[2] {train_dataset[2]}")
print(f"eval_dataset[0] {eval_dataset[0]}")
print(f"eval_dataset[1] {eval_dataset[1]}")

# Define the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

# Transform function
def transform(example_batch):
    # Convert images to RGB and apply feature extraction
    images = [image.convert("RGB") for image in example_batch['image']]
    inputs = feature_extractor(images, return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

# Apply transformations to datasets
train_dataset.set_transform(transform)
eval_dataset.set_transform(transform)

# Collate function
def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'].squeeze(0) for item in batch])
    labels = torch.tensor([item['labels'] for item in batch])
    return {
        'pixel_values': pixel_values,
        'labels': labels
    }

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(eval_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

print(dataset)
print(f"train_dataset[0] {train_dataset[0]}")
print(f"train_dataset[1] {train_dataset[1]}")
print(f"train_dataset[2] {train_dataset[2]}")
print(f"eval_dataset[0] {eval_dataset[0]}")
print(f"eval_dataset[1] {eval_dataset[1]}")

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 10
    })
})
train_dataset[0] {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=250x333 at 0x7F6DF6F6AE90>, 'label': 1}
train_dataset[1] {'image': <PIL.WebPImagePlugin.WebPImageFile image mode=RGB size=640x853 at 0x7F6DF7BE3AD0>, 'label': 1}
train_dataset[2] {'image': <PIL.WebPImagePlugin.WebPImageFile image mode=RGB size=399x266 at 0x7F6DF7BB7890>, 'label': 1}
eval_dataset[0] {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=570x400 at 0x7F6E0C0BC290>, 'label': 0}
eval_dataset[1] {'image': <PIL.WebPImagePlugin.WebPImageFile image mode=RGB size=1000x667 at 0x7F6E0CDE79D0>, 'label': 1}
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 10
    })
})
train_dataset[0] {'pixel_values': tensor([[[-0.9294, -0.9294, -0.9294,  ..., -0.5216, -0.5216, -0.5294],
         [-0.9294, -0.9294, -0.9294,  ..., -0.4980, -0.4980, -0.4980],
       



### Load the ViT model
This block loads the pre-trained Vision Transformer (ViT) model and its corresponding feature extractor from Hugging Face. It prepares the model for image classification tasks with the specified number of output labels. 

Libraries:
- `ViTForImageClassification`: Model class for Vision Transformer.
- `from_pretrained`: Loads a pre-trained model from Hugging Face's model hub.
- `num_labels`: Number of output labels.
- `ignore_mismatched_sizes`: Useful for fine-tuning a model with different input sizes.

In [6]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=2, ignore_mismatched_sizes=True)

training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir='./logs',
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Set up optimizer and learning rate scheduler
This block sets up the optimizer and learning rate scheduler for training the model. The optimizer updates the model parameters to minimize the loss function, while the learning rate scheduler adjusts the learning rate during training to improve performance.

Libraries:
- `AdamW`: Optimizer with weight decay fix, recommended for transformers.
- `get_scheduler`: Utility to get a learning rate scheduler.
- `tqdm.auto.tqdm`: Progress bar library.

In [7]:
from transformers import AdamW, get_scheduler
from tqdm.auto import tqdm

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_loader:
        # Move batch to device (GPU or CPU)
        batch = {k: v.to(model.device) for k, v in batch.items()}

        # Forward pass
        print(f"batch: {batch}")
        outputs = model(**batch)
        loss = outputs.loss

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

        # Print loss for debugging
        progress_bar.set_postfix({"loss": loss.item()})



  0%|          | 0/3 [00:00<?, ?it/s]

batch: {'pixel_values': tensor([[[[ 0.5294,  0.5843,  0.5843,  ..., -0.0431, -0.0275, -0.1529],
          [ 0.5529,  0.5451,  0.5059,  ...,  0.0824, -0.0510, -0.1216],
          [ 0.5843,  0.5529,  0.4431,  ...,  0.0118, -0.0510,  0.0275],
          ...,
          [-0.2157,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [-0.2157,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [-0.2157,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000]],

         [[ 0.1765,  0.2471,  0.2784,  ..., -0.2235, -0.2000, -0.3255],
          [ 0.2157,  0.2392,  0.2000,  ..., -0.1059, -0.2392, -0.3098],
          [ 0.2706,  0.2392,  0.1294,  ..., -0.1922, -0.2627, -0.1686],
          ...,
          [-0.2157,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [-0.2157,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [-0.2157,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000]],

         [[-0.0588, -0.0039, -0.0039,  ..., -0.4118, -0.4039, -0.5059],
      

### Evaluation loop
This block sets the model to evaluation mode and iterates over the evaluation dataset to compute predictions. It ensures that gradients are not calculated to speed up the process and reduce memory usage.

Libraries:
- `model.eval()`: Sets the model to evaluation mode.
- `torch.no_grad()`: Disables gradient calculation for faster evaluation.
- `batch.items()`: Iterates over the items in a batch.

In [10]:
model.eval()
total_correct = 0
total_samples = 0

for batch in eval_loader:
    with torch.no_grad():
        # Move batch to device (GPU or CPU)
        batch = {k: v.to(model.device) for k, v in batch.items()}

        # Forward pass
        outputs = model(**batch)

        # Access logits
        logits = outputs.logits

        # Get predictions
        predictions = torch.argmax(logits, dim=-1)

        # Get true labels (assuming they are in 'labels' key of batch)
        labels = batch['labels']

        # Count correct predictions
        correct = (predictions == labels).sum().item()

        # Update totals
        total_correct += correct
        total_samples += labels.size(0)

# Calculate accuracy
accuracy = (total_correct / total_samples) * 100
print(f"Accuracy: {accuracy:.2f}%")

Accuracy: 100.00%


### Save the model and feature extractor
This block saves the trained model and the feature extractor to disk so that they can be loaded and used later without retraining.

Libraries:
- `save_pretrained`: Saves the model and feature extractor for later use.

In [11]:
model.save_pretrained('models/cat_bowl_model')
feature_extractor.save_pretrained('models/cat_bowl_model')

['models/cat_bowl_model/preprocessor_config.json']

### Load and test the model with an image
This block demonstrates how to load the saved model and feature extractor, preprocess a new image, and make a prediction using the model. It also prints the predicted label for the given image.

Libraries:
- `PIL.Image`: Python Imaging Library, used to open and manipulate images.
- `ViTImageProcessor`: Processes images for the ViT model.
- `model(**inputs).logits`: Forward pass to get logits (raw model outputs).
- `logits.argmax(-1)`: Gets the index of the highest logit, representing the predicted class.

In [12]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image

# Load the fine-tuned model and image processor
model = ViTForImageClassification.from_pretrained('models/cat_bowl_model')
processor = ViTImageProcessor.from_pretrained('models/cat_bowl_model')

# Function to load and preprocess the image
def load_and_preprocess_image(image_path):
    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")
    return inputs

# Function to predict if the bowl is full or empty
def predict(image_path):
    inputs = load_and_preprocess_image(image_path)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}  # Move to device
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return predicted_class_idx

# Class labels
class_labels = ["empty", "full"]

# Main function for CLI
def main():
    image_path = "training/test/empty_01.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/empty_02.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/empty_03.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/empty_04.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/empty_05.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/empty_06.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/empty_07.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/full_01.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/full_02.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/full_03.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/full_04.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/full_05.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/full_06.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")
    image_path = "training/test/full_07.jpg"
    predicted_class_idx = predict(image_path)
    print(f"The bowl in {image_path} is {class_labels[predicted_class_idx]}.")

if __name__ == "__main__":
    main()

The bowl in training/test/empty_01.jpg is empty.
The bowl in training/test/empty_02.jpg is empty.
The bowl in training/test/empty_03.jpg is empty.
The bowl in training/test/empty_04.jpg is empty.
The bowl in training/test/empty_05.jpg is empty.
The bowl in training/test/empty_06.jpg is full.
The bowl in training/test/empty_07.jpg is full.
The bowl in training/test/full_01.jpg is full.
The bowl in training/test/full_02.jpg is full.
The bowl in training/test/full_03.jpg is full.
The bowl in training/test/full_04.jpg is full.
The bowl in training/test/full_05.jpg is full.
The bowl in training/test/full_06.jpg is full.
The bowl in training/test/full_07.jpg is full.
