In [None]:
import os, importlib, sys
from collections import defaultdict
from stream.client import draw_minimap, frames_to_map
import numpy as np
from PIL import Image

In [None]:
"""
manually captured frames of coast map in 4k
- the frames are grouped by key, where each key represents a separate instance of the layout
- first frame of a group is always centered at map entrance, we track origin position based on this first frame
- the frame groups are further grouped by layout ID, which was assigned by human inspection as ground truth

current approach to train a vision transformer:

1. for each set of N frames, assemble N minimaps:
 {frame 0}, {frame 0 + frame 1}, ..., {frame 0 + .. + frame N}
 where {frame + frame} represents assembling a composite minimap from minimap slices in these frames
 
2. then extract minimap feature masks, crop to minimize surrounding blank space

3. each minimap mask is an input, output is softmax probabilities for N classes, where N is the number of unique layouts
"""

In [None]:
# Load 4k screenshots
layouts = {}
for layout in os.listdir("data/train"):
    frames = defaultdict(dict)
    data_dir = os.path.join("data/train", layout)
    for file in os.listdir(data_dir):
        if file.endswith(".png"):
            key = file.split("screenshot-")[0]
            full_path = os.path.join(data_dir, file)
            number = int(file.split("screenshot-")[1][:-4])
            frames[key][number] = np.array(Image.open(full_path))
    layouts[int(layout)] = frames
    #break

In [None]:
# First extract middle of 4k frame
# player icon is at 1920, 1060, in 4k

In [None]:
box_radius = 600
frames = []
for instance in layouts[0]:
    frame_ids = sorted(layouts[0][instance].keys())
    for frame_id in frame_ids:
        print(frame_id)
        frames.append(layouts[0][instance][frame_id])
    break

In [None]:
frames = np.array(frames)

In [None]:
frames.shape

In [None]:
cropped = frames[:, 1060-box_radius: 1060+box_radius, 1920-box_radius: 1920+box_radius, :]

In [None]:
cropped.shape

In [None]:
for i in range(cropped.shape[0]):
    #display(Image.fromarray(cropped[i]))
    pass

In [None]:
client_module = sys.modules['stream.client']
importlib.reload(client_module)
from stream.client import draw_minimap, frames_to_map

In [None]:
# stitch frames together, tracking origin
minimap, origin = frames_to_map(cropped)

In [None]:
minimap.shape

In [None]:
ent = minimap[origin[0]-100:origin[0]+100, origin[1]-100:origin[1]+100, :]
ent.shape

In [None]:
Image.fromarray(ent)

In [None]:
from helpers import pad_to_square_multiple, shrink_image

In [None]:
# shrink and pad image

#Image.fromarray(minimap)
dims = minimap.shape[0:2]
max_dim_idx = dims.index(max(dims))
new_size = dims[max_dim_idx] // 2
shrunk_origin = tuple(int(x * new_size / max(dims)) for x in origin)
shrunk = shrink_image(minimap, new_size)
# Use mask to track origin position
mask = np.zeros((*shrunk.shape[0:2], 1))
mask[shrunk_origin] = 1
shrunk = np.concatenate([shrunk, mask], axis=-1)
padded = pad_to_square_multiple(shrunk, 32)
shrunk_origin = np.where(padded[..., 3] == 1)
shrunk_origin = tuple(int(x[0]) for x in shrunk_origin)
padded = padded[:,:,0:3].astype(np.uint8)

In [None]:
#Image.fromarray(padded)
#Image.fromarray(padded[shrunk_origin[0]-100:shrunk_origin[0]+100, shrunk_origin[1]-100:shrunk_origin[1]+100, :])

In [None]:
# Run inference to extract mask of minimap features

In [None]:
from models import AttentionUNet

In [None]:
model_name = "AttentionUNet_4"
model = AttentionUNet(model_name)
model.load()

In [None]:
pred = model.batch_inference(padded, chunk_size=32)

In [None]:
display(Image.fromarray(padded))
display(Image.fromarray(pred * 255, mode="L"))