In [None]:
!pip install -q transformers timm torchvision datasets

from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import ImageFolder
from PIL import Image
from transformers import ViTForImageClassification, ViTFeatureExtractor
import numpy as np
from tqdm import tqdm


In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

DTD_PATH = "/content/drive/MyDrive/dtd"


In [None]:

    train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.RandomRotation(5),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])




In [None]:

full_dataset = ImageFolder(root=os.path.join(DTD_PATH, "images"), transform=train_transform)


total_size = len(full_dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

val_dataset.dataset.transform = test_transform
test_dataset.dataset.transform = test_transform

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

print("Classes:", full_dataset.classes)


Classes: ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged']


In [None]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=47,
    ignore_mismatched_sizes=True
).to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([47]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([47, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total

def train(model, train_loader, val_loader, epochs=20):
    best_acc = 0
    for epoch in range(epochs):
        model.train()
        running_loss = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images).logits
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")

        val_acc = evaluate(model, val_loader)
        print(f"Validation Accuracy: {val_acc:.2f}%")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_vit_dtd.pth")
            print(" Saved best model!")


In [None]:
train(model, train_loader, val_loader, epochs=10)


Epoch 1/10: 100%|██████████| 124/124 [02:05<00:00,  1.01s/it]

Epoch [1/10], Loss: 0.1536





Validation Accuracy: 74.59%
 Saved best model!


Epoch 2/10: 100%|██████████| 124/124 [02:08<00:00,  1.04s/it]

Epoch [2/10], Loss: 0.0634





Validation Accuracy: 76.12%
 Saved best model!


Epoch 3/10: 100%|██████████| 124/124 [02:08<00:00,  1.04s/it]

Epoch [3/10], Loss: 0.0269





Validation Accuracy: 76.95%
 Saved best model!


Epoch 4/10: 100%|██████████| 124/124 [02:08<00:00,  1.04s/it]

Epoch [4/10], Loss: 0.0123





Validation Accuracy: 77.42%
 Saved best model!


Epoch 5/10: 100%|██████████| 124/124 [02:07<00:00,  1.03s/it]

Epoch [5/10], Loss: 0.0062





Validation Accuracy: 76.83%


Epoch 6/10: 100%|██████████| 124/124 [02:07<00:00,  1.03s/it]

Epoch [6/10], Loss: 0.0030





Validation Accuracy: 78.01%
 Saved best model!


Epoch 7/10: 100%|██████████| 124/124 [02:08<00:00,  1.04s/it]

Epoch [7/10], Loss: 0.0020





Validation Accuracy: 78.25%
 Saved best model!


Epoch 8/10: 100%|██████████| 124/124 [02:07<00:00,  1.03s/it]

Epoch [8/10], Loss: 0.0019





Validation Accuracy: 78.13%


Epoch 9/10: 100%|██████████| 124/124 [02:08<00:00,  1.03s/it]

Epoch [9/10], Loss: 0.0017





Validation Accuracy: 77.90%


Epoch 10/10: 100%|██████████| 124/124 [02:07<00:00,  1.03s/it]

Epoch [10/10], Loss: 0.0016





Validation Accuracy: 78.25%


In [None]:

torch.save(model.state_dict(), "best_model.pth")



In [None]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed