# Test-Time Adaptation

In [1]:
import torch
import torchvision.transforms as T
import torchvision.models as models
import torchvision
from test_methods.test import Tester
from test_time_adaptation.resnet50_dropout import ResNet50Dropout

In [2]:
imagenet_a_path = "imagenet-a"
imagenet_b_path = "imagenetv2-matched-frequency-format-val/"

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
import torchvision.transforms as T

augmentations = [
    T.RandomHorizontalFlip(p=1),
    T.RandomVerticalFlip(p=1),
    T.RandomRotation(degrees=30),
    T.RandomRotation(degrees=60),
    T.ColorJitter(brightness=0.2),
    T.ColorJitter(contrast=0.2),
    T.ColorJitter(saturation=0.2),
    T.ColorJitter(hue=0.2),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    T.RandomRotation(degrees=15),
    T.RandomAdjustSharpness(sharpness_factor=2, p=1),
    T.RandomGrayscale(p=1),
    T.RandomInvert(p=1),
    T.RandomAutocontrast(p=1),
    T.GaussianBlur(kernel_size=5),
]

augmix_augmentations = [
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0)
]

## Resnet50

In [5]:
exp_path_a = "/home/sagemaker-user/Domain-Shift-Computer-Vision/experiments/Resnet50_ImagenetA_SGD"

In [6]:
MC = {
	"dropout_rate": 0.5,
	"num_samples": 10,
	"use_dropout": False
}

In [7]:
tester_resnet50 = Tester(
    model = ResNet50Dropout() if MC['use_dropout'] else models.resnet50,
    optimizer = torch.optim.SGD,
    exp_path = exp_path_a,
    device = device
)

In [8]:
#lr_setting = [{
#    "classifier" : [["fc.weight", "fc.bias"], 0.00025]    
#}, 0]
lr_setting_sgd = [0.00025] # setting used in MEMO paper for SGD
lr_setting_adam = [0.0001] # setting used in MEMO paper for ADAM

In [9]:
imagenetV1_weights = models.ResNet50_Weights.IMAGENET1K_V1 # MEMO paper used these weights
imagenetV2_weights = models.ResNet50_Weights.IMAGENET1K_V2

In [10]:
gen_aug_settings = {
    "clip_img_encoder" : "ViT-L/14",
    "num_img" : 30,
    "gen_data_path" : "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
    "use_t2i_similarity" : True,
    "t2i_img" : True,
    "i2i_img" : False,
    "threshold" : 0.7
}

In [None]:
tester_resnet50.test(
     augmentations = augmix_augmentations, 
     num_augmentations = 16,
     seed_augmentations = 42,
     batch_size = 64, 
     img_root = imagenet_a_path,
     num_adaptation_steps = 2,
     MEMO = True,
     lr_setting = lr_setting_sgd,
     top_augmentations = 0, # if using gen_aug run with 0 bc otherwise might not be used at all
     weights_imagenet = imagenetV1_weights,
     prior_strength = 16,
     TTA = True,
     MC = None,
     gen_aug_settings = gen_aug_settings
)

## Image Generation

In [None]:
!pip install ollama

In [None]:
!pip install diffusers

In [None]:
! pip install git+https://github.com/openai/CLIP.git

In [2]:
import torch

In [1]:
from test_time_adaptation.image_generation.image_generator import ImageGenerator

In [2]:
imagenetA_generator = ImageGenerator()

In [5]:
# generate prompts
skipped_classes = imagenetA_generator.generate_prompts(
    num_prompts_per_class=20,
    style_of_picture="photograph",
    path="/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
    context_llm = "/home/sagemaker-user/Domain-Shift-Computer-Vision/test_time_adaptation/image_generation/llm_context.json",
    llm_model = "llama3.1", 
    clip_text_encoder = "ViT-L/14"
)

Processing classes:   4%|▎         | 7/200 [00:30<14:52,  4.62s/it]

Skipping class vulture.


Processing classes:   6%|▌         | 11/200 [01:21<27:50,  8.84s/it]

Skipping class green iguana.


Processing classes:   8%|▊         | 15/200 [01:39<21:04,  6.84s/it]

Skipping class garter snake.


Processing classes:  10%|█         | 21/200 [02:03<16:12,  5.43s/it]

Skipping class lorikeet.


Processing classes:  14%|█▍        | 29/200 [03:16<24:29,  8.60s/it]

Skipping class flatworm.


Processing classes:  17%|█▋        | 34/200 [04:04<24:38,  8.91s/it]

Skipping class great egret.


Processing classes:  18%|█▊        | 36/200 [04:31<29:51, 10.92s/it]

Skipping class pelican.


Processing classes:  20%|██        | 40/200 [04:51<20:27,  7.67s/it]

Skipping class Rottweiler.


Processing classes:  22%|██▏       | 43/200 [05:35<28:38, 10.94s/it]

Skipping class red fox.


Processing classes:  24%|██▎       | 47/200 [05:57<20:21,  7.98s/it]

Skipping class American black bear.


Processing classes:  26%|██▋       | 53/200 [06:28<14:53,  6.08s/it]

Skipping class bee.


Processing classes:  27%|██▋       | 54/200 [06:49<19:49,  8.15s/it]

Skipping class ant.


Processing classes:  28%|██▊       | 55/200 [07:09<24:22, 10.08s/it]

Skipping class grasshopper.


Processing classes:  29%|██▉       | 58/200 [07:30<20:59,  8.87s/it]

Skipping class mantis.


Processing classes:  30%|███       | 61/200 [07:58<20:58,  9.06s/it]

Skipping class monarch butterfly.


Processing classes:  32%|███▏      | 64/200 [08:15<17:33,  7.75s/it]

Skipping class starfish.


Processing classes:  34%|███▎      | 67/200 [08:44<19:04,  8.60s/it]

Skipping class fox squirrel.


Processing classes:  36%|███▌      | 71/200 [09:06<15:26,  7.18s/it]

Skipping class armadillo.


Processing classes:  36%|███▋      | 73/200 [09:33<18:32,  8.76s/it]

Skipping class white-headed capuchin.


Processing classes:  38%|███▊      | 77/200 [10:02<17:53,  8.73s/it]

Skipping class accordion.


Processing classes:  39%|███▉      | 78/200 [10:21<21:14, 10.44s/it]

Skipping class acoustic guitar.


Processing classes:  40%|████      | 81/200 [10:41<17:41,  8.92s/it]

Skipping class apron.


Processing classes:  42%|████▏     | 83/200 [11:06<20:20, 10.43s/it]

Skipping class balloon.


Processing classes:  42%|████▏     | 84/200 [11:25<23:39, 12.24s/it]

Skipping class banjo.


Processing classes:  44%|████▍     | 88/200 [11:42<14:29,  7.76s/it]

Skipping class lighthouse.


Processing classes:  48%|████▊     | 95/200 [12:33<14:34,  8.33s/it]

Skipping class candle.


Processing classes:  48%|████▊     | 97/200 [12:38<13:25,  7.82s/it]


KeyboardInterrupt: 

In [None]:
# generate images
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler

model_id = "runwayml/stable-diffusion-v1-5"
pipet2i = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipet2i.scheduler = DPMSolverMultistepScheduler.from_config(pipet2i.scheduler.config)
pipet2i = pipet2i.to("cuda")

In [50]:
imagenet_a_generated_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated"

In [None]:
imagenetA_generator.generate_images(path = imagenet_a_generated_path,
                                    num_images = 1,
                                    image_generation_pipeline = pipet2i,
                                    num_inference_steps = 25,
                                    guidance_scale = 9,
                                    strength=1)

## Retrieving Images

In [13]:
from utility.data.get_data import get_data
import clip
import torch
import torch.nn.functional as F
from PIL import Image

In [14]:
dataloader = get_data(batch_size=32, 
                      img_root = "imagenet-a",
                      split_data=False)



In [323]:
candle_img = dataloader.dataset[20][0]

In [46]:
clip_image_encoder = "ViT-L/14"
clip_model, preprocess = clip.load(clip_image_encoder)

In [51]:
imagenet_a_generated_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated"

In [381]:
retrieved_images = retrieve_gen_images(img = candle_img,
                                       num_images = 30,
                                       data_path = imagenet_a_generated_path,
                                       clip_model = clip_model,
                                       preprocess = preprocess,
                                       t2i_images = True,
                                       use_t2i_similarity = False,
                                       threshold = 0.7)

## Scraping Images

In [3]:
!pip install bing_image_downloader

Collecting bing_image_downloader
  Downloading bing_image_downloader-1.1.2-py3-none-any.whl.metadata (2.8 kB)
Downloading bing_image_downloader-1.1.2-py3-none-any.whl (5.9 kB)
Installing collected packages: bing_image_downloader
Successfully installed bing_image_downloader-1.1.2


In [1]:
from test_time_adaptation.image_generation.web_scrape import scrape_images_imagenetA

In [None]:
scrape_images_imagenetA(img_style = "a photo of", 
                        imgenetA_gen_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated", 
                        limit = 5)