In [None]:
# Install PyTorch and torchvision
!pip install torchvision --upgrade

Collecting torchvision
  Downloading torchvision-0.20.1-cp310-cp310-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torch==2.5.1 (from torchvision)
  Downloading torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl.metadata (28 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.5.1->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.5.1->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.5.1->torchvision)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.5.1->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.5.1->torchvision)
  Dow

## Device Agnostic Code , it checks if the gpu is available or not , and puts the device to GPU if available otherwise CPU.

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"using device {device}")

using device cuda


import torch
import os
import shutil
import random


data_dir = '/content/drive/MyDrive/archive/animals'

train_dir = '/content/drive/MyDrive/data2/train'
val_dir = '/content/drive/MyDrive/data2/val'

os.makedirs(train_dir , exist_ok=True)
os.makedirs(val_dir , exist_ok=True)

train_split = 0.8

for animals in os.listdir(data_dir):
  animal_folder = os.path.join(data_dir,animals)

  if os.path.isdir(animal_folder):

    images = os.listdir(animal_folder)

    random.shuffle(images)

    split_point = int(len(images) * train_split)

    train_images = images[:split_point]
    val_images = images[split_point:]

    train_animal_folder = os.path.join(train_dir,animals)
    val_animal_folder = os.path.join(val_dir,animals)
    os.makedirs(train_animal_folder,exist_ok=True)
    os.makedirs(val_animal_folder,exist_ok=True)

    for image in train_images:
      src = os.path.join(animal_folder,image)
      dst = os.path.join(train_animal_folder,image)
      shutil.copy(src,dst)


    for image in val_images:
      src = os.path.join(animal_folder,image)
      dst = os.path.join(val_animal_folder,image)
      shutil.copy(src,dst)

    print(f"Processed {animals}: {len(train_images)} images in train, {len(val_images)} in val.")






# Data Loading and Preprocessing

In [None]:
from torchvision import datasets
from torchvision import transforms

# Define paths for the train and validation datasets that we created above
train_dir = '/content/drive/MyDrive/data2/train'
val_dir = '/content/drive/MyDrive/data2/val'

# Define transformations/data augmentations that we are going to apply on the pictures

train_transform = transforms.Compose([
    transforms.Resize((224,224)), # resize all images to 224X224 , makes sure all images have same dimension
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(), #converts images from PIL format into pyTorch tensors
    transforms.Normalize(mean =[0.485,0.456,0.406],
                         std =[0.229,0.224,0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

train_dataset = datasets.ImageFolder(train_dir,transform = train_transform)

val_dataset = datasets.ImageFolder(val_dir,transform = val_transform)



In [None]:
from torch.utils.data import DataLoader

batch_size =32

train_loader = DataLoader(train_dataset,batch_size = batch_size,num_workers =2 , pin_memory = True,shuffle = True) # num_workers =2 , means more process to fetch data and faster
val_loader = DataLoader(val_dataset, batch_size = batch_size,num_workers = 2 , pin_memory = True, shuffle= False)


In [None]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7ccc8ed4d690>

In [None]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights

In [None]:
# Load the pre-trained ViT-B-16 model

weights = ViT_B_16_Weights.DEFAULT
model = vit_b_16(weights = weights)

num_classes = len(train_dataset.classes)
#Replace the classification head
# - model.heads.head is the final classification layer.
# - We replace it with a new nn.Linear layer.
# - in_features is the size of the input to the head (768 for vit_b_16).
# - out_features is the number of classes in your dataset.
model.heads.head =  nn.Linear(in_features = model.heads.head.in_features, out_features = num_classes)

model = model.to(device)

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:05<00:00, 65.7MB/s]


## Not freezing layers , so that the model can change its whole parameters a little bit to get better at our task of classification , its already good at image classification but we are making it best for wildlife images , so thats why I want to keep it like this

In [None]:
# Setting up loss function and optimizer
import torch.optim as optim

loss_fn = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(),lr=1e-5) # smaller learning rates are recommended for fine tuning , changing paramters drastically can disturb pre-trained weights



num_epochs = 20 # the number of times entire training dataset will be passed through model

for epoch in range(num_epochs):
  print(f"Epoch {epoch+1}/{num_epochs}")
  print('-' * 10)

  model.train()

  running_loss = 0.0 # 0.0 to accumualte the loss over all batches for avg loss at end of epoch
  running_corrects = 0.0 # correct predictions #0.0 resets the accumulator for the new epoch

  for inputs,labels in train_loader:
    inputs = inputs.to(device) # batchh of input images
    labels = labels.to(device) # true labels for input images

    #zero parameter gradients
    optimizer.zero_grad()

    outputs = model(inputs)
    _,preds = torch.max(outputs,1) # - we are ignoring first value
    loss = loss_fn(outputs,labels)


    loss.backward()
    optimizer.step()

    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data)


  epoch_loss = running_loss/ len(train_dataset)
  epoch_acc = running_corrects.double()/len(train_dataset)

  print(f'Training loss:{epoch_loss:4f}Acc:{epoch_acc:.4f}')

        # Validation Phase
  model.eval()  # Set model to evaluation mode
  val_running_loss = 0.0
  val_running_corrects = 0

  with torch.no_grad():
      for inputs, labels in val_loader:
          inputs = inputs.to(device)
          labels = labels.to(device)

          outputs = model(inputs)
          _, preds = torch.max(outputs, 1)
          loss = loss_fn(outputs, labels)

          val_running_loss += loss.item() * inputs.size(0)
          val_running_corrects += torch.sum(preds == labels.data)

  val_epoch_loss = val_running_loss / len(val_dataset)
  val_epoch_acc = val_running_corrects.double() / len(val_dataset)

  print(f'Validation Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}')







In [None]:
import time
import copy
from tqdm import tqdm  # For progress bars
import torch
from torch.utils.tensorboard import SummaryWriter  # For logging

def train_model(model, train_loader, val_loader, loss_fn, optimizer,
                num_epochs=10, device='cuda', scheduler=None):

    # tnesorBoard
    writer = SummaryWriter('runs/wildlife_classifier')

    # Initialize best model tracking
    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_acc = 0.0

    # Training history
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print('-' * 10)
        start_time = time.time()

        # Training Phase
        model.train()
        running_loss = 0.0
        running_corrects = 0

        # Progress bar for training
        train_pbar = tqdm(train_loader, desc='Training')
        for inputs, labels in train_pbar:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            with torch.cuda.amp.autocast():  # Mixed precision training
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = loss_fn(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # Update progress bar
            train_pbar.set_postfix({'loss': loss.item()})

        # Calculate epoch statistics
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)

        # Store in history
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())

        # Log to tensorboard
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_acc, epoch)

        print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Validation Phase
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0

        # Progress bar for validation
        val_pbar = tqdm(val_loader, desc='Validation')
        with torch.no_grad():
            for inputs, labels in val_pbar:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = loss_fn(outputs, labels)

                val_running_loss += loss.item() * inputs.size(0)
                val_running_corrects += torch.sum(preds == labels.data)

                val_pbar.set_postfix({'loss': loss.item()})

        val_epoch_loss = val_running_loss / len(val_loader.dataset)
        val_epoch_acc = val_running_corrects.double() / len(val_loader.dataset)

        # Store in history
        history['val_loss'].append(val_epoch_loss)
        history['val_acc'].append(val_epoch_acc.item())

        # Log to tensorboard
        writer.add_scalar('Loss/val', val_epoch_loss, epoch)
        writer.add_scalar('Accuracy/val', val_epoch_acc, epoch)

        print(f'Validation Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}')

        # Save best model
        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            # Save checkpoint
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
                'history': history
            }, 'best_model_checkpoint.pth')
            print("Model improved! Checkpoint saved.")

        # Step the scheduler if provided
        if scheduler is not None:
            scheduler.step(val_epoch_loss)

        time_elapsed = time.time() - start_time
        print(f"Epoch completed in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")

    # Load best model weights
    model.load_state_dict(best_model_wts)
    writer.close()

    return model, history



In [None]:



optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                      patience=3, factor=0.1)
loss_fn = torch.nn.CrossEntropyLoss()

model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=10,
    device=device
)



Epoch 1/10
----------


  with torch.cuda.amp.autocast():  # Mixed precision training
Training: 100%|██████████| 152/152 [14:21<00:00,  5.66s/it, loss=0.664]


Training Loss: 1.5152 Acc: 0.7238


Validation: 100%|██████████| 62/62 [05:19<00:00,  5.15s/it, loss=0.0844]


Validation Loss: 0.3781 Acc: 0.9344
Model improved! Checkpoint saved.
Epoch completed in 19m 45s

Epoch 2/10
----------


Training: 100%|██████████| 152/152 [01:33<00:00,  1.63it/s, loss=0.23]


Training Loss: 0.2896 Acc: 0.9371


Validation: 100%|██████████| 62/62 [00:36<00:00,  1.71it/s, loss=0.0241]


Validation Loss: 0.2039 Acc: 0.9573
Model improved! Checkpoint saved.
Epoch completed in 2m 15s

Epoch 3/10
----------


Training: 100%|██████████| 152/152 [01:31<00:00,  1.66it/s, loss=0.323]


Training Loss: 0.1178 Acc: 0.9765


Validation: 100%|██████████| 62/62 [00:36<00:00,  1.72it/s, loss=0.0162]


Validation Loss: 0.1657 Acc: 0.9593
Model improved! Checkpoint saved.
Epoch completed in 2m 14s

Epoch 4/10
----------


Training: 100%|██████████| 152/152 [01:29<00:00,  1.70it/s, loss=0.155]


Training Loss: 0.1532 Acc: 0.9646


Validation: 100%|██████████| 62/62 [00:36<00:00,  1.68it/s, loss=0.0172]


Validation Loss: 0.1754 Acc: 0.9588
Epoch completed in 2m 6s

Epoch 5/10
----------


Training: 100%|██████████| 152/152 [01:29<00:00,  1.69it/s, loss=0.126]


Training Loss: 0.0782 Acc: 0.9845


Validation: 100%|██████████| 62/62 [00:35<00:00,  1.73it/s, loss=0.00463]


Validation Loss: 0.1003 Acc: 0.9771
Model improved! Checkpoint saved.
Epoch completed in 2m 10s

Epoch 6/10
----------


Training: 100%|██████████| 152/152 [01:30<00:00,  1.69it/s, loss=0.0629]


Training Loss: 0.0652 Acc: 0.9850


Validation: 100%|██████████| 62/62 [00:36<00:00,  1.69it/s, loss=0.00618]


Validation Loss: 0.1675 Acc: 0.9624
Epoch completed in 2m 7s

Epoch 7/10
----------


Training: 100%|██████████| 152/152 [01:32<00:00,  1.64it/s, loss=0.203]


Training Loss: 0.1372 Acc: 0.9639


Validation: 100%|██████████| 62/62 [00:36<00:00,  1.72it/s, loss=0.00694]


Validation Loss: 0.1778 Acc: 0.9532
Epoch completed in 2m 9s

Epoch 8/10
----------


Training: 100%|██████████| 152/152 [01:32<00:00,  1.64it/s, loss=0.0196]


Training Loss: 0.0801 Acc: 0.9798


Validation: 100%|██████████| 62/62 [00:35<00:00,  1.76it/s, loss=0.006]


Validation Loss: 0.2024 Acc: 0.9522
Epoch completed in 2m 8s

Epoch 9/10
----------


Training: 100%|██████████| 152/152 [01:31<00:00,  1.65it/s, loss=0.073]


Training Loss: 0.0829 Acc: 0.9794


Validation: 100%|██████████| 62/62 [00:35<00:00,  1.74it/s, loss=0.00881]


Validation Loss: 0.1556 Acc: 0.9619
Epoch completed in 2m 7s

Epoch 10/10
----------


Training: 100%|██████████| 152/152 [01:29<00:00,  1.69it/s, loss=0.0044]


Training Loss: 0.0193 Acc: 0.9965


Validation: 100%|██████████| 62/62 [00:37<00:00,  1.66it/s, loss=0.00297]

Validation Loss: 0.0957 Acc: 0.9771
Epoch completed in 2m 7s





In [None]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16

# Initialize the model architecture
model = vit_b_16(weights=None)  # No pre-trained weights since we'll load our own
num_classes = len(train_dataset.classes)
model.heads.head = nn.Linear(in_features=model.heads.head.in_features, out_features=num_classes)

# Load the saved weights
model.load_state_dict(torch.load('best_model_checkpoint.pth', map_location='cpu'))
model.eval()  # Set the model to evaluation mode

  model.load_state_dict(torch.load('best_model_checkpoint.pth', map_location='cpu'))


RuntimeError: Error(s) in loading state_dict for VisionTransformer:
	Missing key(s) in state_dict: "class_token", "conv_proj.weight", "conv_proj.bias", "encoder.pos_embedding", "encoder.layers.encoder_layer_0.ln_1.weight", "encoder.layers.encoder_layer_0.ln_1.bias", "encoder.layers.encoder_layer_0.self_attention.in_proj_weight", "encoder.layers.encoder_layer_0.self_attention.in_proj_bias", "encoder.layers.encoder_layer_0.self_attention.out_proj.weight", "encoder.layers.encoder_layer_0.self_attention.out_proj.bias", "encoder.layers.encoder_layer_0.ln_2.weight", "encoder.layers.encoder_layer_0.ln_2.bias", "encoder.layers.encoder_layer_0.mlp.0.weight", "encoder.layers.encoder_layer_0.mlp.0.bias", "encoder.layers.encoder_layer_0.mlp.3.weight", "encoder.layers.encoder_layer_0.mlp.3.bias", "encoder.layers.encoder_layer_1.ln_1.weight", "encoder.layers.encoder_layer_1.ln_1.bias", "encoder.layers.encoder_layer_1.self_attention.in_proj_weight", "encoder.layers.encoder_layer_1.self_attention.in_proj_bias", "encoder.layers.encoder_layer_1.self_attention.out_proj.weight", "encoder.layers.encoder_layer_1.self_attention.out_proj.bias", "encoder.layers.encoder_layer_1.ln_2.weight", "encoder.layers.encoder_layer_1.ln_2.bias", "encoder.layers.encoder_layer_1.mlp.0.weight", "encoder.layers.encoder_layer_1.mlp.0.bias", "encoder.layers.encoder_layer_1.mlp.3.weight", "encoder.layers.encoder_layer_1.mlp.3.bias", "encoder.layers.encoder_layer_2.ln_1.weight", "encoder.layers.encoder_layer_2.ln_1.bias", "encoder.layers.encoder_layer_2.self_attention.in_proj_weight", "encoder.layers.encoder_layer_2.self_attention.in_proj_bias", "encoder.layers.encoder_layer_2.self_attention.out_proj.weight", "encoder.layers.encoder_layer_2.self_attention.out_proj.bias", "encoder.layers.encoder_layer_2.ln_2.weight", "encoder.layers.encoder_layer_2.ln_2.bias", "encoder.layers.encoder_layer_2.mlp.0.weight", "encoder.layers.encoder_layer_2.mlp.0.bias", "encoder.layers.encoder_layer_2.mlp.3.weight", "encoder.layers.encoder_layer_2.mlp.3.bias", "encoder.layers.encoder_layer_3.ln_1.weight", "encoder.layers.encoder_layer_3.ln_1.bias", "encoder.layers.encoder_layer_3.self_attention.in_proj_weight", "encoder.layers.encoder_layer_3.self_attention.in_proj_bias", "encoder.layers.encoder_layer_3.self_attention.out_proj.weight", "encoder.layers.encoder_layer_3.self_attention.out_proj.bias", "encoder.layers.encoder_layer_3.ln_2.weight", "encoder.layers.encoder_layer_3.ln_2.bias", "encoder.layers.encoder_layer_3.mlp.0.weight", "encoder.layers.encoder_layer_3.mlp.0.bias", "encoder.layers.encoder_layer_3.mlp.3.weight", "encoder.layers.encoder_layer_3.mlp.3.bias", "encoder.layers.encoder_layer_4.ln_1.weight", "encoder.layers.encoder_layer_4.ln_1.bias", "encoder.layers.encoder_layer_4.self_attention.in_proj_weight", "encoder.layers.encoder_layer_4.self_attention.in_proj_bias", "encoder.layers.encoder_layer_4.self_attention.out_proj.weight", "encoder.layers.encoder_layer_4.self_attention.out_proj.bias", "encoder.layers.encoder_layer_4.ln_2.weight", "encoder.layers.encoder_layer_4.ln_2.bias", "encoder.layers.encoder_layer_4.mlp.0.weight", "encoder.layers.encoder_layer_4.mlp.0.bias", "encoder.layers.encoder_layer_4.mlp.3.weight", "encoder.layers.encoder_layer_4.mlp.3.bias", "encoder.layers.encoder_layer_5.ln_1.weight", "encoder.layers.encoder_layer_5.ln_1.bias", "encoder.layers.encoder_layer_5.self_attention.in_proj_weight", "encoder.layers.encoder_layer_5.self_attention.in_proj_bias", "encoder.layers.encoder_layer_5.self_attention.out_proj.weight", "encoder.layers.encoder_layer_5.self_attention.out_proj.bias", "encoder.layers.encoder_layer_5.ln_2.weight", "encoder.layers.encoder_layer_5.ln_2.bias", "encoder.layers.encoder_layer_5.mlp.0.weight", "encoder.layers.encoder_layer_5.mlp.0.bias", "encoder.layers.encoder_layer_5.mlp.3.weight", "encoder.layers.encoder_layer_5.mlp.3.bias", "encoder.layers.encoder_layer_6.ln_1.weight", "encoder.layers.encoder_layer_6.ln_1.bias", "encoder.layers.encoder_layer_6.self_attention.in_proj_weight", "encoder.layers.encoder_layer_6.self_attention.in_proj_bias", "encoder.layers.encoder_layer_6.self_attention.out_proj.weight", "encoder.layers.encoder_layer_6.self_attention.out_proj.bias", "encoder.layers.encoder_layer_6.ln_2.weight", "encoder.layers.encoder_layer_6.ln_2.bias", "encoder.layers.encoder_layer_6.mlp.0.weight", "encoder.layers.encoder_layer_6.mlp.0.bias", "encoder.layers.encoder_layer_6.mlp.3.weight", "encoder.layers.encoder_layer_6.mlp.3.bias", "encoder.layers.encoder_layer_7.ln_1.weight", "encoder.layers.encoder_layer_7.ln_1.bias", "encoder.layers.encoder_layer_7.self_attention.in_proj_weight", "encoder.layers.encoder_layer_7.self_attention.in_proj_bias", "encoder.layers.encoder_layer_7.self_attention.out_proj.weight", "encoder.layers.encoder_layer_7.self_attention.out_proj.bias", "encoder.layers.encoder_layer_7.ln_2.weight", "encoder.layers.encoder_layer_7.ln_2.bias", "encoder.layers.encoder_layer_7.mlp.0.weight", "encoder.layers.encoder_layer_7.mlp.0.bias", "encoder.layers.encoder_layer_7.mlp.3.weight", "encoder.layers.encoder_layer_7.mlp.3.bias", "encoder.layers.encoder_layer_8.ln_1.weight", "encoder.layers.encoder_layer_8.ln_1.bias", "encoder.layers.encoder_layer_8.self_attention.in_proj_weight", "encoder.layers.encoder_layer_8.self_attention.in_proj_bias", "encoder.layers.encoder_layer_8.self_attention.out_proj.weight", "encoder.layers.encoder_layer_8.self_attention.out_proj.bias", "encoder.layers.encoder_layer_8.ln_2.weight", "encoder.layers.encoder_layer_8.ln_2.bias", "encoder.layers.encoder_layer_8.mlp.0.weight", "encoder.layers.encoder_layer_8.mlp.0.bias", "encoder.layers.encoder_layer_8.mlp.3.weight", "encoder.layers.encoder_layer_8.mlp.3.bias", "encoder.layers.encoder_layer_9.ln_1.weight", "encoder.layers.encoder_layer_9.ln_1.bias", "encoder.layers.encoder_layer_9.self_attention.in_proj_weight", "encoder.layers.encoder_layer_9.self_attention.in_proj_bias", "encoder.layers.encoder_layer_9.self_attention.out_proj.weight", "encoder.layers.encoder_layer_9.self_attention.out_proj.bias", "encoder.layers.encoder_layer_9.ln_2.weight", "encoder.layers.encoder_layer_9.ln_2.bias", "encoder.layers.encoder_layer_9.mlp.0.weight", "encoder.layers.encoder_layer_9.mlp.0.bias", "encoder.layers.encoder_layer_9.mlp.3.weight", "encoder.layers.encoder_layer_9.mlp.3.bias", "encoder.layers.encoder_layer_10.ln_1.weight", "encoder.layers.encoder_layer_10.ln_1.bias", "encoder.layers.encoder_layer_10.self_attention.in_proj_weight", "encoder.layers.encoder_layer_10.self_attention.in_proj_bias", "encoder.layers.encoder_layer_10.self_attention.out_proj.weight", "encoder.layers.encoder_layer_10.self_attention.out_proj.bias", "encoder.layers.encoder_layer_10.ln_2.weight", "encoder.layers.encoder_layer_10.ln_2.bias", "encoder.layers.encoder_layer_10.mlp.0.weight", "encoder.layers.encoder_layer_10.mlp.0.bias", "encoder.layers.encoder_layer_10.mlp.3.weight", "encoder.layers.encoder_layer_10.mlp.3.bias", "encoder.layers.encoder_layer_11.ln_1.weight", "encoder.layers.encoder_layer_11.ln_1.bias", "encoder.layers.encoder_layer_11.self_attention.in_proj_weight", "encoder.layers.encoder_layer_11.self_attention.in_proj_bias", "encoder.layers.encoder_layer_11.self_attention.out_proj.weight", "encoder.layers.encoder_layer_11.self_attention.out_proj.bias", "encoder.layers.encoder_layer_11.ln_2.weight", "encoder.layers.encoder_layer_11.ln_2.bias", "encoder.layers.encoder_layer_11.mlp.0.weight", "encoder.layers.encoder_layer_11.mlp.0.bias", "encoder.layers.encoder_layer_11.mlp.3.weight", "encoder.layers.encoder_layer_11.mlp.3.bias", "encoder.ln.weight", "encoder.ln.bias", "heads.head.weight", "heads.head.bias". 
	Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "best_val_acc", "history". 

In [None]:
# Load the entire checkpoint
checkpoint = torch.load('best_model_checkpoint.pth', map_location='cpu')


  checkpoint = torch.load('best_model_checkpoint.pth', map_location='cpu')


In [None]:
# Extract the model's state_dict from the checkpoint
state_dict = checkpoint['model_state_dict']

In [None]:
# Load the state_dict into your model
model.load_state_dict(state_dict)

<All keys matched successfully>

In [None]:
model.eval()
print("Model loaded successfully.")

Model loaded successfully.


In [None]:
# Apply dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    model,  # the model instance
    {nn.Linear},  # layers to quantize
    dtype=torch.qint8  # data type for quantized weights
)

In [None]:
import os

def print_size_of_model(model, label=''):
    torch.save(model.state_dict(), 'temp.p')
    size_mb = os.path.getsize('temp.p') / 1e6
    print(f'{label} Model Size: {size_mb:.2f} MB')
    os.remove('temp.p')
    return size_mb

# Original model size
print_size_of_model(model, 'Original')

# Quantized model size
print_size_of_model(quantized_model, 'Quantized')

Original Model Size: 343.52 MB
Quantized Model Size: 173.46 MB


173.462068

In [None]:
from torch.utils.data import DataLoader

# Ensure your validation DataLoader is set up
# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

correct = 0
total = 0

quantized_model.eval()
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to('cpu')  # Quantized model is on CPU
        outputs = quantized_model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f'Quantized Model Accuracy: {accuracy:.4f}')

Quantized Model Accuracy: 0.9654


In [None]:
import time

# Function to measure inference time
def measure_inference_time(model, data_loader):
    start_time = time.time()
    with torch.no_grad():
        for inputs, _ in data_loader:
            inputs = inputs.to('cpu')
            outputs = model(inputs)
    end_time = time.time()
    total_time = end_time - start_time
    return total_time

# Measure time for original model
original_time = measure_inference_time(model, val_loader)
print(f'Original Model Inference Time: {original_time:.2f} seconds')

# Measure time for quantized model
quantized_time = measure_inference_time(quantized_model, val_loader)
print(f'Quantized Model Inference Time: {quantized_time:.2f} seconds')

# Calculate speed-up
speed_up = original_time / quantized_time
print(f'Speed-up: {speed_up:.2f}x')

Original Model Inference Time: 1433.81 seconds
Quantized Model Inference Time: 1230.51 seconds
Speed-up: 1.17x


In [None]:

normal_model_save_path = '/content/drive/MyDrive/best_model_checkpoint.pth'

In [None]:
torch.save(model.state_dict(), normal_model_save_path)
print(f"Normal model saved to {normal_model_save_path}")

Normal model saved to /content/drive/MyDrive/best_model_checkpoint.pth


In [None]:
quantized_model_save_path = '/content/drive/MyDrive/quantized_model.pth'

In [None]:
torch.save(quantized_model, quantized_model_save_path)
print(f"Quantized model saved to {quantized_model_save_path}")

Quantized model saved to /content/drive/MyDrive/quantized_model.pth
