In [None]:
from PIL import Image
import numpy as np
import os
from dataloader import DataLoader
from models import UNet, AttentionUNet
from tinygrad import dtypes
from helpers import pad_to_square_multiple

In [None]:
dl = DataLoader(
    image_dir="data/auto_crop",
    mask_dir="data/mask",
    patch_size=(64,64),
)

# Compare raw data vs. desired map features (true mask)

In [None]:
dl.normalize=False
for a,b in zip(*dl.get_batch(16)):
    a = a.numpy().astype(np.uint8).transpose(1,2,0)
    b = b.numpy().astype(np.uint8) * 255
    if np.any(b > 0):
        display(Image.fromarray(a))
        display(Image.fromarray(b, mode="L"))
dl.normalize=True

# Train UNet to extract map features (mask) from raw screenshots

In [None]:
model_name = "UNet_3"
model = UNet(model_name)
model.train()

In [None]:
model_name = "AttentionUNet_1"
model = AttentionUNet(model_name)
model.train()

In [None]:
# Load saved model if training was already done
model_name = "UNet_3"
model = UNet.load(model_name)

In [None]:
models = {
    "UNet_3": UNet.load("UNet_3"),
    "AttentionUNet_1": AttentionUNet.load("AttentionUNet_1"),
}

# Compare predicted mask vs. true mask

In [None]:
x, y = dl.get_batch(10)

y_pred = model(x).argmax(axis=1).cast(dtypes.uint8).numpy()
y = y.cast(dtypes.uint8).numpy()
for a,b in zip(y_pred,y):
    if np.any(b > 0):
    #if True:
        display(Image.fromarray(a * 255, mode="L"))
        display(Image.fromarray(b * 255, mode="L"))
        print("---------------------------------")

In [None]:
test = pad_to_square_multiple(np.load("data/auto_crop/3/3.npz")['data'], 64)
display(Image.fromarray(test))
for name, model in models.items():
    print(name)
    pred = model.batch_inference(test)
    display(Image.fromarray(pred * 255, mode="L"))