In [1]:
from pathlib import Path
import os
import shutil
import random

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models
from torchvision.transforms import v2 as transforms
from torchvision.transforms.v2 import functional as F
from tqdm import tqdm
from torchvision import transforms

## Filter POI image folders

In [2]:
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"}


def dirwalk(folder: str):
    if folder.split(".")[-1].lower() in IMG_FORMATS:
        return [folder]
    else:
        path = Path(folder)
        files = sorted(
            x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS
        )
        return files

In [5]:
images = dirwalk("/mnt/data/data/Data_Sim_Org")
print(len(images))

313880


In [15]:
images2 = dirwalk("/mnt/data/data/Data_Sim")
print(len(images2))

100104


In [5]:
def get_folders(images: list):
    folders = [x.parent for x in images]
    return list(set(folders))

In [8]:
folders = get_folders(images)
len(folders)

29969

In [6]:
folders2 = get_folders(images2)
len(folders2)

29734

In [7]:
def get_not_poi(images: list) -> list:
    return [x for x in images if "POI" not in x.parent.name]

In [8]:
not_poi_images2 = get_not_poi(images2)
len(not_poi_images2)

0

In [42]:
def get_pa(images: list) -> list:
    return [x for x in images if "PA" in x.parent.name]

In [43]:
pa_images = get_pa(images)
len(pa_images)

0

In [44]:
def get_sn(images: list) -> list:
    return [x for x in images if "SN" in x.parent.name]

In [45]:
sn_images = get_sn(images)
len(sn_images)

0

In [3]:
def get_aug(images: list) -> list:
    return [x for x in images if str(x.name).startswith("auga")]

In [9]:
aug_images2 = get_aug(images2)
len(aug_images2)

200207

In [None]:
# ignore_images = set(pa_images).union(set(sn_images)).union(set(aug_images))
# len(ignore_images)

40587

In [10]:
# ignore_images2 = [x for x in images2 if x in not_poi_images2 or x in aug_images2]
ignore_images2 = aug_images2
len(ignore_images2)

200207

In [24]:
def get_ignore_folders(folders: list, ignore_images: list):
    def folder_len(folder):
        return len(list(folder.glob("*.*")))
    
    folder_dict = {folder: folder_len(folder) for folder in folders}
    for image in ignore_images:
        folder = image.parent
        folder_dict[folder] -= 1
    
    ignore_folders = [x for x, n in folder_dict.items() if n <= 0]
    return ignore_folders

In [12]:
ignore_folders2 = get_ignore_folders(folders2, ignore_images2)
len(ignore_folders2)

0

In [13]:
def remove_ignore(ignore_images: list, ignore_folders: list):
    for image in ignore_images:
        image.unlink()
    
    for folder in ignore_folders:
        folder.rmdir()

In [14]:
remove_ignore(ignore_images2, ignore_folders2)

## (Optional) Create Masked Crop POI

Running in POI_Seg project ...

## Split train and test sets

In [16]:
def horizontal_split(root_dir, output_dir, train_ratio=0.8, val_ratio=0.2):
    assert train_ratio + val_ratio == 1.0, "Ratios must sum to 1.0"

    train_dir = os.path.join(output_dir, 'train')
    val_dir = os.path.join(output_dir, 'val')
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)

    class_folders = [f.name for f in os.scandir(root_dir) if f.is_dir()]
    random.shuffle(class_folders)

    split_idx = int(len(class_folders) * train_ratio)
    
    # Split folders into train and validation sets
    train_folders = class_folders[:split_idx]
    val_folders = class_folders[split_idx:]

    for folder in train_folders:
        src = os.path.join(root_dir, folder)
        dest = os.path.join(train_dir, folder)
        shutil.copytree(src, dest)
        
    for folder in val_folders:
        src = os.path.join(root_dir, folder)
        dest = os.path.join(val_dir, folder)
        shutil.copytree(src, dest)

In [17]:
src_path = "/mnt/data/data/Data_Sim"
dst_path = "../data_sim_processed"
horizontal_split(src_path, dst_path)

## Data augmentation

In [18]:
def augment_and_save_image(image, transform, num_augmentations=2):
    img = Image.open(str(image))
    image_name = image.name
    for i in range(num_augmentations):
        try:
            augmented_image = transform(img)
            augmented_name = f"auga_{i}__{image_name}"
            augmented_path = image.parent / Path(augmented_name)
            augmented_image.save(str(augmented_path))
        except:
            continue

def aug(images, transform, num_augmentations=2):
    for image in tqdm(images):
        augment_and_save_image(image, transform, num_augmentations)

In [19]:
transform = transforms.RandomChoice(
    [
        transforms.ColorJitter(brightness=0.5, contrast=0.5),
        transforms.RandomRotation(30),
        transforms.GaussianBlur(kernel_size=3),
    ],
)

In [21]:
train_images2_folder = "/mnt/data/src/bait/train/sim/data_sim_processed/train"

In [27]:
train_images2_folders = list(Path(train_images2_folder).iterdir())
ignore_train_folders2 = get_ignore_folders(train_images2_folders, [])
len(ignore_train_folders2)

1

In [29]:
remove_ignore([], ignore_train_folders2)

In [None]:
train_images2 = dirwalk(train_images2_folder)
aug(train_images2, transform)

100%|██████████| 80719/80719 [16:15<00:00, 82.77it/s]  


In [22]:
images2 = dirwalk(train_images2_folder)
folders2 = get_folders(images2)

print(len(folders2), len(images2))

24039 242156


In [None]:
val_images2_folder = "/mnt/data/src/bait/train/sim/data_sim_processed/val"
val_images2 = dirwalk(val_images2_folder)
val_folders2 = get_folders(val_images2)

print(len(val_folders2), len(val_images2))  

5695 19385
