In [None]:
import os
import sys
sys.path.insert(0, os.getcwd())

import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import ToTensor
from torchvision.transforms.functional import resize
from torchvision.io import read_image, write_png
from tqdm.auto import tqdm

In [None]:
to_tensor = ToTensor()

In [None]:
ds_vs_ds_path = {
    "video_adv_splc": "/media/nas2/graph_sim_data/video_advanced_splicing/test",
    "video_vis_aug": "/media/nas2/graph_sim_data/video_visible_aug/test",
    "video_invis_aug": "/media/nas2/graph_sim_data/video_invisible_aug/test",
    "video_sham_adobe": "/media/nas2/Datasets/VideoSham-adobe-research/extracted_frames_ge_1920x1080",
    "video_e2fgvi_davis": "/media/nas2/Tai/13-e2fgvi-video-inpainting/ds_1920x1080",
    "videomatting": "/media/nas2/Datasets/VideoMatting/data/dataset",
    "deepfake": "/media/nas2/deepfakes/cvpr/dataset",
    "deepfake_not_working": "/media/nas2/deepfakes/cvpr/not_working_examples",
}

In [None]:
ds_vs_samples = {
    "video_adv_splc": [
        "manip_05798",
        "manip_04226",
        "manip_06094",
        "manip_06575",
        "manip_06547",
        "manip_04821",
        "manip_06826",
        "manip_04981",
        "manip_06815",
        "manip_08303",
    ],
    "video_vis_aug": [
        "manip_05798",
        "manip_06475",
        "manip_07971",
        "manip_06769",
        "manip_04630",
        "manip_08394",
        "manip_07378",
        "manip_04333",
        "manip_06848",
        "manip_04831",
        "manip_07085",
    ],
    "video_invis_aug": [
        "manip_07143",
        "manip_08009",
        "manip_06398",
        "manip_05589",
        "manip_04427",
        "manip_05028",
        "manip_07956",
        "manip_08294",
        "manip_05499",
        "manip_08120",
        "manip_06038",
    ],
    "video_sham_adobe": [
        "attack4/manip_4176_0219",
        "attack4/manip_1044_0088",
        "attack4/manip_4002_0082",
        "attack4/manip_4002_0060",
        "attack1/manip_0098_0138",
        "attack1/manip_0108_0046",
        "attack1/manip_0102_0044",
        "attack1/manip_0102_0046",
        "attack2/manip_0090_0023",
        "attack2/manip_0087_0140",
        "attack2/manip_0084_0280",
        "attack3/manip_1050_0151",
        "attack3/manip_1050_0244",
        "attack1/manip_4143_0105",
        "attack1/manip_4143_0213",
    ],
    "video_e2fgvi_davis": [
        "manip_schoolgirls_013",
        "manip_paragliding_046",
        "manip_horsejump-low_048",
        "manip_motorbike_025",
        "manip_breakdance_066",
        "manip_scooter-gray_068",
        "manip_scooter-gray_039",
        "manip_scooter-gray_022",
        "manip_tractor-sand_029",
        "manip_tractor-sand_035",
        "manip_tractor-sand_012",
        "manip_hockey_033",
        "manip_hockey_032",
        "manip_hockey_067",
        "manip_boat_024",
        "manip_boat_062",
        "manip_boat_061",
        "manip_bmx-trees_031",
        "manip_bmx-trees_032",
        "manip_bmx-trees_016",
        "manip_bmx-bumps_011",

    ],
    "videomatting": [
        "artem_manip_0138",
        "artem_manip_0149",
        "rain_manip_0043",
        "snow_manip_0129",
        "snow_manip_0038",
        "slava_manip_0072",
        "vitaliy_manip_0055",
        "concert_manip_0038",
        "concert_manip_0152",
    ],
    "deepfake": [
        "manip_Zella-Rena_0000",
        "manip_Zella-Rena_0017",
        "manip_Ruelle-Leah_0019",
        "manip_Ella-Katie_0023",
        "manip_Nicki-Latto_0008",
        "manip_Nicki-Latto_0024",
        "manip_Marina-Madison_0022",
        "manip_Ed-24KGold_0002",
        "manip_Doja-Kylie_0004",
        "manip_Selena-Tay_0000",
    ],
    "deepfake_not_working": [
        "manip_Kanye-Kevin_0016",
    ],
}

In [None]:
root_dir = "."
loc_result_dir = f"{root_dir}/loc_comparisons"

In [None]:
datasets = [
    "video_adv_splc", #0
    "video_vis_aug", #1
    "video_invis_aug", #2
    "video_sham_adobe", #3
    "video_e2fgvi_davis", #4
    "videomatting", #5
    "deepfake", #6
    "deepfake_not_working", #7
]

In [None]:
architectures = [
    "video_transformer", #0
    "fsg", #1
    "exif", #2
    "noiseprint", #3
    "mvss", #4
]

In [None]:
arch_choice = architectures[4]
ds_choice = datasets[3]
print(arch_choice, ds_choice)

In [None]:
for d in datasets:
    if not os.path.exists(f"{loc_result_dir}/{d}"):
        os.makedirs(f"{loc_result_dir}/{d}")

In [None]:
eval_samples = [f"{ds_vs_ds_path[ds_choice]}/{s}" for s in ds_vs_samples[ds_choice]]

In [None]:
# make sure these files exists:
for s in eval_samples:
    if not os.path.exists(f"{s}.png"):
        raise FileNotFoundError(f"{s}.png")
    if not os.path.exists(f"{s}.mask"):
        raise FileNotFoundError(f"{s}.mask")

In [None]:
from evaluate_model import get_model

if arch_choice == "video_transformer":
    from models.video_transformer.patch_predictions import PatchPredictions

In [None]:
model = get_model(arch_choice)

In [None]:
# model = model.cuda()
# for i in tqdm(range(1000)):
#     model(torch.randn(1,3,1080,1920).cuda())

In [None]:
for sample_name in tqdm(eval_samples):
    sample_path = f"{sample_name}.png"
    sample_gt_mask = f"{sample_name}.mask"
    sample_folder, sample_basename = os.path.split(os.path.abspath(sample_path))
    sample_filename, sample_extension = os.path.splitext(sample_basename)

    # sample = resize(read_image(sample_path), [1080, 1920])
    sample = Image.open(sample_path, mode="r")
    sample = to_tensor(sample) * 255
    sample = resize(sample[0:3].to(torch.uint8), [1080, 1920])

    gt_mask = read_image(sample_gt_mask)
    if gt_mask.max() < 255:
        gt_mask[gt_mask > 0] = 255

    det, pred_mask = model(sample.unsqueeze(0).float())
    det, pred_mask = det.detach().cpu(), pred_mask.detach().cpu()
    if arch_choice == "video_transformer":
        from models.video_transformer.patch_predictions import PatchPredictions
        patch_pred_class = PatchPredictions(pred_mask, model.patch_size, model.img_size, min_thresh=0.1, max_num_regions=3, final_thresh=0.26)
        pred_mask = patch_pred_class.get_pixel_preds()

    if len(pred_mask.shape) < 3:
        pred_mask = (pred_mask.unsqueeze(0) * 255).to(torch.uint8)
    else:
        pred_mask = (pred_mask * 255).to(torch.uint8)
    
    if arch_choice == "video_transformer":
        write_png(sample, f"{loc_result_dir}/{ds_choice}/{arch_choice}_{sample_filename}.png", 0)
        write_png(gt_mask, f"{loc_result_dir}/{ds_choice}/{arch_choice}_{sample_filename}_gt_mask.png", 0)
    write_png(pred_mask, f"{loc_result_dir}/{ds_choice}/{arch_choice}_{sample_filename}_pred_mask.png", 0)

In [None]:
# plt.imshow(pred_mask[0])

In [None]:
# pred_mask.sum() / 255