In [1]:
#pytorch lib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms,models
from torch.utils.data import Dataset,random_split
import torch.optim as optim


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os,shutil,warnings
import glob
import random 
import rasterio
import json

from transformers import ViTForImageClassification
from tqdm import tqdm
from PIL import Image
from IPython.display import display
from torchvision import transforms
from typing import Tuple, Dict, List

from sklearn.metrics import precision_score, recall_score, confusion_matrix, ConfusionMatrixDisplay


random.seed(69)


## Data Preperation
Since the image already formatted in "Classname/imagename.png", we can use utility function from PyTorch called "Image Folder" to prepare the image for us

In [22]:
#get all the classes
root_dir = "datasets/EuroSAT/"
transform = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),  # Converts to a tensor and scales values to [0, 1]
    #transforms.ConvertImageDtype(torch.float32),
])
dataset = datasets.ImageFolder(root=root_dir, transform=transform, )

train_size = int(0.9*len(dataset))
test_size  = len(dataset) - train_size

In [23]:
train_dataset,val_dataset = random_split(dataset,[train_size,test_size])


dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [31]:
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float32)
model.classifier = nn.Linear(model.classifier.in_features,10)

In [32]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [36]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')

    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        #print("test")
        outputs = model(images)
        
        logits = outputs.logits
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

        progress_bar.set_postfix(loss=running_loss/len(progress_bar))
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")
    if (epoch + 1) % 10 == 0:
        # 
        torch.save(model.state_dict(), f'weights/vit_model_{epoch+1}.pth')

print("Finished Training")

Epoch 1/100:  16%|███████████████████████████████████████████▉                                                                                                                                                                                                                               | 125/760 [00:31<02:41,  3.93batch/s, loss=0.367]


KeyboardInterrupt: 