In [None]:
from init_notebook import *
from typing import Literal
from src.models.fractal import KaliSetLayer

In [None]:
def coord_grid(
    width: int = 256, 
    height: int = 256, 
    min_x: float = -1.,
    max_x: float = 1.,
    min_y: float = -1.,
    max_y: float = 1.,
    z: float = 1.,
):
    return torch.concat([
        g[None, ...]
        for g in torch.meshgrid(
            torch.linspace(min_y, max_y, height),
            torch.linspace(min_x, max_x, width), 
            indexing="ij",
        )
        ] + [torch.ones(1, height, width) * z]
    )

In [None]:

grid = []
for accum in ("none", "mean", "max", "min", "submin", "alternate"):
    model = KaliSetLayer((.5, .6, .7), iterations=7, axis=0, accumulate=accum, exponent=1.)
    grid.append(model(coord_grid()).clamp(0, 1))
VF.to_pil_image(make_grid(grid, nrow=3))

In [None]:
grid = []
for scale in (None, .5, 2.):
    for offset in (None, (.5, 0, 0), (0, .5, 0)):
        model = KaliSetLayer((.5, .6, .7), iterations=7, axis=0, offset=offset, scale=scale, exponent=1.)
        grid.append(model(coord_grid(z=0.)).clamp(0, 1))
VF.to_pil_image(make_grid(grid, nrow=3))

In [None]:
device = to_torch_device("auto")
device

In [None]:
target_image = PIL.Image.open(
    #"/home/bergi/Pictures/DSCN0010.jpg"
    "/home/bergi/Pictures/__diverse/Screenshot_2025-03-18_16-21-29.png"
).convert("RGB")
target_image = resize(target_image, .3)
target_image = target_image.crop((100, 150, 250, 270))
display(target_image)
target_image = VF.to_tensor(target_image).to(device)
print(target_image.shape)

In [None]:
target_image2 = PIL.Image.new("RGB", (200, 120))
draw = PIL.ImageDraw.ImageDraw(target_image2)
draw.text((0, -20), "YO!", font=PIL.ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoSans-ExtraBold.ttf", 120))
display(target_image2)
target_image2 = VF.to_tensor(target_image2).to(device)

In [None]:
class MultiModel(nn.Module):
    def __init__(self, num: int = 4, seed: int = 23):
        super().__init__()
        self.rng = random.Random(seed)
        self.models = nn.ModuleList()
        for i in range(num):
            self.models.append(KaliSetLayer(
                param=tuple(self.rng.uniform(.1, .9) for _ in range(3)),
                axis=-3,
                iterations=3 + 2 * self.rng.randrange(5), 
                accumulate=self.rng.choice(KaliSetLayer.ACCUMULATION_TYPES),
                #exponent=10.,
                offset=tuple(self.rng.uniform(-.5, .5) for _ in range(3)),
                learn_param=True, 
                learn_mixer=True,
                learn_offset=True,
                learn_scale=True,
            ))
    def forward(self, x, num: Optional[int] = None):
        y = None
        models = self.models[:num] if num is not None else self.models
        for m in models:
            o = m(x)
            if y is None:
                y = o
            else:
                y += o
        return y

def train_model(model, t_image):
    model.to(device)
    t_image = t_image.to(device)
    #src = torch.rand_like(target_image) * 0.1
    src = coord_grid(
        t_image.shape[-1], t_image.shape[-2],
        #min_x=.5, max_x=.6,
        #min_y=.5, max_y=.6,
    ).to(device)
    image_param = nn.Parameter(src, requires_grad=True)
    
    optimizer = torch.optim.AdamW([
        #image_param, 
        *model.parameters(),
    ], 0.01)
    from src.scheduler import CosineAnnealingWarmupLR
    scheduler = CosineAnnealingWarmupLR(optimizer, 10000, warmup_steps=50)
    loss_func = nn.L1Loss()

    try:
        num_models = 1
        with tqdm(range(scheduler.T_max), ncols=115) as progress:
            for i in progress:
                output = model(image_param, num=num_models)
                if i % 1000 == 0 and num_models < len(model.models):
                    num_models += 1
                if i % 1000 == 0:
                    display(VF.to_pil_image(
                        make_grid([image_param, output, t_image]).clamp(0, 1)
                    ))
                loss = loss_func(output, t_image)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
                grad_max = image_param.grad.max().item()
                #par = ", ".join(str(round(p.item(), 2)) for p in model.param)
                #ofs = ", ".join(str(round(p.item(), 2)) for p in model.offset)
                progress.set_postfix({"lr": scheduler.get_last_lr()[0], "grad_max": grad_max, "loss": loss.item()})
    except KeyboardInterrupt:
        pass
    display(VF.to_pil_image(model(coord_grid().to(device)).clamp(0, 1)))

In [None]:
train_model(
    MultiModel(8, seed=1001),
    resize(target_image, 1),
)

In [None]:
train_model(
    MultiModel(2, seed=123),
    resize(target_image2, 1),
)

In [None]:
VF.to_pil_image(model(coord_grid().to(device)).clamp(0, 1))