In [None]:
import sys
sys.path.append("..")

import random
import math
import time
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.encoder import *
from src.models.decoder import *
from src.models.util import *

In [None]:
model = EncoderConv2d((3, 32, 32), 128, channels=(8, 12, 14), kernel_size=(3, 4, 5))
model

In [None]:
model2 = VQVAE()
model2

In [None]:
@torch.no_grad()
def get_model_weight_images(
        model: nn.Module,
        grad_only: bool = True,
        max_channels: int = 16,
        min_size: int = 2,
        max_size: int = 128,
        normalize: str = "all",  # "each", "shape", "all", "none"
        size_to_scale: Dict[int, float] = {10: 4, 20: 2},
):
    from torchvision.utils import make_grid
    from src.util.image import signed_to_image

    # yield 2d shapes
    def _iter_params():
        for param in model.parameters():
            if not param.requires_grad and grad_only:
                continue
            if param.ndim == 2:
                yield param
            elif param.ndim == 4:
                for ch in range(min(max_channels, param.shape[0])):
                    yield param[ch, 0]
                for ch in range(min(max_channels, param.shape[1])):
                    yield param[0, ch]

    shape_dict = {}
    for param in _iter_params():
        if any(s < min_size for s in param.shape):
            continue
        param = param[:max_size, :max_size]
        
        scale = None
        for key in sorted(size_to_scale):
            value = size_to_scale[key]
            if all(s <= key for s in param.shape):
                scale = value
                break
        
        if scale:
            param = VF.resize(
                param.unsqueeze(0),
                [s * scale for s in param.shape], VF.InterpolationMode.NEAREST, antialias=False
            ).squeeze(0)

        if param.shape not in shape_dict:
            shape_dict[param.shape] = []
        shape_dict[param.shape].append(param)

    grids = []
    for shape in sorted(shape_dict):
        params = shape_dict[shape]
        nrow = max(1, int(math.sqrt(len(params)) * 2))
        if normalize == "each":
            grids.append(make_grid([signed_to_image(p) for p in params], nrow=nrow))
        else:
            grids.append(make_grid([p.unsqueeze(0) for p in params], nrow=nrow))
        
    max_width = max(g.shape[-1] for g in grids)

    for image_idx, image in enumerate(grids):
        if image.shape[-1] < max_width:
            grids[image_idx] = VF.pad(image, [0, 0, max_width - image.shape[-1], 0])

    if normalize == "shape":
        grids = [signed_to_image(g) for g in grids]

    grids = torch.concat([
        VF.pad(grid, [0, 0, 0, 2])
        for grid in grids
    ], dim=-2)

    if normalize == "all":
        grids = signed_to_image(grids)

    return grids
    
    
model3 = EncoderConv2d((3, 32, 32), code_size=2, channels=(24, 32, 48), kernel_size=11)
VF.to_pil_image(get_model_weight_images(model3, normalize="all"))

In [None]:
VF.pad?

In [None]:
from experiments import datasets
ds1 = datasets.rpg_tile_dataset_3x32x32((3, 32, 32))
ds2 = datasets.mnist_dataset((3, 32, 32))

In [None]:
class CombineImageAugmentIterableDataset(IterableDataset):
    def __init__(
            self,
            dataset: Union[Dataset, IterableDataset],
            ratio: float = .5,
            crop_ratio: Union[float, Tuple[float, float]] = .5,
            batch_size: int = 128,
    ):
        assert batch_size > 1
        #num_aug = int(batch_size * ratio)
        #if num_aug < 1:
        #    raise ValueError(f"`batch_size` * `ratio` must be >= 1")
            
        self.dataset = dataset
        self.ratio = ratio
        self.batch_size = batch_size
        self.crop_ratio = (crop_ratio, crop_ratio) if isinstance(crop_ratio, (float, int)) else tuple(crop_ratio)
        
    def __iter__(self):
        num_aug = int(self.batch_size * self.ratio)
        
        for batch in iter_batches(self.dataset, self.batch_size):
            yield from random_combine_image_crops(batch)

            break

            
for batch in DataLoader(CombineImageAugmentIterableDataset(ds2, ratio=.2), batch_size=64):
    images = batch[0]
    is_aug = batch[1]
    #print(batch.shape)
    display(VF.to_pil_image(
        make_grid_labeled(images, nrow=8, labels=["X" if a else "" for a in is_aug])
    ))
    break

In [None]:
def random_combine_image_crops(
        images: torch.Tensor,
        ratio: float = .5,
        crop_ratio: Union[float, Tuple[float, float]] = .5,
):
    crop_ratio = (crop_ratio, crop_ratio) if isinstance(crop_ratio, (float, int)) else tuple(crop_ratio)
    ret_images = []
    
    for image_idx, image in enumerate(images): 
        if random.random() > ratio:
            ret_images.append(image)
        else:
            while True:
                other_idx = random.randrange(images.shape[0])
                if other_idx != image_idx:
                    break
            other_image = images[other_idx]

            crop_size = [
                random.uniform(*crop_ratio)
                for i in range(2)
            ]
            crop_size = [
                max(1, min(int(c * image.shape[i + 1]), image.shape[i + 1] - 1))
                for i, c in enumerate(crop_size)
            ]
            source_pos = [random.randrange(0, s - crop_size[i]) for i, s in enumerate(other_image.shape[-2:])]
            target_pos = [random.randrange(0, s - crop_size[i]) for i, s in enumerate(other_image.shape[-2:])]

            image[:, target_pos[0]: target_pos[0] + crop_size[0], target_pos[1]: target_pos[1] + crop_size[1]] = \
                other_image[:, source_pos[0]: source_pos[0] + crop_size[0], source_pos[1]: source_pos[1] + crop_size[1]]

            ret_images.append(image)
    
    return torch.concat([i.unsqueeze(0) for i in ret_images], dim=0)

images = next(iter(DataLoader(ds, batch_size=64)))
VF.to_pil_image(make_grid(random_combine_image_crops(images[0])))

In [None]:
a = torch.Tensor([0, 1, 2, 3, 4])
b = a.clone()
b[1] = 5
a, b