In [None]:
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torchvision.models as models
from torchvision.models.resnet import ResNet50_Weights
import lightning.pytorch as pl

from models import Model_Wrapper, Preprocess
from transforms import Luminance
from utils import View, sample_imgs

In [None]:
PATH_TO_IMAGENET = '../../datasets/imagenet/2012/'
NUM_IMG_EVAL = 10000

In [None]:
resnet50 = models.resnet50(weights=ResNet50_Weights.DEFAULT)
resnet50 = Model_Wrapper(resnet50)
trainer = pl.Trainer(accelerator="auto", limit_test_batches=100)

In [None]:
# get transforms
prep = Preprocess(PATH_TO_IMAGENET, (224, 224), shuffle=True)

##### Test On Orignal Dataset

In [None]:
prep.reset_trans()
imgnet_orig = prep.get_loader()
sample_orig = sample_imgs(imgnet_orig, slice(0,3))

# test model on original images
result_orig = trainer.test(resnet50, imgnet_orig)

##### Test Histogram Eq on Dark Images

In [None]:
prep.reset_trans()
imgnet_dark = prep.luminance(1/8).get_loader()
imgnet_dark_histeq = prep.hist_eq().get_loader()

sample_dark = sample_imgs(imgnet_dark, slice(0,3))
sample_dark_histeq = sample_imgs(imgnet_dark_histeq, slice(0,3))
View.compare3_color(sample_dark, sample_dark_histeq, sample_orig)

# test model on dark images and histeq
result_dark = trainer.test(resnet50, imgnet_dark)
result_dark_histeq = trainer.test(resnet50, imgnet_dark_histeq)

##### Test Histogram Eq on Bright Images

In [None]:

imgnet_bright = prep.luminance(2).get_loader()
imgnet_bright_histeq = prep.hist_eq().get_loader()
prep.reset_trans()

sample_bright = sample_imgs(imgnet_bright, slice(0,3))
sample_bright_histeq = sample_imgs(imgnet_bright_histeq, slice(0,3))
View.compare3_color(sample_bright, sample_bright_histeq, sample_orig)

# test model on bright images and histeq
result_bright = trainer.test(resnet50, imgnet_bright)
result_bright_histeq = trainer.test(resnet50, imgnet_bright_histeq)