In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.cnn import *

In [None]:
#SHAPE = (3, 128, 128)
#dataset = TensorDataset(torch.load(f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt"))

SHAPE = (1, 128, 128)
dataset = TensorDataset(torch.load(f"../datasets/pattern-{SHAPE[-3]}x{SHAPE[-2]}x{SHAPE[-1]}-uint.pt"))

assert SHAPE == dataset[0][0].shape

In [None]:
class ContrastiveImageDataset(Dataset):
    """
    Returns tuple of two image crops and bool if the crops
    are from the same image.
    """
    def __init__(
            self,
            source_dataset: Dataset,
            crop_shape: Tuple[int, int],
            num_crops: int = 2,
            num_contrastive_crops: int = 2,
            prob_h_flip: float = .5,
            prob_v_flip: float = .5,
            prob_hue: float = .0,
            prob_saturation: float = .5,
            prob_brightness: float = .5,
            prob_grayscale: float = 0.,
            generator: Optional[torch.Generator] = None
    ):
        self.source_dataset = source_dataset
        self.crop_shape = crop_shape
        self.num_contrastive_crops = num_contrastive_crops
        self.num_crops = num_crops
        self.prob_h_flip = prob_h_flip
        self.prob_v_flip = prob_v_flip
        self.prob_hue = prob_hue
        self.prob_saturation = prob_saturation
        self.prob_brightness = prob_brightness
        self.prob_grayscale = prob_grayscale
        self.generator = torch.Generator() if generator is None else generator

        transforms = [self._crop]
        if prob_h_flip:
            transforms.append(self._h_flip)
        if prob_v_flip:
            transforms.append(self._v_flip)
        if prob_hue:
            transforms.append(self._hue)
        if prob_saturation:
            transforms.append(self._saturation)
        if prob_brightness:
            transforms.append(self._brightness)
        if prob_grayscale:
            transforms.append(self._to_grayscale)
        self.cropper = VT.Compose(transforms)

    def __len__(self) -> int:
        return len(self.source_dataset) * (self.num_crops + self.num_contrastive_crops)

    def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, bool]:
        all_crops = self.num_crops + self.num_contrastive_crops

        true_index = index // all_crops
        crop_index = index % all_crops

        image1 = image2 = self._get_image(true_index)
        is_same = True

        if crop_index >= self.num_crops:
            other_index = true_index
            while other_index == true_index:
                other_index = torch.randint(0, len(self.source_dataset) - 1, (1,), generator=self.generator).item()

            image2 = self._get_image(other_index)
            is_same = False

        return (
            self.cropper(image1),
            self.cropper(image2),
            is_same,
        )

    def _get_image(self, index: int):
        image = self.source_dataset[index]
        if isinstance(image, (tuple, list)):
            image = image[0]
        return image

    def _crop(self, image: torch.Tensor) -> torch.Tensor:
        h, w = image.shape[-2:]
        x = torch.randint(0, h - self.crop_shape[0] + 1, size=(1,), generator=self.generator).item()
        y = torch.randint(0, w - self.crop_shape[1] + 1, size=(1,), generator=self.generator).item()

        return VF.crop(image, y, x, self.crop_shape[0], self.crop_shape[1])

    def _h_flip(self, image: torch.Tensor) -> torch.Tensor:
        doit = torch.rand(1, generator=self.generator).item() < self.prob_h_flip
        return VF.hflip(image) if doit else image

    def _v_flip(self, image: torch.Tensor) -> torch.Tensor:
        doit = torch.rand(1, generator=self.generator).item() < self.prob_v_flip
        return VF.vflip(image) if doit else image

    def _hue(self, image: torch.Tensor) -> torch.Tensor:
        amt = torch.rand(1, generator=self.generator).item() - .5
        return VF.adjust_hue(image, amt)

    def _saturation(self, image: torch.Tensor) -> torch.Tensor:
        amt = torch.rand(1, generator=self.generator).item() * 2.
        return VF.adjust_saturation(image, amt)

    def _brightness(self, image: torch.Tensor) -> torch.Tensor:
        amt = torch.rand(1, generator=self.generator).item() + .5
        return VF.adjust_brightness(image, amt)

    def _to_grayscale(self, image: torch.Tensor) -> torch.Tensor:
        doit = torch.rand(1, generator=self.generator).item() < self.prob_grayscale
        return VF.rgb_to_grayscale(image, image.shape[0]) if doit else image
        
                
cds = ContrastiveImageDataset(dataset, (64, 64), generator=torch.Generator().manual_seed(23))
print(f"size: {len(cds):,}")
for i in range(30):
    i1, i2, is_same = cds[i]
    print(is_same)
    display(VF.to_pil_image(make_grid([i1, i2])))

In [None]:
def plot_positive_negative(cds: ContrastiveImageDataset, total: int = 8*8):
    positive, negative = [], []
    for i1, i2, is_same in DataLoader(cds, shuffle=True):
        the_list = positive if is_same else negative
        if len(the_list) < total:
            the_list.append(i1.squeeze(0))
            the_list.append(i2.squeeze(0))
        else:
            if len(positive) >= total and len(negative) >= total:
                break

    display(VF.to_pil_image(make_grid([
        make_grid(positive, nrow=8),
        make_grid(negative, nrow=8),
    ])))
    
plot_positive_negative(ContrastiveImageDataset(
        dataset, crop_shape=(64, 64),
        num_crops=1, num_contrastive_crops=1,
        prob_h_flip=.5,
        prob_v_flip=.5,
        prob_hue=.0,
        prob_saturation=0.,
        prob_brightness=0.9,
        prob_grayscale=1.,
))

In [None]:
image = dataset[0][0]
w = torch.rand((image.shape[0], image.shape[0], 1))
#w.shape
#image = VF.adjust_hue(image, .4)
#image = VF.adjust_saturation(image, 2.)
image = VF.adjust_brightness(image, 2)
VF.to_pil_image(image)

In [None]:
features = torch.randn(3, 10)
features

In [None]:
features = features / torch.norm(features, dim=-1, keepdim=True)

In [None]:
features.min()

In [None]:
i = torch.rand(1, 3, 64, 64)
nn.Sequential(
    nn.Conv2d(3, 5, 7),
    nn.MaxPool2d(5),
)(i).shape