In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import math
import optuna
import detectors
import timm
from torchvision import transforms
# %matplotlib notebook
%matplotlib inline

In [None]:
# show the names all available GPU devices:
[torch.cuda.device(i) for i in range(torch.cuda.device_count())]
torch.cuda.device_count()
torch.cuda.get_device_name(0)

In [None]:
# Get the CIFAR10 dataset from 'torch':
from torchvision import datasets
from torchvision.transforms import ToTensor

# Download the training data from open datasets.
training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# Create data loaders:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [None]:
model = timm.create_model("resnet18_cifar10", pretrained=True)

In [None]:
# calculate the model's accuracy on the test data:
model.eval()
correct = 0
total = 0

# CIFAR10 mean and standard deviation:
mean = [       0.4914,      0.4822,      0.4465    ]
std = [      0.2023,      0.1994,      0.201 ]
#mean = [       0.5,      0.5,      0.5    ]
#std = [      0.5,      0.5,      0.5 ]

normalize = transforms.Normalize(mean, std)

with torch.no_grad():
    for i, data in enumerate(test_dataloader):
        images, labels = data
        # Normalize the images batch:
        images = normalize(images)
        outputs = model(images)
        predicted = outputs.argmax(dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if i % 10 == 0:
            print(f'For step {i} the accuracy is {100 * correct / total}%')

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

In [None]:
def quant_tensor(tensor):
    quant_model_flatten = tensor
    if tensor.numel() >= 16:
        quant_model_flatten = quant_model_flatten.flatten()
        quant_model_flatten = quant_model_flatten.unfold(0, 16, 16).mean(-1)
        quant_model_flatten = quant_model_flatten.repeat_interleave(16)
        quant_model_flatten = quant_model_flatten.view(10, 512)

    return quant_model_flatten

def quantize_linear_layers(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Extract weights and biases
            weight = module.weight
            bias = module.bias

            # Quantize weights
            quantized_weight = quant_tensor(weight)

            # Replace original weights with quantized weights
            module.weight = nn.Parameter(quantized_weight, requires_grad=False)

            # Optional: Quantize biases if they exist
            if bias is not None:
                quantized_bias = quant_tensor(bias)
                module.bias = nn.Parameter(quantized_bias, requires_grad=False)
    return model

quant_model = quantize_linear_layers(model)

In [None]:
# calculate the model's accuracy on the test data:
quant_model.eval()
correct = 0
total = 0

# CIFAR10 mean and standard deviation:
mean = [       0.4914,      0.4822,      0.4465    ]
std = [      0.2023,      0.1994,      0.201 ]
#mean = [       0.5,      0.5,      0.5    ]
#std = [      0.5,      0.5,      0.5 ]

normalize = transforms.Normalize(mean, std)

with torch.no_grad():
    for i, data in enumerate(test_dataloader):
        images, labels = data
        # Normalize the images batch:
        images = normalize(images)
        outputs = quant_model(images)
        predicted = outputs.argmax(dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if i % 10 == 0:
            print(f'For step {i} the accuracy is {100 * correct / total}%')

print(f'Accuracy of the quantized network on the 10000 test images: {100 * correct / total}%')