In [1]:
import torch
from torch import nn
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch import optim
import json
import random
from PIL import Image
import os
from matplotlib import pyplot as plt

In [2]:
SEED = 97120
random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7e76c4fe85b0>

In [3]:
!pip install pytorch_pretrained_vit

Collecting pytorch_pretrained_vit
  Downloading pytorch-pretrained-vit-0.0.7.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->pytorch_pretrained_vit)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->pytorch_pretrained_vit)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->pytorch_pretrained_vit)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch->pytorch_pretrained_vit)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch->pytorch_pretrained_vit)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 

In [6]:
# import
from pytorch_pretrained_vit import ViT

#- pretrained model
model_name = 'B_16_imagenet1k'
model = ViT(model_name, pretrained=True)

Loaded pretrained weights.


In [9]:
# SECTION 2: Download Datasets
# Use kagglehub to download the datasets
florencetushabe_path = kagglehub.dataset_download('florencetushabe/sickle-cell-disease-dataset')
fenicxs_path = kagglehub.dataset_download('fenicxs/sickle-cell-anaemia')
data_dir = fenicxs_path

In [8]:
import kagglehub

In [10]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import json
import random
from PIL import Image
from matplotlib import pyplot as plt
import os

In [11]:
# SECTION 3: Dataset Preparation
# Define data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjust to ViT input size
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalizing images
])

In [12]:
SEED = 97120
random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7e76c4fe85b0>

In [13]:
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)

In [15]:
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

In [16]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# SECTION 4: Define ViT Model
from pytorch_pretrained_vit import ViT

In [17]:
model_name = 'B_16_imagenet1k'
vit = ViT(model_name, pretrained=True)

# Adjust classifier for binary classification
vit.fc = nn.Linear(vit.fc.in_features, 2) 

# Move model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vit = vit.to(device)

Loaded pretrained weights.


In [18]:
#Training and Validation Functions
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    epoch_loss, correct = 0, 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    accuracy = correct / len(dataloader.dataset)
    return epoch_loss / len(dataloader), accuracy


def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    epoch_loss, correct = 0, 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            epoch_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()

    accuracy = correct / len(dataloader.dataset)
    return epoch_loss / len(dataloader), accuracy
        

In [19]:
import torch.nn.functional as F

# Section 6: Adjust positional embeddings
def resize_positional_embedding(vit_model, new_grid_size):
    old_embedding = vit_model.positional_embedding.pos_embedding  # Extract positional embedding
    cls_token = old_embedding[:, :1]  # Class token
    old_grid_embedding = old_embedding[:, 1:]  # Grid embeddings

    # Compute old grid size
    old_grid_size = int(old_grid_embedding.shape[1] ** 0.5)
    old_grid_embedding = old_grid_embedding.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)

    # Resize grid embeddings
    new_grid_embedding = F.interpolate(
        old_grid_embedding, size=(new_grid_size, new_grid_size), mode="bilinear", align_corners=False
    )
    new_grid_embedding = new_grid_embedding.permute(0, 2, 3, 1).reshape(1, new_grid_size * new_grid_size, -1)

    # Combine class token with resized grid embeddings
    new_pos_embedding = torch.cat([cls_token, new_grid_embedding], dim=1)
    vit_model.positional_embedding.pos_embedding = torch.nn.Parameter(new_pos_embedding)



In [20]:
resize_positional_embedding(vit, new_grid_size=14)
print(f"Resized positional embedding shape: {vit.positional_embedding.pos_embedding.shape}")


Resized positional embedding shape: torch.Size([1, 197, 768])


In [21]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])


In [22]:
for images, _ in train_loader:
    images = images.to(device)
    patch_embeddings = vit.patch_embedding(images)
    print(f"Patch embedding shape: {patch_embeddings.shape}")
    break


Patch embedding shape: torch.Size([32, 768, 14, 14])


In [23]:

batch_size, hidden_size, grid_h, grid_w = patch_embeddings.shape
patch_embeddings = patch_embeddings.view(batch_size, hidden_size, -1).permute(0, 2, 1)



In [24]:
#Prepare for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vit = vit.to(device)

optimizer = torch.optim.Adam(vit.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()


In [25]:
# Define training and validation functions
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct = 0, 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += (outputs.argmax(1) == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)
    return avg_loss, accuracy

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            total_correct += (outputs.argmax(1) == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)
    return avg_loss, accuracy

In [26]:
# Section 9: Training loop
epochs = 10
for epoch in range(epochs):
    train_loss, train_acc = train_epoch(vit, train_loader, optimizer, criterion, device)
    val_loss, val_acc = validate_epoch(vit, val_loader, criterion, device)
    print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} - Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/10 - Train Loss: 0.4011, Train Acc: 0.8418 - Val Loss: 0.5147, Val Acc: 0.8421
Epoch 2/10 - Train Loss: 0.2443, Train Acc: 0.9121 - Val Loss: 0.3878, Val Acc: 0.8596
Epoch 3/10 - Train Loss: 0.1816, Train Acc: 0.9209 - Val Loss: 0.6924, Val Acc: 0.7719
Epoch 4/10 - Train Loss: 0.1267, Train Acc: 0.9538 - Val Loss: 0.4328, Val Acc: 0.8947
Epoch 5/10 - Train Loss: 0.1805, Train Acc: 0.9231 - Val Loss: 0.5188, Val Acc: 0.8684
Epoch 6/10 - Train Loss: 0.1133, Train Acc: 0.9604 - Val Loss: 0.4526, Val Acc: 0.8421
Epoch 7/10 - Train Loss: 0.1327, Train Acc: 0.9516 - Val Loss: 0.5395, Val Acc: 0.9035
Epoch 8/10 - Train Loss: 0.0881, Train Acc: 0.9736 - Val Loss: 0.4324, Val Acc: 0.8333
Epoch 9/10 - Train Loss: 0.0697, Train Acc: 0.9780 - Val Loss: 0.4409, Val Acc: 0.8860
Epoch 10/10 - Train Loss: 0.0231, Train Acc: 0.9934 - Val Loss: 0.4013, Val Acc: 0.9123


In [36]:
# Save the trained model
torch.save(vit.state_dict(), "vit_sickle_cell_classifier.pth")


In [37]:
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

# SECTION 2: Download Datasets
# Use kagglehub to download the datasets
florencetushabe_path = kagglehub.dataset_download('florencetushabe/sickle-cell-disease-dataset')
fenicxs_path = kagglehub.dataset_download('fenicxs/sickle-cell-anaemia')
data_dir = fenicxs_path

# Define the data transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match ViT input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalization
])

# Load the full dataset
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Split the dataset into training, validation, and test sets
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# Create DataLoaders
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Train dataset size: {len(train_dataset)} samples")
print(f"Validation dataset size: {len(val_dataset)} samples")
print(f"Test dataset size: {len(test_dataset)} samples")


Train dataset size: 398 samples
Validation dataset size: 113 samples
Test dataset size: 58 samples


In [38]:
test_loss, test_acc = validate_epoch(vit, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")


Test Loss: 0.1959, Test Accuracy: 0.9483


In [41]:
# Section 10: Test Function
def test_model(model, dataloader, criterion, device):
    model.eval()  # Set model to evaluation mode
    test_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():  # No gradients required during testing
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            # Get predictions and calculate accuracy
            _, preds = torch.max(outputs, 1)
            correct_predictions += (preds == labels).sum().item()
            total_predictions += labels.size(0)

    # Calculate average loss and accuracy
    avg_loss = test_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions
    return avg_loss, accuracy
