In [None]:
"""
Adapted from https://github.com/milesial/Pytorch-UNet
 and from https://github.com/tinygrad/tinygrad/examples/stable_diffusion.py
"""

In [None]:
from tinygrad import Tensor
from tinygrad.nn import Conv2d, ConvTranspose2d, BatchNorm2d
from PIL import Image
import numpy as np
import os, random

In [None]:
def doubleconv(in_chan, out_chan):
    return [Conv2d(in_chan, out_chan, kernel_size=3, padding=1), BatchNorm2d(out_chan), Tensor.relu,
        Conv2d(out_chan, out_chan, kernel_size=3, padding=1), BatchNorm2d(out_chan), Tensor.relu]

class UNet:
    def __init__(self):
        self.save_intermediates = [
            doubleconv(3, 64), 
            [Tensor.max_pool2d, *doubleconv(64, 128)],
        ]
        self.middle = [
            Tensor.max_pool2d, *doubleconv(128, 256),
            ConvTranspose2d(256, 128, kernel_size=2, stride=2),
        ]
        self.consume_intermediates = [
            [*doubleconv(256, 128), ConvTranspose2d(128, 64, kernel_size=2, stride=2)],
            [*doubleconv(128, 64), Conv2d(64, 2, kernel_size=1)],
        ]

    def __call__(self, x):
        intermediates = []
        for b in self.save_intermediates:
            for bb in b:
                x = bb(x)
            intermediates.append(x)
        for bb in self.middle:
            x = bb(x)
        for b in self.consume_intermediates:
            x = intermediates.pop().cat(x, dim=1)
            for bb in b:
                x = bb(x)
        return x

In [None]:
unet = UNet()

In [None]:
x = Tensor.randn(1,3,100,100)
y = unet(x)
#assert x.shape == y.shape
y.shape

In [None]:
class DataLoader:
    def __init__(self, image_dir, mask_dir, patch_size=(64, 64), normalize=True, 
                 flip_prob=0.5, rotate_prob=0.5, noise_prob=0,):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.patch_size = patch_size
        self.normalize = normalize
        self.flip_prob = flip_prob
        self.rotate_prob = rotate_prob
        self.noise_prob = noise_prob
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.npz'))])
        self.mask_files = [f for f in os.listdir(image_dir) if f.endswith(('.npz'))]

    def get_batch(self, batch_size):
        # Randomly distribute samples across images
        shares = np.random.dirichlet(np.ones(len(self.image_files)), size=1)[0]
        result = np.round(shares * batch_size).astype(int)
        # Adjust to ensure sum is exactly batch_size
        diff = batch_size - result.sum()
        result[np.argmax(result)] += diff

        image_patches, mask_patches = [], []
        mask = np.load(os.path.join(self.mask_dir, self.mask_files[0]))['data']
        for i, num_samples in enumerate(result):
            image = np.load(os.path.join(self.image_dir, self.image_files[i]))['data']
            image = self._normalize(image) if self.normalize else image
            for _ in range(num_samples):
                ip, mp = self._random_crop(image, mask)
                ip, mp = self._apply_augmentations(ip, mp)
                image_patches.append(ip)
                mask_patches.append(mp)
        return np.array(image_patches), np.array(mask_patches)

    def _apply_augmentations(self, image, mask):
        if random.random() < self.flip_prob:
            image, mask = self._random_flip(image, mask)
        if random.random() < self.rotate_prob:
            image, mask = self._random_rotate(image, mask)
        if random.random() < self.noise_prob:
            image = self._random_noise(image)  # Apply noise only to the image, not the mask
        return image, mask

    def _random_crop(self, image, mask):
        h, w = image.shape[:2]
        new_h, new_w = self.patch_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image_patch = image[top:top+new_h, left:left+new_w]
        mask_patch = mask[top:top+new_h, left:left+new_w]

        return image_patch, mask_patch

    def _random_flip(self, image, mask):
        return np.fliplr(image), np.fliplr(mask)

    def _random_rotate(self, image, mask):
        k = random.choice([1, 2, 3])  # 90, 180, or 270 degrees
        return np.rot90(image, k), np.rot90(mask, k)

    def _random_noise(self, image):
        noise = np.random.normal(0, 0.05, image.shape)
        return np.clip(image + noise, 0, 1)

    def _normalize(self, image):
        #return (image - np.mean(image)) / np.std(image)
        normalized = np.zeros_like(image, dtype=np.float32)
        for i in range(image.shape[2]):
            channel = image[:,:,i]
            mean = np.mean(channel)
            std = np.std(channel)
            normalized[:,:,i] = (channel - mean) / (std + 1e-8)  # adding small epsilon to avoid division by zero
        return normalized

In [None]:
dl = DataLoader(
    image_dir="data/auto_crop/0",
    mask_dir="data/mask",
    normalize=False
)

In [None]:
x, y = dl.get_batch(10)
for a,b in zip(x,y):
    if np.any(b > 0):
        display(Image.fromarray(a))
        display(Image.fromarray(b * 255, mode="L"))
        print()