In [None]:
from init_notebook import *
from typing import Literal
from src.models.fractal import FractalBaseLayer

In [None]:
def coord_grid(
    width: int = 256, 
    height: int = 256, 
    min_x: float = -2.,
    max_x: float = 2.,
    min_y: float = -2.,
    max_y: float = 2.,
):
    return torch.concat([
        g[None, ...]
        for g in torch.meshgrid(
            torch.linspace(min_y, max_y, height),
            torch.linspace(min_x, max_x, width), 
        )
        ]
    )

In [None]:
class MandelbrotLayer(FractalBaseLayer):

    def __init__(
            self,
            param: Union[torch.Tensor, Iterable[float]] = (0, 0),
            axis: int = -1,
            iterations: int = 1,
            scale: Union[None, float, Iterable[float], torch.Tensor] = None,
            offset: Union[None, Iterable[float], torch.Tensor] = None,
            exponent: Optional[float] = None,
            mixer: Union[None, torch.Tensor, List[List[float]]] = None,
            learn_param: bool = False,
            learn_mixer: bool = False,
            learn_scale: bool = False,
            learn_offset: bool = False,
    ):
        """
        A layer that calculates the mandelbrot-set.

        Think of the input as coordinates.

        The module accepts any shape as long as the `axis` dimension has the same size as `param`.
        And `param` should have at least shape (2,)

        :param iterations: number of iterations
        :param axis: int, axis of the channels
        :param mixer: matrix of shape (len(param), len(param)) to transform the final values
        """
        if not isinstance(param, torch.Tensor):
            param = torch.Tensor(param)
        if param.shape != torch.Size([2]):
            raise ValueError(f"Expected `param` to have shape (2,), got shape {param.shape}")
        
        super().__init__(
            num_channels=2,
            axis=axis,
            scale=scale,
            offset=offset,
            mixer=mixer,
            learn_mixer=learn_mixer,
            learn_scale=learn_scale,
            learn_offset=learn_offset,
        )
        self.iterations = iterations
        self.param = nn.Parameter(param, requires_grad=learn_param)

    def fractal(self, x: torch.Tensor, axis: int) -> torch.Tensor:
        slices = [None] * axis
        
        c = torch.complex(
            x[*(slices + [1])],
            x[*(slices + [0])],
        )
        #z = torch.zeros_like(x)
        z = torch.complex(torch.ones_like(x) * self.param[0], torch.ones_like(x) * self.param[1])
        accum = torch.complex(torch.zeros_like(x), torch.zeros_like(x))
        #accum = torch.zeros_like(x)
        for i in range(self.iterations):
            z = z ** 2 + c
            mask = ~(torch.isnan(z) | torch.isinf(z))
            #mask = mask[0] & mask[1]
            z_mask = z[mask]
            #act = lambda x: torch.exp(-x.abs())
            act = lambda x: torch.tanh(x)
            #act = lambda x: torch.tanh(x/3).abs()
            accum[mask] += torch.complex(act(z_mask.real), act(z_mask.imag))
        output = torch.concat(
            [accum.real[:1], accum.imag[:1], accum.real[1:2]],
            axis
        )
        #output = accum
        output /= self.iterations
        print(output.min(), output.max(), output.shape)
        return output * 6 + .5

model = MandelbrotLayer(axis=-3, iterations=150, param=(0, 0))
VF.to_pil_image(model(coord_grid(width=512, height=512)).clip(0, 1))

In [None]:
s = 0.01
VF.to_pil_image((model(coord_grid(
    width=512, height=512,
    min_x=0., max_x=0.+s,
    min_y=-.75, max_y=-.75+s,
)) * 1).clip(0, 1))

In [None]:
class MandelbrotLayer(FractalBaseLayer):

    def __init__(
            self,
            axis: int = -1,
            iterations: int = 1,
            scale: Union[None, float, Iterable[float], torch.Tensor] = None,
            offset: Union[None, Iterable[float], torch.Tensor] = None,
            exponent: Optional[float] = None,
            mixer: Union[None, torch.Tensor, List[List[float]]] = None,
            learn_mixer: bool = False,
            learn_scale: bool = False,
            learn_offset: bool = False,
    ):
        """
        A layer that calculates the mandelbrot-set.

        Think of the input as coordinates.

        The module accepts any shape as long as the `axis` dimension has the same size as `param`.
        And `param` should have at least shape (2,)

        :param iterations: number of iterations
        :param axis: int, axis of the channels
        :param mixer: matrix of shape (len(param), len(param)) to transform the final values
        """
        super().__init__(
            num_channels=2,
            axis=axis,
            scale=scale,
            offset=offset,
            mixer=mixer,
            learn_mixer=learn_mixer,
            learn_scale=learn_scale,
            learn_offset=learn_offset,
        )
        self.iterations = iterations

    def fractal(self, x: torch.Tensor, axis: int) -> torch.Tensor:
        slices = [None] * axis
        
        c = x[*slices]
        
        z = torch.zeros_like(x)
        accum = torch.zeros_like(x)
        for i in range(self.iterations):
            zz = z ** 2 
            aa = zz[*slices, 0]
            bb = zz[*slices, 1]
            a = aa - bb + c[*slices, 0]
            b = 2. * aa * bb + c[*slices, 1]
            z[*slices, 0] = a
            z[*slices, 1] = b
            mask = ~(torch.isnan(z) | torch.isinf(z))
            accum[mask] += z[mask]#.tanh()
        output = accum / self.iterations
        print(output.min(), output.max())
        return output

model = MandelbrotLayer(axis=-3, iterations=20)
VF.to_pil_image(model(coord_grid(width=512, height=512)).clip(0, 1))