In [1]:
import torch
from transformers import ViTForImageClassification,ViTFeatureExtractor,Trainer,TrainingArguments
from datasets import load_dataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)
model.to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

  from .autonotebook import tqdm as notebook_tqdm
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Load CIFAR-10 dataset
dataset = load_dataset("cifar10")

# Define image transformation
def transform(example_batch):
    transforms = Compose([
        Resize((224, 224)),  # ViT expects 224x224 images
        ToTensor(),
        Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)  # Apply ViT specific normalization
    ])
    example_batch['pixel_values'] = [transforms(img) for img in example_batch]
    return example_batch

# Apply the transformation to the dataset
prepared_dataset = dataset.with_transform(transform)

# Create PyTorch Dataloaders
train_loader = DataLoader(prepared_dataset['train'], batch_size=32, shuffle=True)
test_loader = DataLoader(prepared_dataset['test'], batch_size=32)

## Fine-tuning the model

In [6]:
#Defining training arguments
training_arguments = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True
)

#Defining evaluation metric
from sklearn.metrics import accuracy_score

def compute_evaluation_metric(prediction):
    labels = prediction.label_ids
    preds = prediction.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    return {"accuracy" : acc}

#Defining trainer 
trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=prepared_dataset["train"],
    eval_dataset=prepared_dataset["test"],
    tokenizer=feature_extractor,
    compute_metrics=compute_evaluation_metric
)

#finetuning with the trainer
trainer.train()

  trainer = Trainer(


NameError: name 'img' is not defined