In [None]:
from torchvision.io import read_image
from torchvision.models import vit_b_16, ViT_B_16_Weights, list_models
from torchvision.datasets import ImageNet, ImageFolder
from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassAccuracy
import torch
import time
import csv
import os
from torchvision.transforms import transforms

## Imports

In [None]:
from torchvision.io import read_image
from torchvision.models import vit_b_16, ViT_B_16_Weights, list_models
from torchvision.datasets import ImageNet, ImageFolder
from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassAccuracy
import torch
import time
import csv
import os
from torchvision.transforms import transforms

## Define all needed abilities to run

In [None]:
# Move model and data to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device used is: " + str(device))

# create a metric for accuracy
metric = MulticlassAccuracy(num_classes=1000).to(device)

# Step 2: Initialize the inference transforms
weights =ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
preprocess = weights.transforms(antialias=True)

preprocess_w_gray2rgb = transforms.Compose([
    lambda x: x.expand(3, -1, -1) if x.shape[0] == 1 else x,
    preprocess
])
imagenet_val_dir = '/kaggle/input/imagenet-val/imagenet_val'
dataset = ImageFolder(root=imagenet_val_dir, loader=read_image, transform=preprocess_w_gray2rgb)
class_dict = dataset.class_to_idx
class_dict = {value: key for key, value in class_dict.items()}

# Create a DataLoader to load the images in batches
dataloader = DataLoader(dataset, batch_size=8, shuffle=False)

# inference for one image only - verify 

In [None]:
img = read_image("/kaggle/input/imagenet-val/imagenet_val/10/ILSVRC2012_val_00010763.JPEG")
print(img.shape[0])
if img.shape[0] == 1:
     img = img.expand(3, -1, -1)
print(img.size())

# Step 1: Initialize model with the best available weights
weights =ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
model = vit_b_16(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
print(class_id)
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

## Functions

In [None]:
def check_label_name(predictions, weights):
    for prediction in predictions:
        class_id = prediction.argmax().item()
        print(class_id)
        score = prediction[class_id].item()
        category_name = weights.meta["categories"][class_id]
        print(f"{category_name}: {100 * score:.1f}%")
        print("\n")
        
        
def collect_data(model_name, accuracy, duration, device):
    # Calculate model size
    model_size = os.path.getsize(model_name) / (1024**2)

    # Append data to CSV file
    with open('model_data.csv', 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([model_name, accuracy, duration, model_size, device])
        
def model_quantization(model, backend='x86', save=False, qtype='int'):
    model.eval()
    if qtype == 'int':
        type_to_quantize = torch.qint8
    elif qtype == 'uint':
        type_to_quantize = torch.quint8
    else:
        print("invalid type, stopping")
        exit(1)
    # Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
    #model.qconfig = torch.quantization.get_default_qconfig(backend)
    #torch.backends.quantized.engine = backend
    quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
    #scripted_quantized_model = torch.jit.script(quantized_model)
    
    if save:
        scripted_quantized_model.save("vit_scripted_quantized.pt")
        
    return quantized_model
        
        
def labels_process(labels, class_dict):
    # Change labels because of dataset idx
    labels = [class_dict[int(label)] for label in labels]
    labels = [int(num) for num in labels]
    labels = torch.tensor(labels)
    return labels
    
def inference(model, dataloader, class_dict, device, image_num_stop=40000):
    index_stop = image_num_stop // 8
    total_correct = 0
    total_samples = 0
    start_time = time.time()
    model.to(device)
    model.eval()
    
    with torch.no_grad():
        for index, (images, labels) in enumerate(dataloader):
            images = images.to(device)
        
            # Change labels because of dataset idx
            labels = labels_process(labels, class_dict)
            labels = labels.to(device)
        
            predictions = model(images)
            predicted_labels = torch.argmax(predictions, dim=1) + 1  # Add 1 to predicted labels
            
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.size(0)
        
            if index % 50 == 0:
                print("{} images were processed out of 50,000".format(8 * index))
            
            if index == index_stop:
                print("Number of images processed: {} stopping now".format(index_stop*8))
                print("stopped checking because of errors for the entire dataset \n ")
                break
            
    accuracy = total_correct / total_samples
    end_time = time.time()
    duration = (end_time - start_time) / 60
    
    return accuracy, duration

    
    
def inference_list(model, dataloader, class_dict, device, label_dict, image_num_stop=40000):
    index_stop = image_num_stop // 8
    label_counts = {label: {"correct": 0, "total": 0} for label in range(1,1001)}
    start_time = time.time()
    model.to(device)
    model.eval()
    
    with torch.no_grad():
        for index, (images, labels) in enumerate(dataloader):
            images = images.to(device)
        
            # Change labels because of dataset idx
            labels = labels_process(labels, class_dict)
            labels = labels.to(device)
        
            predictions = model(images)
            predicted_labels = torch.argmax(predictions, dim=1) + 1  # Add 1 to predicted labels
            
            for pred_label, true_label in zip(predicted_labels, labels):
                label_counts[true_label.item()]["total"] += 1
                if pred_label == true_label:
                    label_counts[true_label.item()]["correct"] += 1
            
            if index % 50 == 0:
                print("{} images were processed out of 50,000".format(8 * index))
            
            if index == index_stop:
                print("Number of images processed: {} stopping now".format(index_stop * 8))
                print("stopped checking because of errors for the entire dataset \n ")
                break
    
    accuracies = {}
    for index, (label, counts) in enumerate(label_counts.items()):
        accuracies[label_dict[index]] = counts["correct"] / counts["total"] if counts["total"] > 0 else 0.0
    
    end_time = time.time()
    duration = (end_time - start_time) / 60
    
    return accuracies, duration


def write_dict_to_csv(dictionary, filename):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['Label', 'Accuracy'])
        for label, accuracy in dictionary.items():
            writer.writerow([label, accuracy])

## Baseline ViT - ran on GPU

In [None]:
# Step 1: Initialize model with the best available weights
weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
model = vit_b_16(weights=weights)

accuracy, duration = inference(model=model, dataloader=dataloader, class_dict=class_dict, device=device, image_num_stop=10000)

print("Inference took {} minutes".format(duration))
print("Accuracy for this model is {}".format(accuracy))

model_name = "vit_baseline.pth"
torch.save(model, model_name)
collect_data(model_name=model_name, accuracy=accuracy, duration=duration, device=device)

## Dynamic quantization - float16

In [None]:
# deosn't work
# Step 1: Initialize model with the best available weights
weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
model = vit_b_16(weights=weights)
quantized_vit = model_quantization(model=model, qtype='float')

accuracy, duration = inference(model=quantized_vit, dataloader=dataloader, class_dict=class_dict, device="cpu", image_num_stop=10000)

print("Inference took {} minutes".format(duration))
print("Accuracy for this quantized model is {}".format(accuracy))

model_name = "vit_basic_quantization.pth"
torch.save(quantized_vit, model_name)
collect_data(model_name=model_name, accuracy=accuracy, duration=duration, device=device)

## Dynamic quantization - int8

In [None]:
# Step 1: Initialize model with the best available weights
weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
model = vit_b_16(weights=weights)
quantized_vit = model_quantization(model=model, qtype='uint')

# make label_dict - turn into a function later
label_list = weights.meta["categories"]

accuracy_list, duration = inference_list(model=model, dataloader=dataloader, class_dict=class_dict, device="cpu", label_dict=label_list, image_num_stop=10000)

print("Inference took {} minutes".format(duration))
print(accuracy_list)

write_dict_to_csv(accuracy_list, 'Vit_quantized_uint.csv')