# Imports

In [None]:
from os.path import join, isdir
from os import makedirs
import einops
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import model
from coco_fake_dataset import COCOFakeDataset

# Parameters

In [None]:
coco2014_path = join("..", "..", "datasets", "coco2014")
coco_fake_path = join("..", "..", "datasets", "fake_coco")
images_path = join(".", "images")
pretrained_model_path = join("images", "coco_fake_S_epoch=4-train_acc=0.93-val_acc=0.93.ckpt")

# Dataset loading

In [None]:
dataset = COCOFakeDataset(
        coco2014_path=coco2014_path,
        coco_fake_path=coco_fake_path,
        split="val",
        mode="single",
        resolution=224,
    )

In [None]:
dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
    )

# Model creation

In [None]:
net = model.BNext4DFR.load_from_checkpoint(pretrained_model_path, map_location="cpu")

# Plotting

In [None]:
# creates the images dir
from os import makedirs
if not isdir(images_path):
    makedirs(images_path)

In [None]:
for batch in dataloader:
    import torch
    import timm
    # adds the new channels to the image
    image_augmented = net.add_new_channels(batch["image"])
    image_adapted = net.adapter(image_augmented).detach().cpu()
    image_adapted = (image_adapted - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(1, -1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(1, -1, 1, 1)
    features = net.base_model(image_adapted)[0]
    image_adapted = einops.rearrange(image_adapted[0], "c h w -> h w c")
    rgb_image = einops.rearrange(image_augmented[0, :3], "c h w -> h w c")
    fft_image = image_augmented[0, -2]
    lbp_image = image_augmented[0, -1]
    # plots each channel of the augmented image
    for image, title in [
        (rgb_image, "rgb"),
        (fft_image, "fft"),
        (lbp_image, "lbp"),
        (image_adapted, "adapted"),
        ]:
        plt.imshow(image)
        plt.axis('off')
        plt.savefig(join(images_path, f"{title}.png"), bbox_inches='tight', pad_inches=0, dpi=300)
    break