In [None]:
from init_notebook import *

In [None]:
dataset = datasets.rpg_tile_dataset_3x32x32(shape=(3, 64, 64))
#dataset = datasets.fmnist_dataset(train=True)
dataset = datasets.kali_patch_dataset((3, 64, 64))

patches = next(iter(DataLoader(dataset, batch_size=64)))
print(patches.shape)
VF.to_pil_image(resize(make_grid(patches), 2))

In [None]:
class RandomCropHalfImage(nn.Module):

    def __init__(
            self,
            prob: float = 1.,
            null_value: float = 0.,
    ):
        super().__init__()
        self.prob = prob
        self.null_value = null_value

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if input.ndim == 3:
            return self._crop_half(input)
        elif input.ndim == 4:
            return torch.concat([
                self._crop_half(img).unsqueeze(0)
                for img in input
            ])
        else:
            raise ValueError(f"input must have 3 or 4 dimensions, got {input}")

    def _crop_half(self, image: torch.Tensor) -> torch.Tensor:
        if random.uniform(0, 1) >= self.prob:
            return image
            
        lrtb = random.randrange(4)
        if lrtb == 0:
            slices = slice(None, None), slice(None, image.shape[-1] // 2) 
        elif lrtb == 1:
            slices = slice(None, None), slice(image.shape[-1] // 2, None) 
        elif lrtb == 2:
            slices = slice(None, image.shape[-2] // 2), slice(None, None)
        else:
            slices = slice(image.shape[-2] // 2, None), slice(None, None) 
        
        new_image = image + 0
        new_image[:, slices[0], slices[1]] = self.null_value
        return new_image
        

noise_patches = RandomCropHalfImage(prob=1.)(patches)
VF.to_pil_image(resize(make_grid(noise_patches), 2))

In [None]:
class ImageNoise(nn.Module):

    def __init__(
            self,
            amt_min: float = .01,
            amt_max: float = .15,
            amt_power: float = 2.,
            grayscale_prob: float = .1,
            prob: float = 1.,
    ):
        super().__init__()
        self.amt_min = amt_min
        self.amt_max = amt_max
        self.amt_power = amt_power
        self.grayscale_prob = grayscale_prob
        self.prob = prob

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if input.ndim == 3:
            return self._noise(input)
        elif input.ndim == 4:
            return torch.concat([
                self._noise(img).unsqueeze(0)
                for img in input
            ])
        else:
            raise ValueError(f"input must have 3 or 4 dimensions, got {input}")

    def _noise(self, image: torch.Tensor) -> torch.Tensor:
        if random.uniform(0, 1) >= self.prob:
            return image
            
        amt = math.pow(random.uniform(0, 1), self.amt_power)
        amt = self.amt_min + (self.amt_max - self.amt_min) * amt

        if random.uniform(0, 1) < self.grayscale_prob:
            noise = torch.randn_like(image[..., :1, :, :]).repeat(
                *(1 for _ in range(image.ndim - 3)),
                image.shape[-3], 1, 1
            )
        else:
            noise = torch.randn_like(image)

        return (image + amt * noise).clamp(0, 1)


noise_patches = ImageNoise(amt_min=.1, amt_max=.1, prob=1., grayscale_prob=.5)(patches)
VF.to_pil_image(resize(make_grid(noise_patches), 2))

In [None]:
class ImageMultiNoise(nn.Module):

    def __init__(
            self,
            amt_min: float = .01,
            amt_max: float = .15,
            amt_power: float = 2.,
            blur_sigma_min: float = 0.,
            blur_sigma_max: float = 1.,
            prob: float = 1.,
            channel_modes: Optional[List[str]] = None,
            distribution_modes: Optional[List[str]] = None,
    ):
        super().__init__()    
        self.amt_min = amt_min
        self.amt_max = amt_max
        self.amt_power = amt_power
        self.blur_sigma_min = blur_sigma_min
        self.blur_sigma_max = blur_sigma_max
        self.prob = prob
        if channel_modes is None:
            channel_modes = ["white", "color"]
        self.channel_modes = channel_modes
        if distribution_modes is None:
            distribution_modes = ["gauss", "positive", "negative", "positive-negative"]
        self.distribution_modes = distribution_modes
        

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if input.ndim == 3:
            return self._noise(input)
        elif input.ndim == 4:
            return torch.concat([
                self._noise(img).unsqueeze(0)
                for img in input
            ])
        else:
            raise ValueError(f"input must have 3 or 4 dimensions, got {input}")

    def _noise(self, image: torch.Tensor) -> torch.Tensor:
        if random.uniform(0, 1) >= self.prob:
            return image
            
        amt = math.pow(random.uniform(0, 1), self.amt_power)
        amt = self.amt_min + (self.amt_max - self.amt_min) * amt

        channel_mode = random.choice(self.channel_modes)
        distribution_mode = random.choice(self.distribution_modes)

        if distribution_mode == "gauss":
            rand_func = torch.randn_like
        elif distribution_mode in ("positive", "negative"):
            rand_func = torch.rand_like
        elif distribution_mode == "positive-negative":
            rand_func = lambda x: torch.rand_like(x) - torch.rand_like(x)
        else:
            raise ValueError(f"Unknown distribution_mode '{distribution_mode}'")
            
        if channel_mode == "white":
            noise = rand_func(image[..., :1, :, :]).repeat(
                *(1 for _ in range(image.ndim - 3)),
                image.shape[-3], 1, 1
            )
        elif channel_mode == "color":
            noise = rand_func(image)
        else:
            raise ValueError(f"Unknown channel_mode '{channel_mode}'")

        blur_sigma = random.uniform(self.blur_sigma_min, self.blur_sigma_max)
        if blur_sigma > 0.:
            noise = VF.gaussian_blur(noise, 5, blur_sigma)
        
        if distribution_mode == "negative":
            noise = -noise

        return (image + amt * noise).clamp(0, 1)


noise_patches = ImageMultiNoise(amt_min=.0, amt_max=.3)(patches)
VF.to_pil_image(resize(make_grid(noise_patches), 2))

In [None]:
noises = [
    torch.rand(10000) - torch.rand(10000),
    torch.randn(10000),
]
    
for data in noises:
    display(px.histogram(data))


In [None]:
class PixelartDataset(Dataset):

    LABELS = [
        'creature',
        'wall',
        'tree',
        'carpet',
        'grass',
        'rock',
        'water',
        'wood',
        'sand',
        'roof',
        'sword',
        'cobblestone',
        'plant',
        'platform',
        'stairs',
        'shelf',
        'block',
        'axe',
        'food',
        'door',
        'dirt',
        'window',
        'pipe',
        'floor',
        'table',
        'bridge',
        'stone',
        'bed',
        'fire',
        'other',
    ]
    
    def __init__(
        self,
        shape: Tuple[int, int, int] = (3, 32, 32),
    ):
        self._out_shape = shape
        self._patch_dataset = None
        self._label_to_id = {l: i for i, l in enumerate(self.LABELS)}
        self._fallback_id = self._label_to_id["other"]
        
    def __len__(self):
        self._lazy_load()
        return self._meta["count"]

    def _lazy_load(self):
        if self._patch_dataset is None:
            path = Path("~/prog/python/github/pixelart-dataset/datasets/v2/").expanduser()
            self._meta = json.loads((path / "tiles.json").read_text())
            patch_shape = (self._out_shape[0], *self._meta["shape"])
            self._patch_dataset = ImagePatchDataset((3, 32, 32), path / "tiles.png")
            self._patch_df = pd.read_csv(path / "tiles.csv")
    
    def __getitem__(self, index: int):
        self._lazy_load()
        item = self._patch_dataset[index]
        label = self._patch_df.iloc[index]["label"]
        
        if label not in self._label_to_id:
            for base_label in self.LABELS:
                if base_label in label:
                    self._label_to_id[label] = self._label_to_id[base_label]
                    break
                    
        #print(self._patch_df.iloc[index])
        return item, self._label_to_id.get(label, self._fallback_id)

ds = PixelartDataset()

counts = {}
for image, label in tqdm(ds):
    label = ds.LABELS[label]
    counts[label] = counts.get(label, 0) + 1

counts

In [None]:
len(counts)

In [None]:
pd.DataFrame(counts, index=["count"]).T.sort_values("count", ascending=False).plot.bar()

In [None]:
label_dist = {'creature': 2003,
    'wall': 1577,
    'tree': 734,
    'carpet': 645,
    'grass': 593,
    'rock': 530,
    'water': 515,
    'wood': 432,
    'sand': 369,
    'roof': 349,
    'sword': 319,
    'cobblestone': 285,
    'plant': 273,
    'platform': 270,
    'stairs': 257,
    'shelf': 254,
    'block': 234,
    'axe': 211,
    'grass/sand': 209,
    'food': 198,
    'door': 185,
    'dirt': 183,
    'window': 173,
    'pipe': 167,
    'floor': 166,
    'table': 163,
    'bridge': 157,
    'sceptre': 133,
    'stone': 127,
    'bed': 125,
    'fire': 125,
    'statue': 122,
    'shield': 122,
    'boundary': 119,
    'ice': 114,
    'steel': 111,
    'glow': 109,
    'fish': 109,
    'armour': 101,
    'grass/rock': 100,
    'dirt/sand': 96,
    'book': 89,
    'fence': 88,
    'terrain': 88,
    'chair': 88,
    'symbol': 84,
    'cobblestone/grass': 77,
    'dirt/rock': 76,
    'sand/water': 76,
    'spear': 74,
    'dirt/grass': 72,
    'column': 71,
    'potion': 71,
    'doorway': 69,
    'tombstone': 67,
    'boat': 65,
    'helmet': 61,
    'gear': 56,
    'arrow': 55,
    'jewelry': 54,
    'scroll': 52,
    'fountain': 52,
    'pot': 51,
    'ring': 51,
    'chest': 50,
    'lamp': 50,
    'rock/water': 50,
    'cloth': 50,
    'bow': 45,
    'doorway/wall': 45,
    'boots': 44,
    'rock/wall': 44,
    'grass/tree': 43,
    'sign': 38,
    'explosion': 37,
    'machine': 37,
    'wall/window': 35,
    'cobblestone/dirt': 32,
    'dirt/water': 32,
    'slime': 31,
    'grass/water': 30,
    'wall/water': 30,
    'rock/terrain': 29,
    'barrel': 29,
    'hand': 29,
    'ice/water': 29,
    'well': 28,
    'ladder': 28,
    'sphere': 28,
    'coin': 26,
    'grass/platform': 26,
    'lava': 25,
    'background': 24,
    'leaves': 24,
    'tool': 23,
    'skull': 21,
    'snow': 21,
    'grass/terrain': 20,
    'creature/tree': 20,
    'curtain': 18,
    'ground': 17,
    'house': 17,
    'dirt/wall': 16,
    'anvil': 16,
    'mace': 16,
    'clock': 16,
    'cloud': 14,
    'head': 14,
    'steel/wood': 14,
    'mountain': 14,
    'key': 13,
    'mushroom': 12,
    'spikes': 12,
    'door/wall': 12,
    'bones': 11,
    'sand/tree': 11,
    'barricade': 10,
    'gun': 10,
    'hat': 10,
    'dirt/tree': 10,
    'stone/terrain': 9,
    'stars': 9,
    'background/wall': 9,
    'crystal': 8,
    'ornament': 8,
    'roof/wood': 8,
    'sarcophargus': 7,
    'grass/stone': 7,
    'background/doorway': 6,
    'blood': 6,
    'carpet/wood': 6,
    'robot': 4,
    'clock/tree': 4,
    'bomb': 3,
    'stain': 3,
    'block/steel': 3,
    'hole': 3,
    'fire/glow': 3,
    'cobblestone/water': 3,
    'creature/fire': 3,
    'gloves': 2,
    'plant/pot': 2,
    'paper': 2,
    'tank': 2,
    'cobblestone/sand': 2,
    'boomerang': 2,
    'dirt/grass/rock': 2,
    'block/wood': 2,
    'flag': 2,
    'window/wood': 2,
    'bucket': 1,
    'sand/wall': 1,
    'terrain/wall': 1,
    'tunnel': 1,
    'spaceship': 1,
    'ground/skull': 1,
    'explosion/glow': 1,
    'skull/symbol': 1,
    'background/cobblestone': 1,
    'head/helmet': 1,
    'arrow/sword': 1,
    'lamp/wall': 1,
    'candle': 1,
    'floor/hole/wall': 1,
    'carpet/sword': 1,
    'door/wood': 1
}

In [None]:
LABELS = [
    'creature',
    'wall',
    'tree',
    'carpet',
    'grass',
    'sand',
    'cobblestone',
    'rock',
    'water',
    'wood',
    'roof',
    'sword',
    'weapon',
    'plant',
    'platform',
    'stairs',
    'shelf',
    'block',
    'food',
    'door',
    'dirt',
    'window',
    'pipe',
    'floor',
    'furniture',
    "fire",
    "statue",
    "armour",
    "boundary",
]
len(LABELS)

label_mapping = {
    "bridge": "wood",
    "sword": "weapon",
    "axe": "weapon",
    "sceptre": "weapon",
    "spear": "weapon",
    "stone": "rock",
    "bed": "furniture",
    "table": "furniture",
    "chair": "furniture",
    "shield": "armour",
    "glow": "fire",
    "fish": "food",
    "arrow": "weapon",
}

for label, c in label_dist.items():
    label = label_mapping.get(label, label)
    if label not in LABELS:
        found = False
        for top_label in LABELS:
            if top_label in label:
                label = top_label
                found = True
                break
        if not found:
            print(label, c)
            continue
            
    

In [None]:
list(label_dist.keys())[:31]

In [None]:
class ImageNoiseDataset(Dataset):

    def __init__(
            self,
            image_dataset: Dataset,
            amt_min: float = .01,
            amt_max: float = .15,
            amt_power: float = 1.,
            amounts_per_arg: Iterable[float] = (1,),
            grayscale_prob: float = .0,
            prob: float = 1.,
    ):
        super().__init__()
        self._image_dataset = image_dataset
        self._amt_min = amt_min
        self._amt_max = amt_max
        self._amt_power = amt_power
        self._amounts_per_arg = amounts_per_arg
        self._grayscale_prob = grayscale_prob
        self._prob = prob

    def __len__(self):
        return len(self._image_dataset)

    def __getitem__(self, index):
        item = self._image_dataset[index]

        is_tuple = isinstance(item, (list, tuple))
        if is_tuple:
            image, *rest = item
        else:
            image, rest = item, []

        amt = math.pow(random.uniform(0, 1), self._amt_power)
        amt = self._amt_min + (self._amt_max - self._amt_min) * amt
        is_grayscale = random.uniform(0, 1) < self._grayscale_prob
        
        if random.uniform(0, 1) >= self._prob:
            noisy_images = [image for _ in self._amounts_per_arg]
        else:
            noisy_images = []
            noisy_image = image
            for sub_amt in self._amounts_per_arg:
                if is_grayscale:
                    noise = torch.randn_like(image[..., :1, :, :]).repeat(
                        *(1 for _ in range(image.ndim - 3)),
                        image.shape[-3], 1, 1
                    )
                else:
                    noise = torch.randn_like(image)

                noisy_image = (noisy_image + sub_amt * amt * noise).clamp(0, 1)
                noisy_images.append(noisy_image)

        if is_tuple:
            return *noisy_images, *rest
        else:
            if len(noisy_images) == 1:
                return noisy_images[0]
            else:
                return *noisy_images,

ds = ImageNoiseDataset(
    PixelartDataset(), 
    amt_min=.1, amt_max=1.,
    amounts_per_arg=(1., 1.3),
)
images = []
for i, tup in zip(range(32), ds):
    images.append(tup[0])
    images.append(tup[1])
VF.to_pil_image(resize(make_grid(images), 2))

In [None]:
sampler = torch.utils.data.sampler.RandomSampler([1, 2, 3, 4])

In [None]:
for i in sampler:
    print(i)