In [None]:
from PIL import Image
import numpy as np
import os, glob
from dataloader import DataLoader
from models import UNet, AttentionUNet
from tinygrad import dtypes
from helpers import pad_to_square_multiple
#from training_data import clean_sparse_pixels, get_patches, get_tokens, extract_map_features
from training_data import tokenize_minimap, ViTDataLoader

In [None]:
dl = DataLoader(
    image_dir="data/auto_crop",
    mask_dir="data/mask",
    patch_size=(64,64),
)

# Compare raw data vs. desired map features (true mask)

In [None]:
dl.normalize=False
for a,b in zip(*dl.get_batch(16)):
    a = a.numpy().astype(np.uint8).transpose(1,2,0)
    b = b.numpy().astype(np.uint8) * 255
    if np.any(b > 0):
        display(Image.fromarray(a))
        display(Image.fromarray(b, mode="L"))
dl.normalize=True

# Compare predicted mask vs. true mask

In [None]:
x, y = dl.get_batch(10)

y_pred = model(x).argmax(axis=1).cast(dtypes.uint8).numpy()
y = y.cast(dtypes.uint8).numpy()
for a,b in zip(y_pred,y):
    if np.any(b > 0):
    #if True:
        display(Image.fromarray(a * 255, mode="L"))
        display(Image.fromarray(b * 255, mode="L"))
        print("---------------------------------")

In [None]:
import glob, time
from IPython.display import clear_output

In [None]:
layouts = sorted(list(set(glob.glob("data/train/*/*/*.png")) - set(glob.glob("data/train/*/*/*_mask.png"))))

In [None]:
len(layouts)

In [None]:
models = {
    "AttentionUNet8_8600": AttentionUNet("AttentionUNet8_8600", depth=3),
}

for model in models.values():
    model.load()

In [None]:
layout_n = 4
x = 8
for layout in layouts[layout_n*x: layout_n*x + layout_n]:

# 8, 9, 29, 32
#x = 0
#for layout in layouts[x:x+4]:
    test = Image.open(layout)
    #display(test)
    test = pad_to_square_multiple(np.array(test), 32)
    for name, model in models.items():
        print(name)
        pred = model.batch_inference(test, chunk_size=32)
        pred = clean_sparse_pixels(pred, threshold=10, neighborhood_size=15)
        display(Image.fromarray(pred * 255, mode="L"))
        #time.sleep(15)
        #clear_output()

In [None]:
#Image.fromarray(clean_sparse_pixels(pred, threshold=10, neighborhood_size=15) * 255, mode="L")

# Recompute layout masks for vit training

In [None]:
layouts = sorted(list(set(glob.glob("data/train/*/*/*.png")) - set(glob.glob("data/train/*/*/*_mask.png"))))
model = AttentionUNet("AttentionUNet8_8600", depth=3).load()

In [None]:
for layout in layouts[0:1]:
    minimap = np.array(Image.open(layout))
    wd = os.path.dirname(layout)
    num = os.path.splitext(os.path.basename(layout))[0]
    origin = np.load(os.path.join(wd, f"{num}_origin.npz"))['data']
    tokens, mask = tokenize_minimap(minimap, origin, model)
    display(Image.fromarray(mask * 255, mode="L"))
    break

In [None]:
origin

In [None]:
Image.fromarray(minimap[origin[0]-100: origin[0]+100, origin[1]-100: origin[1]+100] * 255, mode="L")

In [None]:
for layout in layouts:
    wd = os.path.dirname(layout)
    num = os.path.splitext(os.path.basename(layout))[0]
    mask = Image.open(os.path.join(wd, f"{num}_mask.png"))
    display(mask)

# Simulate player exploration of layout

The layout has been divided up into 2D square patch tokens

To train the ViT, we randomly sample tokens, with bias to skew sampling nearer to entrance

In [None]:
# Decide how many tokens have been seen
# Skew heavily toward smaller numbers to focus training on sample size useful to the player
# Because the player wants to know the layout ASAP after entering, with minimal tokens
num_samples = 8 + np.random.beta(1.3, 1.3 * 3, size=1) * 120
num_samples = np.round(num_samples).astype(np.uint32)

# Filter tokens that are too far from origin to be traveled to within a limited number of tokens seen
filtered = tokens[tokens[:, -1, -1, -1] <= num_samples**2]

# From the filtered set of tokens we could theoretically have traveled to,
#  randomly sample tokens, skewed toward being close to the origin (entrance)

"""
  In theory if we traveled in a straight line from origin and encountered map features there, 
then we'll allow those farthest map features to be sampled only if we get the max value from
this beta distribution.
  If we sample the min value from this beta dist., then we sample the num_samples closest tokens to the origin.

  We use below alpha/beta params to simulate typical exploration, which is rarely a perfect straight line from origin.
"""
#diff_samples = (filtered.shape[0] - num_samples) * np.random.beta(2, 2 * 1.5, size=5)
diff_samples = (filtered.shape[0] - num_samples) * np.random.beta(2, 2 * 3, size=1)
sample_pools = num_samples + diff_samples
sample_pools = np.round(sample_pools).astype(np.uint32)
print(f"num tokens: {num_samples[0]}")
print(f"filtered token limit: {filtered.shape[0]}")
print(f"total tokens: {tokens.shape[0]}")
print(f"max patch length: {(tokens[-1, -1, -1, -1])}")
#sel = sorted([int(x) for x in sel])
display(sample_pools)
print()

for max_token_idx in sample_pools:
    print(f"max_token_idx: {max_token_idx}")
    # Sample randomly within the window defined above
    sel = np.random.choice(max_token_idx, size=num_samples, replace=False)
    print(sorted([int(x) for x in sel]))
    print(sorted([int(x) for x in tokens[sel, -1, -1, -1]]))
    print()
#sel = np.random.choice(filtered.shape[0], size=num_samples, replace=False)
#display(tokens[sample_pools, -1, -1, -1])

In [None]:
tokens[sel].shape

In [None]:
sample_pools[0]

In [None]:
np.random.choice(sample_pools[0], size=num_samples, replace=False)

In [None]:
filtered.shape

In [None]:
tokens.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import beta

# Parameters
alpha = 2
beta_param = 3 * alpha  # Shape parameter β to set mean = 0.2
#alpha = 1.3 
#beta_param = 3 * alpha  # Shape parameter β to set mean = 0.2
#beta_param = 1.08

# Define the domain
x = np.linspace(0, 1, 1000)

# Compute the PDF of the Beta distribution
pdf = beta.pdf(x, alpha, beta_param)

# Plotting the distribution
plt.figure(figsize=(8, 5))
plt.plot(x, pdf, label=f'Beta PDF (α={alpha}, β={beta_param})', color='blue')
plt.title('Skewed Beta Distribution with Mean = 0.2')
plt.xlabel('x')
plt.ylabel('Probability Density')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
alpha = 2
beta_param = 1.5 * alpha
samples = np.random.beta(alpha, beta_param, size=10)
sorted([float(x) for x in samples])

# Visualize layout samples used in training

In [None]:
dl = ViTDataLoader(data_dir="data/train")

In [None]:
X,Y = dl.get_training_data(max_tokens=128)
x = X[-1]
print(x.shape)
sub = np.zeros(x.shape, dtype=np.int64)
sub[:,:,:,1] = x[:,0,0,1].min()
sub[:,:,:,2] = x[:,0,0,2].min()
x = x - sub
h, w = int(x[:,0,0,1].max()), int(x[:,0,0,2].max())
#canv = np.zeros((h*ps, w*ps), dtype=np.uint8)
canv = np.zeros((ps+h*ps, ps+w*ps), dtype=np.uint8)

In [None]:
for p in x:
    h0 = p[0,0,1]
    #print(h0)
    w0 = p[0,0,2]
    #print(w0)
    canv[h0*ps: h0*ps+ps, w0*ps:w0*ps+ps] = p[:,:,0].astype(np.uint8)
Image.fromarray(canv * 255, mode="L")