In [1]:
!pip install torch torchvision torchaudio
!pip install transformers[torch]
!pip install accelerate -U
!pip install datasets
!pip install numpy
!pip install tqdm

Collecting accelerate
  Downloading accelerate-0.32.1-py3-none-any.whl.metadata (18 kB)
Downloading accelerate-0.32.1-py3-none-any.whl (314 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m314.1/314.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 0.30.1
    Uninstalling accelerate-0.30.1:
      Successfully uninstalled accelerate-0.30.1
Successfully installed accelerate-0.32.1


In [2]:
import numpy as np
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset, load_metric
import torch
from torch.utils.data import DataLoader
from PIL import Image
from tqdm import tqdm

# Check and print the available device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model parameters
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

# Load the dataset
dataset = load_dataset("zh-plus/tiny-imagenet")

# Preprocessing function to be applied during batching
def preprocess_function(images):
    inputs = feature_extractor(images=[image.convert('RGB') for image in images], return_tensors='pt')
    return inputs

# Custom dataset class to handle image preprocessing on-the-fly
class TinyImageNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, feature_extractor):
        self.dataset = dataset
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        image = example['image']
        label = example['label']
        inputs = preprocess_function([image])
        pixel_values = inputs['pixel_values'].squeeze()
        return {'pixel_values': pixel_values, 'labels': label}

train_dataset = TinyImageNetDataset(dataset['train'], feature_extractor)
valid_dataset = TinyImageNetDataset(dataset['valid'], feature_extractor)

# DataLoaders for efficient batching and GPU utilization
train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=20)

# Load metric
metric = load_metric("accuracy")

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

# Prepare labels for the model
labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

# Move the model to the GPU if available
model.to(device)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./vit-base-tiny-imagenet-demo",
    per_device_train_batch_size=20,
    evaluation_strategy="steps",
    num_train_epochs=4,
    fp16=True,
    save_steps=500,
    eval_steps=100,
    logging_steps=10,
    learning_rate=5e-5,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

# Custom collate function to handle the batch data correctly
def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.tensor([item['labels'] for item in batch])
    return {'pixel_values': pixel_values, 'labels': labels}

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=feature_extractor,
)

# Train the model
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

# Evaluate the model
eval_results = trainer.evaluate()
print(f"Eval Loss: {eval_results['eval_loss']}, Eval Accuracy: {eval_results['eval_accuracy']}")


2024-07-21 21:13:24.170411: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-21 21:13:24.170521: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-21 21:13:24.451689: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



Downloading readme:   0%|          | 0.00/3.90k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/3.52k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/146M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/14.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/10000 [00:00<?, ? examples/s]

  metric = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy
100,5.0641,5.049663,0.3315
200,4.7494,4.708425,0.6351
300,4.3837,4.341029,0.7232
400,4.0054,3.996261,0.776
500,3.691,3.675108,0.7849
600,3.3519,3.383193,0.7926
700,3.0736,3.113444,0.8077
800,2.9186,2.846425,0.8212
900,2.6339,2.610291,0.8237
1000,2.4865,2.400941,0.8196




***** train metrics *****
  epoch                    =           4.0
  total_flos               = 28919245220GF
  train_loss               =        0.7519
  train_runtime            =    4:57:21.49
  train_samples_per_second =         22.42
  train_steps_per_second   =          0.56




Eval Loss: 0.5175029039382935, Eval Accuracy: 0.8859
