# Test-Time Adaptation

In [None]:
!pip install ollama
!pip install diffusers
!pip install git+https://github.com/openai/CLIP.git

In [1]:
import torch
import torchvision.transforms as T
import torchvision.models as models
import torchvision
from test_methods.test import Tester
from test_time_adaptation.resnet50_dropout import ResNet50Dropout

In [2]:
imagenet_a_path = "imagenet-a"

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
import torchvision.transforms as T

augmentations = [
    T.RandomHorizontalFlip(p=1),
    T.RandomVerticalFlip(p=1),
    T.RandomRotation(degrees=30),
    T.RandomRotation(degrees=60),
    T.ColorJitter(brightness=0.2),
    T.ColorJitter(contrast=0.2),
    T.ColorJitter(saturation=0.2),
    T.ColorJitter(hue=0.2),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    T.RandomRotation(degrees=15),
    T.RandomAdjustSharpness(sharpness_factor=2, p=1),
    T.RandomGrayscale(p=1),
    T.RandomInvert(p=1),
    T.RandomAutocontrast(p=1),
    T.GaussianBlur(kernel_size=5),
]

augmix_augmentations = [
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0)
]

## Resnet50

In [5]:
exp_path_a = "Domain-Shift-Computer-Vision/experiments/Resnet50_ImagenetA_SGD"

In [6]:
MC = {
	"dropout_rate": 0.5,
	"num_samples": 10,
	"use_dropout": True
}

In [7]:
tester_resnet50 = Tester(
    model = ResNet50Dropout() if MC['use_dropout'] else models.resnet50,
    optimizer = torch.optim.SGD,
    exp_path = exp_path_a,
    device = device
)

In [8]:
#lr_setting = [{
#    "classifier" : [["fc.weight", "fc.bias"], 0.00025]    
#}, 0]
lr_setting_sgd = [0.00025] # setting used in MEMO paper for SGD
lr_setting_adam = [0.0001] # setting used in MEMO paper for ADAM

In [9]:
imagenetV1_weights = models.ResNet50_Weights.IMAGENET1K_V1 # MEMO paper used these weights
imagenetV2_weights = models.ResNet50_Weights.IMAGENET1K_V2

In [10]:
gen_aug_settings = {
    "clip_img_encoder" : "ViT-L/14",
    "num_img" : 40,
    "gen_data_path" : "Domain-Shift-Computer-Vision/imagenetA_generated",
    "use_t2i_similarity" : True,
    "t2i_img" : True,
    "i2i_img" : True,
    "threshold" : 0.45
}

In [None]:
tester_resnet50.test(
     augmentations = augmix_augmentations, 
     num_augmentations = 16,
     seed_augmentations = 42,
     batch_size = 64, 
     img_root = imagenet_a_path,
     num_adaptation_steps = 2,
     MEMO = True,
     lr_setting = lr_setting_sgd,
     top_augmentations = 8,
     weights_imagenet = imagenetV1_weights,
     prior_strength = 16,
     TTA = True,
     MC = MC,
     gen_aug_settings = gen_aug_settings
)

## Image Generation

In [None]:
import torch
from test_time_adaptation.image_generation.image_generator import ImageGenerator

In [None]:
imagenetA_generator = ImageGenerator()

In [None]:
# generate prompts
skipped_classes = imagenetA_generator.generate_prompts(
    num_prompts_per_class=20,
    style_of_picture="photograph",
    path="Domain-Shift-Computer-Vision/imagenetA_generated",
    context_llm = "Domain-Shift-Computer-Vision/test_time_adaptation/image_generation/llm_context.json",
    llm_model = "llama3.1", 
    clip_text_encoder = "ViT-L/14"
)

In [None]:
# generate images
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler
import torch

model_id = "runwayml/stable-diffusion-v1-5"

In [None]:
pipet2i = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipet2i.scheduler = DPMSolverMultistepScheduler.from_config(pipet2i.scheduler.config)
pipet2i = pipet2i.to("cuda")

num_inf_steps = 25

In [None]:
pipei2i = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipei2i.scheduler = DPMSolverMultistepScheduler.from_config(pipei2i.scheduler.config)
pipei2i = pipei2i.to("cuda")

strength = 0.89
num_inf_steps = int(strength**(-1)*25)

In [50]:
imagenet_a_generated_path = "Domain-Shift-Computer-Vision/imagenetA_generated"

In [None]:
class_to_skip = []

In [None]:
imagenetA_generator.generate_images(path = imagenet_a_generated_path,
                                    num_images_per_class = 25,
                                    class_to_skip = class_to_skip,
                                    image_generation_pipeline = pipei2i,
                                    num_inference_steps = num_inf_steps,
                                    guidance_scale = 12,
                                    strength=strength)

## Retrieving Images

In [13]:
from utility.data.get_data import get_data
import clip
import torch
import torch.nn.functional as F
from PIL import Image
from test_time_adaptation.image_generation.image_generator import retrieve_gen_images

In [None]:
dataloader = get_data(batch_size=32, 
                      img_root = "imagenet-a",
                      split_data=False)

In [323]:
stingray_img = dataloader.dataset[20][0]

In [46]:
clip_image_encoder = "ViT-L/14"
clip_model, clip_preprocess = clip.load(clip_image_encoder)

In [51]:
imagenet_a_generated_path = "Domain-Shift-Computer-Vision/imagenetA_generated"

In [381]:
retrieved_images = retrieve_gen_images(img = img[0],
                                       num_images = 40,
                                       data_path = imagenet_a_generated_path,
                                       clip_model = clip_model,
                                       clip_preprocess = clip_preprocess,
                                       t2i_images = True,
                                       i2i_images = False,
                                       use_t2i_similarity = True,
                                       threshold = 0.45)

In [None]:
from PIL import Image
import math

def create_image_grid(images, grid_width, save_path, cell_size=(100, 100)):
    """
    Create a grid of images from a list of PIL images.

    Args:
        images (list of PIL.Image): List of PIL images to arrange in a grid.
        grid_width (int): Number of columns in the grid.
        cell_size (tuple): Size of each cell in the grid (width, height).

    Returns:
        PIL.Image: An image containing the grid of images.
    """
    # Resize images to the specified cell size
    if len(images) == 0:
        print("No images")
        return
        
    resized_images = [img.resize(cell_size) for img in images]
    
    # Calculate grid dimensions
    grid_height = math.ceil(len(images) / grid_width)  # Number of rows needed
    grid_img_width = cell_size[0] * grid_width
    grid_img_height = cell_size[1] * grid_height

    # Create a blank canvas for the grid
    grid_img = Image.new('RGB', (grid_img_width, grid_img_height), (255, 255, 255))  # White background

    # Paste images into the grid
    for i, img in enumerate(resized_images):
        row = i // grid_width
        col = i % grid_width
        x = col * cell_size[0]
        y = row * cell_size[1]
        grid_img.paste(img, (x, y))

    if save_path:
        try:
            grid_img.save(save_path)
            print(f"Grid image saved to {save_path}")
        except Exception as e:
            print(f"Error saving the image: {e}")

    return grid_img

In [None]:
create_image_grid(retrieved_images, 
                  save_path = "Domain-Shift-Computer-Vision/grid_3.png",
                  grid_width = 4)

## Scraping Images

In [None]:
!pip install bing_image_downloader

In [8]:
from test_time_adaptation.image_generation.web_scrape import scrape_images_imagenetA

In [None]:
scrape_images_imagenetA(img_style = "a photo of", 
                        imgenetA_gen_path = "Domain-Shift-Computer-Vision/imagenetA_generated", 
                        limit = 10)