In [None]:
"""
Adapted from:
https://github.com/milesial/Pytorch-UNet
https://github.com/tinygrad/tinygrad/examples/stable_diffusion.py
https://docs.tinygrad.org/mnist/
"""

In [None]:
from tinygrad import Tensor, TinyJit, nn
from tinygrad.nn import Conv2d, ConvTranspose2d, BatchNorm2d
from tinygrad.dtype import dtypes
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, load_state_dict
from PIL import Image
import numpy as np
import os, random

In [None]:
def doubleconv(in_chan, out_chan):
    return [Conv2d(in_chan, out_chan, kernel_size=3, padding=1), BatchNorm2d(out_chan), Tensor.relu,
        Conv2d(out_chan, out_chan, kernel_size=3, padding=1), BatchNorm2d(out_chan), Tensor.relu]

class UNet:
    def __init__(self):
        self.save_intermediates = [
            doubleconv(3, 64), 
            [Tensor.max_pool2d, *doubleconv(64, 128)],
        ]
        self.middle = [
            Tensor.max_pool2d, *doubleconv(128, 256),
            ConvTranspose2d(256, 128, kernel_size=2, stride=2),
        ]
        self.consume_intermediates = [
            [*doubleconv(256, 128), ConvTranspose2d(128, 64, kernel_size=2, stride=2)],
            [*doubleconv(128, 64), Conv2d(64, 2, kernel_size=1)],
        ]

    def __call__(self, x):
        intermediates = []
        for b in self.save_intermediates:
            for bb in b:
                x = bb(x)
            intermediates.append(x)
        for bb in self.middle:
            x = bb(x)
        for b in self.consume_intermediates:
            x = intermediates.pop().cat(x, dim=1)
            for bb in b:
                x = bb(x)
        return x

In [None]:
class ImageMaskPair:
    def __init__(self, image_path, mask_path):
        self.image_path = image_path
        self.mask_path = mask_path

    def load_image(self):
        return np.load(self.image_path)['data']

    def load_mask(self):
        return np.load(self.mask_path)['data']

class DataLoader:
    def __init__(self, image_dir, mask_dir, patch_size=(64, 64), normalize=True, 
                 flip_prob=0.5, rotate_prob=0.5, noise_prob=0,):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.patch_size = patch_size
        self.normalize = normalize
        self.flip_prob = flip_prob
        self.rotate_prob = rotate_prob
        self.noise_prob = noise_prob
        self.image_mask_pairs = self.get_image_mask_pairs()

    def get_image_mask_pairs(self):
        ret = []
        for subdir in os.listdir(self.image_dir):
            for file in os.listdir(os.path.join(self.image_dir, subdir)):
                im_file = os.path.join(self.image_dir, subdir, file)
                mask_file = os.path.join(self.mask_dir, subdir + ".npz")
                ret.append(ImageMaskPair(im_file, mask_file))

        ret = sorted(ret, key = lambda x: x.mask_path)
        return ret

    def get_batch(self, batch_size):
        # Randomly distribute samples across images
        shares = np.random.dirichlet(np.ones(len(self.image_mask_pairs)), size=1)[0]
        result = np.round(shares * batch_size).astype(int)
        # Adjust to ensure sum is exactly batch_size
        diff = batch_size - result.sum()
        result[np.argmax(result)] += diff

        image_patches, mask_patches = [], []
        mask_cache = {}
        for i, num_samples in enumerate(result):
            imp = self.image_mask_pairs[i]
            if mask_cache.get(imp.mask_path) is None:
                # We sorted image_mask_pairs by mask, so we don't need to cache previous masks
                mask_cache = {imp.mask_path: imp.load_mask()}
            mask = mask_cache[imp.mask_path]
            image = imp.load_image()
            image = self._normalize(image) if self.normalize else image
            for _ in range(num_samples):
                ip, mp = self._random_crop(image, mask)
                ip, mp = self._apply_augmentations(ip, mp)
                image_patches.append(ip)
                mask_patches.append(mp)
        image_patches = Tensor(image_patches).permute(0,3,1,2)
        return image_patches, Tensor(mask_patches)

    def _apply_augmentations(self, image, mask):
        if random.random() < self.flip_prob:
            image, mask = self._random_flip(image, mask)
        if random.random() < self.rotate_prob:
            image, mask = self._random_rotate(image, mask)
        if random.random() < self.noise_prob:
            image = self._random_noise(image)  # Apply noise only to the image, not the mask
        return image, mask

    def _random_crop(self, image, mask):
        h, w = image.shape[:2]
        new_h, new_w = self.patch_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image_patch = image[top:top+new_h, left:left+new_w]
        mask_patch = mask[top:top+new_h, left:left+new_w]

        return image_patch, mask_patch

    def _random_flip(self, image, mask):
        return np.fliplr(image), np.fliplr(mask)

    def _random_rotate(self, image, mask):
        k = random.choice([1, 2, 3])  # 90, 180, or 270 degrees
        return np.rot90(image, k), np.rot90(mask, k)

    def _random_noise(self, image):
        noise = np.random.normal(0, 0.05, image.shape)
        return np.clip(image + noise, 0, 1)

    def _normalize(self, image):
        #return (image - np.mean(image)) / np.std(image)
        normalized = np.zeros_like(image, dtype=np.float32)
        for i in range(image.shape[2]):
            channel = image[:,:,i]
            mean = np.mean(channel)
            std = np.std(channel)
            normalized[:,:,i] = (channel - mean) / (std + 1e-8)  # adding small epsilon to avoid division by zero
        return normalized

    def prep(self, image):
        return Tensor(self._normalize(image)).permute(2,0,1).unsqueeze(0)

    def split_image_into_chunks(self, image, chunk_size=64):
        height, width, channels = image.shape
        chunks_h = height // chunk_size
        chunks_w = width // chunk_size
        reshaped = image.reshape(chunks_h, chunk_size, chunks_w, chunk_size, channels)
        return reshaped.transpose(0, 2, 1, 3, 4).reshape(-1, chunk_size, chunk_size, channels)
        
    def image_to_model_input(self, image, chunk_size=64):
        chunks = self.split_image_into_chunks(self._normalize(image), chunk_size)
        return Tensor(chunks).permute(0,3,1,2)
    
    def synthesize_image_from_chunks(self, chunks, original_shape):
        height, width, channels = original_shape
        chunk_size = chunks.shape[1]  # Assuming chunks are square
        chunks_h = height // chunk_size
        chunks_w = width // chunk_size
        reshaped = chunks.reshape(chunks_h, chunks_w, chunk_size, chunk_size, channels)
        transposed = reshaped.transpose(0, 2, 1, 3, 4)
        return transposed.reshape(height, width, channels)

In [None]:
dl = DataLoader(
    image_dir="data/auto_crop",
    mask_dir="data/mask",
    patch_size=(64,64),
)

In [None]:
dl.normalize=False
for a,b in zip(*dl.get_batch(8)):
    a = a.numpy().astype(np.uint8).transpose(1,2,0)
    b = b.numpy().astype(np.uint8) * 255
    if np.any(b > 0):
        display(Image.fromarray(a))
        display(Image.fromarray(b, mode="L"))
dl.normalize=True

In [None]:
model = UNet()
X, Y = dl.get_batch(8)
pred = model(X)
s = pred.shape
pred.permute(0,2,3,1).reshape(-1, s[1]).cross_entropy(Y.reshape(-1)).item()

In [None]:
optim = nn.optim.Adam(nn.state.get_parameters(model))
batch_size = 128
def step():
    Tensor.training = True 
    X, Y = dl.get_batch(batch_size)
    optim.zero_grad()
    pred = model(X)
    s = pred.shape
    # Need to flatten for cross_entropy to work
    loss = pred.permute(0,2,3,1).reshape(-1, s[1]).cross_entropy(Y.reshape(-1)).backward()
    optim.step()
    return loss
jit_step = TinyJit(step)

In [None]:
for step in range(1000):
    loss = jit_step()
    if step%10 == 0:
        Tensor.training = False
        X_test, Y_test = dl.get_batch(batch_size)
        acc = (model(X_test).argmax(axis=1) == Y_test).mean().item()
        print(f"step {step:4d}, loss {loss.item():.2f}, acc {acc*100.:.2f}%")

In [None]:
x, y = dl.get_batch(10)

y_pred = model(x).argmax(axis=1).cast(dtypes.uint8).numpy()
y = y.cast(dtypes.uint8).numpy()
for a,b in zip(y_pred,y):
    #if np.any(b > 0):
    if True:
        display(Image.fromarray(a * 255, mode="L"))
        display(Image.fromarray(b * 255, mode="L"))
        print("---------------------------------")

In [None]:
model_name = "unet"
safe_save(get_state_dict(model), f"data/model/{model_name}.safetensors")

In [None]:
model_name = "unet"
state_dict = safe_load(f"data/model/{model_name}.safetensors")
load_state_dict(model, state_dict)

In [None]:
def pad_to_square_multiple(array, square_size):
    h, w, c = array.shape
    new_h = int(np.ceil(h / square_size) * square_size)
    new_w = int(np.ceil(w / square_size) * square_size)
    pad_h = new_h - h
    pad_w = new_w - w
    return np.pad(array, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')

In [None]:
gt = Image.fromarray(np.load("data/mask/0.npz")['data'] * 255, mode="L")
gt

In [None]:
x = np.load("data/auto_crop/0/3.npz")['data']
#x = Image.open("data/layout/raw/coast-1.png")
x = np.array(x)[:,:,:3]
padded_image = pad_to_square_multiple(x, 64)
display(Image.fromarray(padded_image))
x = dl.prep(padded_image)
y = model(x).argmax(axis=1).cast(dtypes.uint8).numpy().squeeze(0)
display(Image.fromarray(y * 255, mode="L"))

In [None]:
x = pad_to_square_multiple(np.load("data/auto_crop/0/3.npz")['data'], 64)
#x = Image.open("data/layout/raw/coast-9ex.png")
x = pad_to_square_multiple(np.array(x)[:,:,:3], 64)
original_shape = x.shape
x = dl.image_to_model_input(x, chunk_size=64)
y = model(x).argmax(axis=1, keepdim=True).cast(dtypes.uint8).permute(0,2,3,1).numpy()
y = dl.synthesize_image_from_chunks(y, (*original_shape[0:2], 1)).squeeze(-1)
display(Image.fromarray(y * 255, mode="L"))