# Test-Time Adaptation

In [1]:
import torch
import torchvision.transforms as T
import torchvision.models as models
import torchvision
from test_methods.test import Tester
from test_time_adaptation.resnet50_dropout import ResNet50Dropout

In [3]:
imagenet_a_path = "imagenet-a"
imagenet_b_path = "imagenetv2-matched-frequency-format-val/"

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
import torchvision.transforms as T

augmentations = [
    T.RandomHorizontalFlip(p=1),
    T.RandomVerticalFlip(p=1),
    T.RandomRotation(degrees=30),
    T.RandomRotation(degrees=60),
    T.ColorJitter(brightness=0.2),
    T.ColorJitter(contrast=0.2),
    T.ColorJitter(saturation=0.2),
    T.ColorJitter(hue=0.2),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    T.RandomRotation(degrees=15),
    T.RandomAdjustSharpness(sharpness_factor=2, p=1),
    T.RandomGrayscale(p=1),
    T.RandomInvert(p=1),
    T.RandomAutocontrast(p=1),
    T.GaussianBlur(kernel_size=5),
]

augmix_augmentations = [
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0)
]

## Resnet50

In [6]:
exp_path_a = "/home/sagemaker-user/Domain-Shift-Computer-Vision/experiments/Resnet50_ImagenetA_SGD"

In [1]:
MC = {
	"dropout_rate": 0.5,
	"num_samples": 10,
	"use_dropout": True
}

In [7]:
tester_resnet50 = Tester(
    model = ResNet50Dropout() if MC['use_dropout'] else models.resnet50,
    optimizer = torch.optim.SGD,
    exp_path = exp_path_a,
    device = device
)

In [8]:
#lr_setting = [{
#    "classifier" : [["fc.weight", "fc.bias"], 0.00025]    
#}, 0]
lr_setting_sgd = [0.00025] # setting used in MEMO paper for SGD
lr_setting_adam = [0.0001] # setting used in MEMO paper for ADAM

In [9]:
imagenetV1_weights = models.ResNet50_Weights.IMAGENET1K_V1 # MEMO paper used these weights
imagenetV2_weights = models.ResNet50_Weights.IMAGENET1K_V2

In [None]:
tester_resnet50.test(
     augmentations = augmix_augmentations, 
     num_augmentations = 16,
     seed_augmentations = 42,
     batch_size = 64, 
     img_root = imagenet_a_path,
     MEMO = False,
     lr_setting = None,
     top_augmentations = 8,
     weights_imagenet = imagenetV1_weights,
     prior_strength = 0.,
     TTA = True,
     MC = MC
)

-----

In [None]:
!pip install ollama

In [None]:
!pip install diffusers

In [33]:
! pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-4xzrocls
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-4xzrocls
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting ftfy (from clip==1.0)
  Using cached ftfy-6.2.3-py3-none-any.whl.metadata (7.8 kB)
Using cached ftfy-6.2.3-py3-none-any.whl (43 kB)
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25ldone
[?25h  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369492 sha256=248665e3e232817c7bbfbff0485258b40b84b09df5d0833b1a194b61c074ead2
  Stored in directory: /tmp/pip-ephem-wheel-cache-hfdf6mod/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4
Successfully built clip
Installing collected packages: ftfy, clip
Successfully installed

In [2]:
import torch

In [2]:
from test_time_adaptation.image_generation.image_generator import ImageGenerator

In [3]:
imagenetA_generator = ImageGenerator()

In [None]:
# generate prompts
skipped_classes = imagenetA_generator.generate_prompts(
    num_prompts=2,
    style_of_picture="photograph",
    path="/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
    context_llm = "/home/sagemaker-user/Domain-Shift-Computer-Vision/test_time_adaptation/image_generation/llm_context.json"
)

In [None]:
# generate images
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler

model_id = "runwayml/stable-diffusion-v1-5"
pipet2i = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipet2i.scheduler = DPMSolverMultistepScheduler.from_config(pipet2i.scheduler.config)
pipet2i = pipet2i.to("cuda")

In [None]:
imagenetA_generator.generate_images(path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
                                    num_images = 1,
                                    image_generation_pipeline = pipet2i,
                                    num_inference_steps = 25,
                                    guidance_scale = 9,
                                    strength=1)

------

In [69]:
from utility.data.get_data import get_data
import clip
import torch
import torch.nn.functional as F
from PIL import Image

In [5]:
dataloader = get_data(batch_size=32, 
                      img_root = "imagenet-a",
                      split_data=False)



In [31]:
candle_img = dataloader.dataset[5200][0]

In [None]:
import os
import torch

path_gen = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated"
for class_name in os.listdir(path_gen):
    class_path = os.path.join(path_gen, class_name)
    for gen_images_class in os.listdir(class_path):
        gen_images_class = os.path.join(class_path,gen_images_class)
        for gen_image in os.listdir(gen_image_class):
            gen_image_path = os.path.join(gen_images_class, gen_image)
            prompt_embedding_path = os.path.join(gen_image_class, "prompt_clip_embedding.pt")
            prompt_embedding_clip = torch.load(prompt_embedding_path)
            if not torch.isclose(prompt_embedding_clip.norm(),torch.tensor(1.)).item():
                prompt_embedding_clip /= prompt_embedding_clip.norm()
                torch.save(prompt_embedding_clip, prompt_embedding_path)

In [None]:
clip_image_encoder = "ViT-L/14"
clip_model, preprocess = clip.load(clip_image_encoder)

In [82]:
def retrieve_gen_images(img, data_path, num_images, clip_model, use_t2i_similarity = False, t2i_images = True, i2i_images = False):

    assert any(i2i_images, t2i_images), "One of t2i_images and i2i_images must be true"

    retrieved_images_paths = []
    retrieved_images_similarity = torch.zeros(num_images)
    image_embedding = clip_model.encode_image(preprocess(image).unsqueeze(0).cuda())
    image_embedding /= image_embedding.norm()
    
    for class_name in os.listdir(data_path):
        class_path = os.path.join(data_path, class_name)
        for gen_images_class in os.listdir(class_path):
            gen_images_class_path = os.path.join(class_path,gen_images_class)
            for gen_images in os.listdir(gen_images_class_path):
                gen_image_path = os.path.join(gen_images_class_path, gen_images)
                gen_prompt_embedding_clip = torch.load(os.path.join(gen_image_class, "prompt_clip_embedding.pt"))
                t2i_similarity = F.cosine_similarity(image_embedding, gen_prompt_embedding, dim=0)
                if t2i_images:
                    t2i_gen_images_main_path = os.path.join(gen_image_path,"t2i_gen_images")
                    for t2i_images_paths in os.listdir(t2i_gen_images_main_path):
                        t2i_image_path = os.path.join(t2i_gen_images_main_path,t2i_images_paths)
                        gen_image_embedding = torch.load(os.path.join(t2i_image_path, "image_embedding.pt"))
                        i2i_similarity = F.cosine_similarity(image_embedding, gen_image_embedding, dim=0)
                        if use_t2i_similarity:
                            similarity = (i2i_similarity + t2i_similarity)/2 # avg similarity
                        else:
                            similarity = i2i_similarity
                        if len(retrieved_images) < num_images:
                            retrieved_images_similarity[len(retrieved_images)] = similarity
                            retrieved_images.append(os.path.join(t2i_image_path, "image.png"))
                        else:
                            min_similarity, id_similarity = retrieved_images_similarity.min()
                            if similarity > min_similarity:
                                retrieved_images_similarity[id_similarity] = similarity
                                retrieved_images[id_similarity] = os.path.join(t2i_image_path, "image.png")
                if i2i_images:
                    i2i_gen_images_main_path = os.path.join(gen_image_path,"i2i_gen_images")
                    for i2i_images_paths in os.listdir(i2i_gen_images_main_path):
                        i2i_image_path = os.path.join(i2i_gen_images_main_path,i2i_images_paths)
                        gen_image_embedding = torch.load(os.path.join(i2i_image_path, "image_embedding.pt"))
    
                        i2i_similarity = F.cosine_similarity(image_embedding, gen_image_embedding, dim=0)
                        if use_t2i_similarity:
                            similarity = (i2i_similarity + t2i_similarity)/2 # avg similarity
                        else:
                            similarity = i2i_similarity
                        if len(retrieved_images) < num_images:
                            retrieved_images_similarity[len(retrieved_images)] = similarity
                            retrieved_images.append(os.path.join(t2i_image_path, "image.png"))
                        else:
                            min_similarity, id_similarity = retrieved_images_similarity.min()
                            if similarity > min_similarity:
                                retrieved_images_similarity[id_similarity] = similarity
                                retrieved_images[id_similarity] = os.path.join(i2i_image_path, "image.png")