In [None]:
import torch
import torchvision
import torchvision.models as models
from torchvision.models.resnet import ResNet50_Weights
import lightning.pytorch as pl

from mymodels import Model_Wrapper, Preprocess
from myutils import View, sample_imgs_list

In [None]:
# configs to speed up training
torch.set_float32_matmul_precision('medium')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
PATH_TO_IMAGENET = '../../datasets/imagenet/2012/'
NUM_IMG_EVAL = 10000

In [None]:
resnet50 = models.resnet50(weights=ResNet50_Weights.DEFAULT)
resnet50 = Model_Wrapper(resnet50)
trainer = pl.Trainer(accelerator="auto", limit_test_batches=100)

In [None]:
# get transforms
prep = Preprocess(PATH_TO_IMAGENET, (224, 224), shuffle=False)
preview_img_slice = slice(00000, 50000, 10000)

In [None]:
def compare_ds(model, datasets, labels, view_img_slice, view_only=False, figsize=(12,6), verbose=True, skip_results:list=[0]):
    results = {}
    samples = []
    #display first
    for ds in datasets:
        samples.append(sample_imgs_list(ds, view_img_slice))
    _ = View.compare_color(samples, labels, figsize=figsize)
    
    if not view_only:
        for i in range(len(datasets)):
            if i in skip_results:
                continue
            result = trainer.test(model, datasets[i], verbose=verbose)
            results[labels[i]] = result
        return results
    else:
        return None
    

##### Test On Orignal Dataset

In [None]:
prep.reset_trans()
imgnet_orig = prep.get_loader()
sample_orig = sample_imgs_list(imgnet_orig, preview_img_slice)# test model on original images
result_orig = trainer.test(resnet50, imgnet_orig)

##### Test on Dark Images

In [None]:
prep.reset_trans()
imgnet_dark = prep.luminance(1/8).get_loader()
imgnet_dark_histeq = prep.copy().hist_eq().get_loader()
imgnet_dark_retinex = prep.copy().retinex('SSR', 100).get_loader()
imgnet_dark_results = compare_ds(resnet50, 
                                 [prep.basic_loader(), imgnet_dark, imgnet_dark_histeq, imgnet_dark_retinex],
                                 ['Original', 'Dark', 'Dark + HistEQ', 'Dark + Retinex'], 
                                 preview_img_slice)

##### Test on Bright Images

In [None]:
prep.reset_trans()
imgnet_bright = prep.luminance(2).get_loader()
imgnet_bright_histeq = prep.copy().hist_eq().get_loader()
imgnet_bright_retinex = prep.copy().retinex('SSR', 100).get_loader()
imgnet_bright_results = compare_ds(resnet50, 
                                 [prep.basic_loader(), imgnet_bright, imgnet_bright_histeq, imgnet_bright_retinex],
                                 ['Original', 'Bright', 'Bright + HistEQ', 'Bright + Retinex'], 
                                 preview_img_slice)

##### Test Histogram Eq on Low Contrast Images

In [None]:
prep.reset_trans()
imgnet_low_cont = prep.brightness_contrast(0,0.1).get_loader()
imgnet_low_cont_histeq = prep.copy().hist_eq().get_loader()
imgnet_low_cont_retinex = prep.copy().retinex('SSR', 100).get_loader()
imgnet_low_cont_results = compare_ds(resnet50, 
                                 [prep.basic_loader(), imgnet_low_cont, imgnet_low_cont_histeq, imgnet_low_cont_retinex],
                                 ['Original', 'Low Contrast', 'Low Contrast + HistEQ', 'Low Contrast + Retinex'], 
                                 preview_img_slice)

##### Test Remapping Distribution to be low contrast

In [None]:
prep.reset_trans()
imgnet_mod = prep.dist_remap(20/255,10/255).get_loader()
imgnet_mod_histeq = prep.copy().hist_eq().get_loader()
imgnet_mod_retinex = prep.copy().retinex('SSR', 100).get_loader()
# imgnet_mod_results = compare_ds(resnet50, 
#                                  [prep.basic_loader(), imgnet_mod, imgnet_mod_histeq, imgnet_mod_retinex],
#                                  ['Original', 'Low Contrast', 'Low Contrast + HistEQ', 'Low Contrast + Retinex'], 
#                                  preview_img_slice,
#                                  view_only=False)

imgnet_mod_results = compare_ds(resnet50, 
                                 [prep.basic_loader(), imgnet_mod],
                                 ['Original', 'Low Contrast'], 
                                 preview_img_slice,
                                 view_only=False)


##### Test sharpening on blurry images

In [None]:
# prep.reset_trans()
# imgnet_mod = prep.blur(5).get_loader()
# imgnet_mod_filt = prep.sharpen(5).get_loader()

# sample_mod = sample_imgs(imgnet_mod, preview_img_slice)
# sample_mod_filter = sample_imgs(imgnet_mod_filt, preview_img_slice)

# View.compare3_color(sample_mod, sample_mod_filter, sample_orig)

# result_mod = trainer.test(resnet50, imgnet_mod)
# result_mod_filt = trainer.test(resnet50, imgnet_mod_filt)