In [2]:
!pip install datasets transformers


Collecting datasets
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.0-py3-none-any.whl (474 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.3/474.3 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K  

In [None]:
import torch
from torchvision import transforms
from datasets import load_dataset
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm

# Step 1: Load pre-trained Vision Transformer model and feature extractor
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Step 2: Load CIFAR-10 dataset
dataset = load_dataset("cifar10")

# Step 3: Preprocess images to make them compatible with the Vision Transformer
def transform_images(batch):
    # Convert PIL images to RGB and apply feature extractor
    images = [img.convert("RGB") for img in batch['img']]
    inputs = feature_extractor(images=images, return_tensors='pt')
    batch['pixel_values'] = inputs['pixel_values']
    return batch

# Apply the transformation
dataset.set_transform(transform_images)

# Step 4: Custom collate function to handle batches properly
def custom_collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])
    return {'pixel_values': pixel_values, 'label': labels}

# Step 5: Create DataLoaders for training and testing with the custom collate function
train_loader = DataLoader(dataset['train'], batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
test_loader = DataLoader(dataset['test'], batch_size=32, shuffle=False, collate_fn=custom_collate_fn)

# Step 6: Set up training configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training function
def train(model, dataloader):
    model.train()  # Set model to training mode
    total_loss = 0

    for batch in tqdm(dataloader):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()  # Clear previous gradients
        outputs = model(pixel_values=pixel_values, labels=labels)  # Forward pass
        loss = outputs.loss  # Compute loss
        loss.backward()  # Backpropagate the loss
        optimizer.step()  # Update model parameters

        total_loss += loss.item()

    return total_loss / len(dataloader)  # Return average loss

# Training loop for multiple epochs
for epoch in range(3):  # You can adjust the number of epochs
    loss = train(model, train_loader)
    print(f"Epoch {epoch+1}, Loss: {loss}")

# Step 7: Evaluation function
def evaluate(model, dataloader):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient computation during evaluation
        for batch in tqdm(dataloader):
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)
            outputs = model(pixel_values=pixel_values)
            preds = outputs.logits.argmax(-1)  # Get the predicted labels
            correct += (preds == labels).sum().item()  # Count correct predictions
            total += labels.size(0)

    return correct / total  # Return accuracy

# Evaluate the model on the test dataset
accuracy = evaluate(model, test_loader)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Step 8: Save the trained model
model.save_pretrained("./vit_model")
print("Model saved successfully!")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



README.md:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/120M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

  7%|▋         | 113/1563 [1:59:05<25:20:25, 62.91s/it]