In [None]:
import os, importlib, sys
from collections import defaultdict
from stream.client import draw_minimap, frames_to_map
import numpy as np
from PIL import Image

In [None]:
"""
manually captured frames of coast map in 4k
- the frames are grouped by key, where each key represents a separate instance of the layout
- first frame of a group is always centered at map entrance, we track origin position based on this first frame
- the frame groups are further grouped by layout ID, which was assigned by human inspection as ground truth

current approach to train a vision transformer:

1. for each set of N frames, assemble N minimaps:
 {frame 0}, {frame 0 + frame 1}, ..., {frame 0 + .. + frame N}
 where {frame + frame} represents assembling a composite minimap from minimap slices in these frames
 
2. then extract minimap feature masks, crop to minimize surrounding blank space

3. each minimap mask is an input, output is softmax probabilities for N classes, where N is the number of unique layouts
"""

In [None]:
# Load 4k screenshots
layouts = {}
for layout in os.listdir("data/train"):
    frames = defaultdict(dict)
    data_dir = os.path.join("data/train", layout)
    for file in os.listdir(data_dir):
        if file.endswith(".png"):
            key = file.split("screenshot-")[0]
            full_path = os.path.join(data_dir, file)
            number = int(file.split("screenshot-")[1][:-4])
            frames[key][number] = np.array(Image.open(full_path))
    layouts[int(layout)] = frames
    #break

In [None]:
# First extract middle of 4k frame
# player icon is at 1920, 1060, in 4k

In [None]:
box_radius = 600
frames = []
for instance in layouts[0]:
    frame_ids = sorted(layouts[0][instance].keys())
    for frame_id in frame_ids:
        print(frame_id)
        frames.append(layouts[0][instance][frame_id])
    break

In [None]:
frames = np.array(frames)

In [None]:
frames.shape

In [None]:
cropped = frames[:, 1060-box_radius: 1060+box_radius, 1920-box_radius: 1920+box_radius, :]

In [None]:
cropped.shape

In [None]:
for i in range(cropped.shape[0]):
    #display(Image.fromarray(cropped[i]))
    pass

In [None]:
client_module = sys.modules['stream.client']
importlib.reload(client_module)
from stream.client import draw_minimap, frames_to_map

In [None]:
# stitch frames together, tracking origin
minimap, origin = frames_to_map(cropped)

In [None]:
minimap.shape

In [None]:
ent = minimap[origin[0]-100:origin[0]+100, origin[1]-100:origin[1]+100, :]
ent.shape

In [None]:
Image.fromarray(ent)

In [None]:
from helpers import pad_to_square_multiple, shrink_image

In [None]:
# shrink and pad image

#Image.fromarray(minimap)
dims = minimap.shape[0:2]
max_dim_idx = dims.index(max(dims))
new_size = dims[max_dim_idx] // 2
shrunk_origin = tuple(int(x * new_size / max(dims)) for x in origin)
shrunk = shrink_image(minimap, new_size)
# Use mask to track origin position
mask = np.zeros((*shrunk.shape[0:2], 1))
mask[shrunk_origin] = 1
shrunk = np.concatenate([shrunk, mask], axis=-1)
padded = pad_to_square_multiple(shrunk, 32)
shrunk_origin = np.where(padded[..., 3] == 1)
shrunk_origin = tuple(int(x[0]) for x in shrunk_origin)
padded = padded[:,:,0:3].astype(np.uint8)

In [None]:
#Image.fromarray(padded)
#Image.fromarray(padded[shrunk_origin[0]-100:shrunk_origin[0]+100, shrunk_origin[1]-100:shrunk_origin[1]+100, :])

In [None]:
# Run inference to extract mask of minimap features

In [None]:
from models import AttentionUNet

In [None]:
model_name = "AttentionUNet_4"
model = AttentionUNet(model_name)
model.load()

In [None]:
pred = model.batch_inference(padded, chunk_size=32)

In [None]:
display(Image.fromarray(padded))
display(Image.fromarray(pred * 255, mode="L"))

In [None]:
pred.shape

In [None]:
from scipy.ndimage import convolve

def crop_to_content(image):
    white_pixels = np.argwhere(image == 1)
    assert len(white_pixels) > 0
    
    y_min, x_min = white_pixels.min(axis=0)
    y_max, x_max = white_pixels.max(axis=0)
    cropped_image = image[y_min:y_max+1, x_min:x_max+1]
    return cropped_image, (y_min, x_min)

def clean_sparse_pixels(image, threshold=3, neighborhood_size=3):
    # Create a kernel for counting neighbors
    kernel = np.ones((neighborhood_size, neighborhood_size))
    kernel[neighborhood_size//2, neighborhood_size//2] = 0  # Don't count the pixel itself
    # Count white neighbors for each pixel
    neighbor_count = convolve(image.astype(int), kernel, mode='constant')
    # Create a mask of pixels to keep (either black or with enough white neighbors)
    mask = (image == 0) | (neighbor_count >= threshold)
    # Apply the mask to the original image
    cleaned_image = image * mask
    
    return cleaned_image

In [None]:
clean = clean_sparse_pixels(pred, threshold=20, neighborhood_size=40)
clean, offsets = crop_to_content(clean)
display(Image.fromarray(padded))
display(Image.fromarray(clean * 255, mode="L"))

In [None]:
clean_origin = tuple(int(val - offset) for val, offset in zip(shrunk_origin, offsets))

In [None]:
x, y = clean_origin
Image.fromarray(clean[x-50:x+50, y-50:y+50] * 255, mode="L")

In [None]:
clean.shape

In [None]:
clean_origin

In [None]:
# Chunk the map into square patches, label each patch with y,x positions relative to origin
# We will use the y,x positions for token position embeddings
def get_patches(array, origin, ps=32):
    assert len(array.shape) == 2
    Y, X = array.shape
    # calc num patches in each direction from origin
    y, x = origin
    up, down = y//ps, (Y-y)//ps
    left, right = x//ps, (X-x)//ps
    patches = array[y-ps*up : y+ps*down, x-ps*left : x+ps*right]

    # calc patch y,x dims for each pixel, relative to origin patch
    indices = np.indices(patches.shape).transpose(1,2,0)
    indices = indices // ps - np.array([up, left])
    patches = patches.reshape(*patches.shape, 1)
    patches = np.concatenate([patches, indices], axis=-1)

    return patches

In [None]:
patches = get_patches(clean, clean_origin)

In [None]:
Image.fromarray(patches[:,:,0].astype(np.uint8) * 255, mode="L")

In [None]:
# Remove completely black patches
def get_tokens(patches):
    Y,X = patches.shape[0:2]
    y_patches, x_patches = Y // 32, X // 32
    tokens = []
    for i in range(y_patches):
        for j in range(x_patches):
            patch = patches[i*32 : (i+1)*32, j*32 : (j+1)*32]
            if np.any(patch[:,:,0] > 0):
                tokens.append(patch)
    return np.array(tokens)

In [None]:
tokens = get_tokens(patches)

In [None]:
tokens.shape

In [None]:
for p in tokens:
    #display(Image.fromarray(p[:,:,0].astype(np.uint8) * 255, mode="L"))
    pass

In [None]:
dim = 256
assert dim % 4 == 0
num_tokens = tokens.shape[0]
x_coords = tokens[:,0,0,1].reshape(tokens.shape[0], 1)
y_coords = tokens[:,0,0,2].reshape(tokens.shape[0], 1)
embeds = np.zeros((num_tokens, dim))
denoms = np.exp(np.arange(0, dim, 4) / dim * -np.log(10000.0)).reshape(1, dim // 4)
embeds[:, 0::4] = np.sin(x_coords * denoms) 
embeds[:, 1::4] = np.cos(x_coords * denoms) 
embeds[:, 2::4] = np.sin(y_coords * denoms) 
embeds[:, 3::4] = np.cos(y_coords * denoms) 

In [None]:
for module in ('models.vit', 'models'):
    client_module = sys.modules[module]
    importlib.reload(client_module)
from models import ViT

In [None]:
model = ViT(9, max_tokens=128, layers=3, embed_dim=256, num_heads=4)
#model_name = "AttentionUNet_1"
#model = ViT(model_name)
#model.train()

In [None]:
tokens.shape

In [None]:
y = model([tokens])

In [None]:
y = y.numpy().ravel()

In [None]:
y.shape

In [None]:
y.argmax()