In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn.utils import prune
from torch.amp import autocast, GradScaler

import torchvision
from torchvision.datasets import ImageNet, CIFAR10
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.transforms._presets import ImageClassification

from pathlib import Path
from tqdm.notebook import tqdm
from einops import einsum, rearrange, reduce
from typing import Union, Tuple, Any, Generator



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
weights = ResNet18_Weights.IMAGENET1K_V1
transforms = weights.transforms()
transforms

In [None]:
model = resnet18(weights=weights)
model

In [None]:
mode

In [None]:
type(model)

In [None]:
for layer_i, param in enumerate(model.parameters()):
    print('Layer', layer_i, 'shape: \n', param.shape,  '\n')

In [None]:
BATCH_SIZE = 1024

try:
    from google.colab import drive
    drive.mount('/content/drive')
    datapath = Path().cwd() / 'drive' / 'MyDrive' / 'data'
except:
    datapath = Path().cwd() / 'data'
finally:
    print(f'Data from {datapath}')

cifar10 = CIFAR10(
    root=datapath,
    train=False,
    transform=transforms,
    download=True
)

cifar10_test = CIFAR10(
    root=datapath,
    train=False,
    transform=transforms,
    download=True
)

dataloader = torch.utils.data.DataLoader(
    cifar10,
    batch_size=BATCH_SIZE,
    shuffle=True
)

dataloader_test = torch.utils.data.DataLoader(
    cifar10_test,
    batch_size=BATCH_SIZE,
    shuffle=False
)


In [None]:
label_mapping = dict(enumerate(cifar10.classes))
label_mapping

In [None]:
def visualize_normalized_image(image: torch.Tensor,
                               transform: ImageClassification,
                               batch_size: int) -> None:
    """
    Combines a batch of images into 1 and plots them.
    """
    assert isinstance(image, torch.Tensor), f'image type is {type(image)}'
    assert batch_size % 2 == 0, f'uneven {batch_size}'

    image = image.numpy()
    std, mean = transform.std, transform.mean
    std, mean = np.array(std), np.array(mean)

    b1, b2 = batch_size//8, 8
    image = einsum(image, std, 'b c h w, c -> b h w c') + mean
    image = rearrange(image, '(b1 b2) h w c -> (b1 h) (b2 w) c', b1=b1, b2=b2)
    image = np.clip(image, 0, 1)

    plt.imshow(image)
    plt.axis('off')
    plt.show()

    print('Shape: ', image.shape)

In [None]:
# n_of_batch_images_to_show = 4
# current_label = None

# for i, (image, label) in enumerate(dataloader):
#     label0 = int(label[0])

#     if label0 != current_label:
#         visualize_normalized_image(image, transforms, 32)
#         n_of_batch_images_to_show -= 1
#         current_label = label0

#     if n_of_batch_images_to_show == 0:
#         break


Initial CIFAR10 Prediction without any

In [None]:
def evaluate_model(model, dataloader, stop_at=None):
    model = model.to(device)
    model.eval()
    total_correct = 0
    total_samples = 0
    i = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader_test):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            i += 1
            if stop_at:
                if i >= stop_at:
                    break

        accuracy = total_correct / total_samples

    return total_correct, total_samples, accuracy

In [None]:
total_correct, total_samples, accuracy = evaluate_model(model, dataloader_test)

In [None]:
print(f'Total Correct: {total_correct}\n')
print(f'Total Samples: {total_samples}\n')
print(f'Accuraccy: {accuracy}\n')

# Model Modification

In [None]:
model_fc_in = model.fc.in_features
model_fc_out = len(label_mapping)

model.fc = nn.Linear(model_fc_in, model_fc_out)

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
for param in model.fc.parameters():
    param.requires_grad = True

In [None]:
total_correct, total_samples, accuracy = evaluate_model(model, dataloader_test)

In [None]:
print(f'Total Correct: {total_correct}\n')
print(f'Total Samples: {total_samples}\n')
print(f'Accuraccy: {accuracy}\n')

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
# scaler = GradScaler('cuda')

In [None]:
def train_model(model, dataloader, loss, optimizer, num_epochs=10):
    model = model.to(device)
    model.train()

    for epoch in tqdm(range(num_epochs)):
        total_loss = 0

        for images, labels in tqdm(dataloader):
            optimizer.zero_grad()
            images = images.to(device)
            labels = labels.to(device)

            with autocast(device_type="cuda"):
                outputs = model(images)
                loss = criterion(outputs, labels)

            # scaler.scale(loss).backward()
            # scaler.step(optimizer)
            # scaler.update()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch: {epoch + 1}, loss = {avg_loss:.4f}')


In [None]:
train_model(model, dataloader, criterion, optimizer, num_epochs=10)

In [None]:
evaluate_model(model, dataloader_test)

In [None]:
prune.random_unstructured(model.fc, name='weight', amount=0.3)
# prune.remove(model.fc, "weight")

In [None]:
train_model(model, dataloader, criterion, optimizer, num_epochs=10)

In [None]:
evaluate_model(model, dataloader_test)

In [None]:
prune.random_unstructured(model.fc, name='weight', amount=0.3)

In [None]:
train_model(model, dataloader, criterion, optimizer, num_epochs=10)

In [None]:
evaluate_model(model, dataloader_test)

In [None]:
model.fc.weight_mask.sum()