In [None]:
import sys
sys.path.append("..")

import random
import math
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

from src.datasets import *
from src.util.image import * 
from src.util.image_filter import ImageFilter

In [None]:
#base_images = ImageFolder(Path("~/Pictures/__diverse").expanduser())
base_images = ImageFolder(Path("~/Pictures/photos/katjacam").expanduser(), recursive=True)
#base_images = ImageFolder(Path("~/Pictures/diffusion/").expanduser(), recursive=True)
target_shape = (3, 64, 64)
base_images._get_filenames()
len(base_images._filenames)

In [None]:
image_filter = ImageFilter(
    #min_compression_ratio=0.9,
    min_blurred_compression_ratio=0.3,
    
)

def iter_images():
    cropper = VT.RandomCrop(target_shape[-2:])
    def _iter_crops(image, min_size: int):
        count = 2 + max(1, min(40, (min_size - 400) // 200))
        min_scale = max(.05, 1. - min_size / 400)
        #print(min_size, min_scale)
        num_yielded = 0
        num_tried = 0
        while num_yielded < count and num_tried < count * 5:
            img = image
            #image = VT.RandomAffine(degrees=30, scale=[2, 2])(image)
            scale = min_scale + math.pow(random.random() * (1. - min_scale), 10.)
            #if scale < random.random():
            #    img = VT.RandomPerspective(distortion_scale=.7)(img)
            #    crop_x = max(target_shape[-1], img.shape[-1] // 5)
            #    crop_y = max(target_shape[-2], img.shape[-2] // 5)
            #    img = VF.crop(img, crop_y // 2, crop_x // 2, img.shape[-2] - crop_y, img.shape[-1] - crop_x)
            img = VF.resize(img, [
                max(target_shape[-2], int(image.shape[-2] * scale)), 
                max(target_shape[-1], int(image.shape[-1] * scale)),
            ])
            if random.random() < .5:
                center = center=[random.randrange(target_shape[-2]), random.randrange(target_shape[-2])]
                img = VT.RandomRotation(30, center=center)(img)
            img = cropper(img)
            
            num_tried += 1
            if image_filter(img):
                yield img
                num_yielded += 1
    
    last_image_idx = 0
    for idx, base_image in enumerate(base_images):
        if idx - last_image_idx > 100:
            last_image_idx = idx
            print(f"image: #{idx}")
        
        image = set_image_channels(base_image, target_shape[0])
        #if not image_filter(image):
        #    continue
        #yield image
        yield image_resize_crop(image, target_shape[-2:])
        
        min_size = min(*image.shape[-2:])
        if min_size >= 200:
            yield from _iter_crops(image, min_size)

def plot_images(iterable, total=16, nrow=16):
    samples = []
    try:
        for i in tqdm(iter_images(), total=total):
            i = i.clamp(0, 1)
            samples.append(i)
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    display(VF.to_pil_image(make_grid(samples, nrow=nrow)))
    
plot_images(iter_images(), 16*16)

In [None]:
samples[15].view(3, -1).std(1)

In [None]:
FILENAME = f"../datasets/photos-{target_shape[-2]}x{target_shape[-1]}-bcr03.pt"

def store_dataset(
        images: Iterable,
        dtype=torch.float32,
        #image_folder="~/Pictures/__diverse/",
        output_filename=FILENAME,
        max_megabyte=1_000,
):
    tensor_batch = []
    tensor_size = 0
    last_print_size = 0
    try:
        for image in tqdm(images):
            if len(image.shape) < 4:
                image = image.unsqueeze(0)
            tensor_batch.append(image.clamp(0, 1))
            tensor_size += math.prod(image.shape) * 4

            if tensor_size - last_print_size > 1024 * 1024 * 50:
                last_print_size = tensor_size

                print(f"size: {tensor_size:,}")

            if tensor_size >= max_megabyte * 1024 * 1024:
                break
    except KeyboardInterrupt:
        pass
    tensor_batch = torch.cat(tensor_batch)
    torch.save(tensor_batch, output_filename)

store_dataset(iter_images())

In [None]:
ds = TensorDataset(torch.load(FILENAME))
dl = DataLoader(ds, shuffle=True, batch_size=16**2)
for batch in dl:
    img = VF.to_pil_image(make_grid(batch[0], nrow=16))
    break
img

In [None]:
t = torch.rand(16, 8)
torch.concat([t, t]).shape