In [None]:
"""
Adapted from https://github.com/milesial/Pytorch-UNet
 and from https://github.com/tinygrad/tinygrad/examples/stable_diffusion.py
"""

In [None]:
from tinygrad import Tensor
from tinygrad.nn import Conv2d, ConvTranspose2d, BatchNorm2d

In [None]:
def doubleconv(in_chan, out_chan):
    return [Conv2d(in_chan, out_chan, kernel_size=3, padding=1), BatchNorm2d(out_chan), Tensor.relu,
        Conv2d(out_chan, out_chan, kernel_size=3, padding=1), BatchNorm2d(out_chan), Tensor.relu]

class UNet:
    def __init__(self):
        self.save_intermediates = [
            doubleconv(3, 64), 
            [Tensor.max_pool2d, *doubleconv(64, 128)],
        ]
        self.middle = [
            Tensor.max_pool2d, *doubleconv(128, 256),
            ConvTranspose2d(256, 128, kernel_size=2, stride=2),
        ]
        self.consume_intermediates = [
            [*doubleconv(256, 128), ConvTranspose2d(128, 64, kernel_size=2, stride=2)],
            [*doubleconv(128, 64), Conv2d(64, 3, kernel_size=1)],
        ]

    def __call__(self, x):
        intermediates = []
        for b in self.save_intermediates:
            for bb in b:
                x = bb(x)
            intermediates.append(x)
        for bb in self.middle:
            x = bb(x)
        for b in self.consume_intermediates:
            x = intermediates.pop().cat(x, dim=1)
            for bb in b:
                x = bb(x)
        return x

In [None]:
unet = UNet()

In [None]:
x = Tensor.randn(1,3,100,100)
y = unet(x)
assert x.shape == y.shape