<a href="https://colab.research.google.com/github/iammuhammad41/Knowledge-Distillation/blob/main/knowledge_distillation_in_computer_vision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# %%capture
# ! pip install super_gradients
# !pip install diffusers --upgrade
# !pip install invisible_watermark transformers accelerate safetensors

In [None]:
import torch
import os
import time
from IPython.core.display import display, HTML

from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
from diffusers import StableDiffusionPipeline

import os
import torch
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
import time


In [None]:
from super_gradients.training import Trainer, models, dataloaders
from super_gradients.training.metrics import Accuracy, Top5
from torchvision import transforms

CRASH_HANDLER=False

trainer = Trainer(
    experiment_name="beit_tester",
    ckpt_root_dir="content/checkpoints"
    )

test_dataloader = dataloaders.get(
    "cifar10_val",
    dataloader_params={"batch_size": 64},
    dataset_params={"transforms": [transforms.ToTensor(), transforms.Resize(224, antialias=True)]}
    )

pretrained_beit = models.get('beit_base_patch16_224',
                             num_classes= 10,
                             arch_params={"image_size": [224, 224], "patch_size": [16, 16]},
                             pretrained_weights="cifar10"
                             )

metrics = trainer.test(model=pretrained_beit,
                       test_loader=test_dataloader,
                       test_metrics_list=[Accuracy(), Top5()]
                       )

In [None]:
metrics

In [None]:
from super_gradients.training import KDTrainer

experiment_name = "my_first_kd_run"

checkpoint_dir = "kd_checkpoints"

kd_trainer = KDTrainer(experiment_name=experiment_name, ckpt_root_dir=checkpoint_dir)

In [None]:
from super_gradients.training import dataloaders, models

train_dataloader = dataloaders.get("cifar10_train",
                                   dataloader_params={"batch_size": 128}
                                   )

val_dataloader = dataloaders.get("cifar10_val",
                                 dataloader_params={"batch_size": 512}
                                 )

student_resnet18 = models.get('resnet18_cifar', num_classes=10)

In [None]:
from matplotlib import pyplot as plt

def show(images, labels, classes, rows=6, columns=5):
  fig = plt.figure(figsize=(10, 10))

  for i in range(1, columns * rows + 1):
      fig.add_subplot(rows, columns, i)
      plt.imshow(images[i-1].permute(1, 2, 0).clamp(0, 1))
      plt.xticks([])
      plt.yticks([])
      plt.title(f"{classes[labels[i-1]]}")

In [None]:
vis_images_train, vis_labels_train = next(iter(train_dataloader))
show(vis_images_train, vis_labels_train, classes=train_dataloader.dataset.classes)

print(vis_images_train.shape, vis_labels_train.shape)

In [None]:
from super_gradients.training import training_hyperparams
from super_gradients.training.losses import KDLogitsLoss, LabelSmoothingCrossEntropyLoss

kd_params = {
    "max_epochs": 10,
    'lr_cooldown_epochs': 0,  # We dont want to use lr cooldown since we only train for 3 epochs
    'lr_warmup_epochs': 0,    # We dont want to use lr  warmup  since we only train for 3 epochs
    "loss": KDLogitsLoss(distillation_loss_coeff=0.8, task_loss_fn=LabelSmoothingCrossEntropyLoss()),
    "loss_logging_items_names": ["Loss", "Task Loss", "Distillation Loss"]}

training_params = training_hyperparams.get("imagenet_resnet50_kd",
                                           overriding_params=kd_params)

In [None]:
from pprint import pprint

pprint(training_params)

In [None]:
arch_params={"teacher_input_adapter": transforms.Resize(224)}

In [None]:
kd_trainer.train(training_params = training_params,
                 student = student_resnet18,
                 teacher = pretrained_beit,
                 kd_architecture = "kd_module",
                 kd_arch_params = arch_params,
                 train_loader = train_dataloader,
                 valid_loader = val_dataloader
                 )

In [None]:
from super_gradients.training.metrics import Accuracy, Top5

metrics = trainer.test(model=student_resnet18, test_loader=val_dataloader, test_metrics_list=[Accuracy(), Top5()])
print()
print(f"Accuracy: {metrics['Accuracy']:.3f}")
print(f"Top 5:    {metrics['Top5']:.3f}")

In [None]:
%%capture
deci_diffusion_pipeline = StableDiffusionPipeline.from_pretrained('Deci/DeciDiffusion-v1-0',
                                                   custom_pipeline='Deci/DeciDiffusion-v1-0',
                                                   torch_dtype=torch.float16
                                                   )

deci_diffusion_pipeline.unet = deci_diffusion_pipeline.unet.from_pretrained('Deci/DeciDiffusion-v1-0',
                                              subfolder='flexible_unet',
                                              torch_dtype=torch.float16)

# Move pipeline to device
deci_diffusion_pipeline = deci_diffusion_pipeline.to('cuda')

def text_to_image(pipeline, prompt):

    # Start the timer
    start_time = time.time()

    # Call the pipeline function directly
    result = pipeline([prompt], generator=torch.Generator("cuda").manual_seed(42))

    # Calculate and print the elapsed time
    elapsed_time = time.time() - start_time
    display(HTML(f'<span style="color: #3264ff; font-weight:bold;font-size: 20px;">Time taken to generate: {elapsed_time:.2f} seconds</span>'))

    img = result.images[0]

    filename = prompt.replace(' ', '_')
    if len(filename) > 100:  # Limit filename to 100 characters
        filename = filename[:100]

    # Incorporate the pipeline's class name into the filename
    pipeline_name = pipeline.__class__.__name__
    save_path = os.path.join("/content", f"{filename}.png")
    img.save(save_path)

    # Display the saved image
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    return save_path

In [None]:
cifar_classes = train_dataloader.dataset.classes

for classes in cifar_classes:
    text_to_image(deci_diffusion_pipeline, classes)

In [None]:
import os
from PIL import Image
import numpy as np
import requests
import torch

trained_model = models.get('resnet18_cifar',
                           checkpoint_path=f"{checkpoint_dir}/{kd_trainer.experiment_name}/ckpt_best.pth",
                           num_classes= 10)

In [None]:
def predict_and_display(path, model=trained_model, class_list=cifar_classes, device='cuda'):
    """
    Load image from the specified path, preprocess it, predict its class using the given model,
    and then display the image with its predicted class as the label.

    Args:
    - path (str): Path to the image file.
    - model (torch.nn.Module): The trained model for prediction.
    - class_list (list): List of classes.
    - device (str): Device for running the model. Default is 'cuda'.

    Returns:
    - None. Displays the image with the predicted class label.
    """

    # Load and convert the image to numpy array
    image = np.asarray(Image.open(path))

    # Define the transformations
    pred_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    ])

    # Ensure the model is in evaluation mode and on the specified device
    model = model.eval().to(device)

    # Make a prediction
    predictions = model(pred_transforms(image).unsqueeze(0).to(device))

    # Display the image with the predicted class as the label
    plt.xlabel(class_list[torch.argmax(predictions)])
    plt.imshow(image)
    plt.show()

In [None]:
predict_and_display("/content/DeciDiffusionPipeline_ship.png")

In [None]:
predict_and_display("/content/DeciDiffusionPipeline_frog.png")

In [None]:
predict_and_display("/content/DeciDiffusionPipeline_airplane.png")

In [None]:
predict_and_display("/content/DeciDiffusionPipeline_bird.png")