In [1]:
import os
import re
from tqdm import tqdm
import os
import sys
from PIL import Image
import yaml
from pathlib import Path
import torchvision
import torchvision.transforms.functional as TF
import torch
import numpy as np

root_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))

if root_dir not in sys.path:
    sys.path.append(root_dir)

from utils.dir_utils import mkdir, get_last_path
from utils.model_utils import load_checkpoint
from model.URSCT_model import URSCT



  Referenced from: <B3E58761-2785-34C6-A89B-F37110C88A05> /Users/christian/Dev/JBG060/URSCT-SESR/venv/lib/python3.9/site-packages/torchvision/image.so
  Expected in:     <CA0A91CD-08B1-3B88-A2D5-BD93563ECA22> /Users/christian/Dev/JBG060/URSCT-SESR/venv/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])


def get_all_images_in_directory(directory):
    images = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if is_image_file(file):
                images.append(os.path.join(root, file))
    return images

habitat_frame_regex = re.compile(r'(\d{4})_(.+)_(f\d+)[_.](?=jpg|png)')

# add all images in the directory to the list

original_dataset_path = "../../eda/data/DeepFish/"

yolo_images = get_all_images_in_directory("../../yolo/datasets/DeepFish-2/test")
original_images = (
    get_all_images_in_directory(original_dataset_path + "Classification")
    + get_all_images_in_directory(original_dataset_path + "Localization/images")
    + get_all_images_in_directory(original_dataset_path + "Segmentation/images")
)


len(yolo_images), len(original_images)

(645, 43586)

In [3]:
yolo_images[:5]

['../../yolo/datasets/DeepFish-2/test/images/7117_Caranx_sexfasciatus_juvenile_f001410_jpg.rf.2f7f30b91471cf33eceabed996e36cc6.jpg',
 '../../yolo/datasets/DeepFish-2/test/images/9866_no_fish_f000009_jpg.rf.1b691c7ed3043cedc5f1eb3e4f4e7b21.jpg',
 '../../yolo/datasets/DeepFish-2/test/images/7398_NF2_f000181_jpg.rf.06beabe2773409a1ba4b5305832a3051.jpg',
 '../../yolo/datasets/DeepFish-2/test/images/7268_F1_f000301_jpg.rf.de9e00afda3d3f8a673439ae8fedde39.jpg',
 '../../yolo/datasets/DeepFish-2/test/images/7434_NF2_f000061_jpg.rf.c3ff9344882f61c38483d01fa08ec2ff.jpg']

In [4]:
original_images_dict = {
    re.match(habitat_frame_regex, image.split("/")[-1]).groups(0): image
    for image in original_images
}
# lowercase keys in the dictionary
original_images_dict = {
    tuple(
        key_part.lower() if isinstance(key_part, str) else key_part for key_part in key
    ): value
    for key, value in original_images_dict.items()
}


len(original_images_dict), len(
    original_images
)  # duplicate keys are removed, i hope this is not bad

(39766, 43586)

In [5]:
list(original_images_dict.items())[:4]

[(('9892', 'acanthopagrus_palmaris', 'f000038'),
  '../../eda/data/DeepFish/Classification/9892/valid/9892_acanthopagrus_palmaris_f000038.jpg'),
 (('9892', 'acanthopagrus_palmaris', 'f000010'),
  '../../eda/data/DeepFish/Segmentation/images/valid/9892_acanthopagrus_palmaris_f000010.jpg'),
 (('9892', 'acanthopagrus_palmaris', 'f000004'),
  '../../eda/data/DeepFish/Classification/9892/valid/9892_acanthopagrus_palmaris_f000004.jpg'),
 (('9892', 'acanthopagrus_palmaris_2', 'f000008'),
  '../../eda/data/DeepFish/Classification/9892/valid/9892_Acanthopagrus_palmaris_2_f000008.jpg')]

In [6]:
with open('../configs/Enh_opt.yaml', 'r') as config:
    opt = yaml.safe_load(config)
    opt_test = opt['DEEPFISH']

device = opt_test['DEVICE']
model_detail_opt = opt['MODEL_DETAIL']
result_dir = os.path.join(opt_test['SAVE_DIR'], opt['TRAINING']['MODEL_NAME'], 'test_results')
mkdir(result_dir)

model = URSCT(model_detail_opt).to(device)
path_chk_rest = get_last_path(os.path.join(opt_test['SAVE_DIR'], opt['TRAINING']['MODEL_NAME'], 'models'), '_bestSSIM.pth')
load_checkpoint(model, path_chk_rest, device)
model.eval()

patch_size = opt_test['TEST_PS']

Model loading successfully!


In [7]:
def resize_with_letterbox(image, target_size=(640, 640)):
    orig_w, orig_h = image.size
    
    scale = min(target_size[0] / orig_h, target_size[1] / orig_w)
    
    new_w = int(orig_w * scale)
    new_h = int(orig_h * scale)
    
    resized_img = image.resize((new_w, new_h), Image.BILINEAR)
    
    letterbox_img = Image.new('RGB', target_size, (0, 0, 0))
    
    top_left_x = (target_size[1] - new_w) // 2
    top_left_y = (target_size[0] - new_h) // 2
    
    letterbox_img.paste(resized_img, (top_left_x, top_left_y))
    
    return letterbox_img

In [18]:
# goal:
# for each image with yolo annotation, get original image
# for each filtered image, run swin transformer
# write the output to a new directory with division test train val
# transform yolo annotations to the new image size (multiply y axis by 16/9)

for yolo_image in tqdm(yolo_images):
    habitat, fish_type, frame = re.match(
        habitat_frame_regex, yolo_image.split("/")[-1]
    ).groups(0)
    original_image_path = original_images_dict.get((habitat, fish_type.lower(), frame))

    if original_image_path is None:
        raise ValueError(
            f"Original image not found for {yolo_image} with key {(habitat, fish_type.lower(), frame)}"
        )

    yolo_image_path = Path(yolo_image)
    model_stage = str(yolo_image_path.parent.parent.name)
    target_path = os.path.join(result_dir, model_stage, "images", yolo_image_path.stem + ".png")

    if os.path.exists(target_path):
        continue

    # generate swin image
    inp_img = Image.open(original_image_path)

    orig_w, orig_h = inp_img.size

    # letterboxed_img = resize_with_letterbox(inp_img, (640, 640))
    inp_img = TF.to_tensor(inp_img)
    inp_img = TF.resize(inp_img, (patch_size[0], patch_size[1]))

    image_gpu = inp_img.to(device).unsqueeze(0)

    with torch.no_grad():
        restored_SR = model(image_gpu)

    restored_SR = TF.resize(restored_SR, (int(patch_size[1] / (orig_w / orig_h)), patch_size[0]), interpolation=TF.InterpolationMode.BILINEAR)
    
    # save swin image to output directory
    # create directories if they don't exist
    mkdir(os.path.join(result_dir, model_stage, "images"))

    torchvision.utils.save_image(
        restored_SR[0],
        # image_gpu[0],
        target_path,
    )

100%|██████████| 645/645 [08:07<00:00,  1.32it/s]


In [72]:
import cv2
from matplotlib import pyplot as plt
from tqdm import tqdm


for input_video in os.listdir("../../eda/videos"):
    output_video_path = f"swin/{input_video}"
    input_video_path = f"../../eda/videos/{input_video}"

    mkdir(os.path.dirname(output_video_path))

    if not input_video_path.endswith(".mp4"):
        continue

    if not input_video.split(".")[0] in [
        "9908",
        "9907",
        "9898",
        "9894",
        "9892",
        "9862",
        "9866",
    ]:
        continue

    print(f"Processing {input_video}")
    cap = cv2.VideoCapture(input_video_path)

    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    output_dim = (360, 640)
    out = cv2.VideoWriter(
        output_video_path, fourcc, fps, (output_dim[1], output_dim[0])
    )

    for frame_number in tqdm(range(total_frames), desc="Processing frames"):
        ret, frame = cap.read()
        if not ret:
            break  # End of video

        orig_h, orig_w, _ = frame.shape
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        inp_img = TF.to_tensor(frame_rgb)
        inp_img = TF.resize(inp_img, (patch_size[0], patch_size[1]))

        image_gpu = inp_img.to(device).unsqueeze(0)

        with torch.no_grad():
            restored_SR = model(image_gpu)

        restored_SR = TF.resize(
            restored_SR, output_dim, interpolation=TF.InterpolationMode.BILINEAR
        )

        output_frame = restored_SR.squeeze().permute(1, 2, 0).cpu().numpy()
        output_image = np.clip(output_frame * 255, 0, 255).astype(np.uint8)

        # plot image
        # plt.imshow(output_frame)
        # plt.show()

        out.write(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))

    cap.release()
    out.release()
    cv2.destroyAllWindows()

Processing Video_0015.mp4
Processing 9894.mp4


Processing frames: 100%|██████████| 157/157 [01:52<00:00,  1.40it/s]


Processing 7463.mp4
Processing 9908.mp4


Processing frames: 100%|██████████| 91/91 [01:04<00:00,  1.40it/s]


Processing .DS_Store
Processing 9892.mp4


Processing frames: 100%|██████████| 55/55 [00:39<00:00,  1.41it/s]


Processing 7117.mp4
Processing 9852.mp4
Processing 7398.mp4
Processing 7434.mp4
Processing 7623.mp4
Processing .mp4
Processing 7393.mp4
Processing 7426.mp4
Processing Video_0051.mp4
Processing 7585.mp4
Processing Video_0034.mp4
Processing 9862.mp4


Processing frames: 100%|██████████| 45/45 [00:31<00:00,  1.43it/s]


Processing 7482.mp4
Processing 9870.mp4
Processing 9866.mp4


Processing frames: 100%|██████████| 113/113 [01:21<00:00,  1.39it/s]


Processing 9907.mp4


Processing frames: 100%|██████████| 186/186 [02:13<00:00,  1.40it/s]


Processing 9898.mp4


Processing frames: 100%|██████████| 57/57 [00:40<00:00,  1.41it/s]

Processing Video_0030.mp4
Processing 7490.mp4



