In [1]:
%load_ext autoreload
%autoreload 2

In [101]:
import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from matplotlib import colors
from matplotlib.colors import ListedColormap
import webcolors
from PIL import Image, ImageDraw

import networkx as nx

from minimal.imaging import blit_rooms, draw_sep_nask_wireframe
from minimal.layout import NodeType, NODE_COLOR
from minimal.gen import PlanMasks
from minimal.rooms import RoomAreas, extract_rooms, scale_room_mask
from minimal.walls import create_sep_mask, scale_sep_mask, _conv_mask
from minimal.walls import (
    CC_TL,
    CC_TR,
    CC_BR,
    CC_BL,
    CC_T,
    CC_R,
    CC_B,
    CC_L
)
from minimal.doors import extract_face_walls

%matplotlib inline

In [33]:
cmapr = ListedColormap(['white', *plt.get_cmap('tab10').colors])
cmapw = ListedColormap(['white', 'black', *plt.get_cmap('tab10').colors])

In [34]:
%%time
pm = PlanMasks.create_from_state(torch.load("data/plan_masks_05.pth"))
rooms = extract_rooms(pm)
rooms.sort(key=lambda r: r.total_area())
# blit_rooms(rooms)

CPU times: user 1.18 s, sys: 4.1 ms, total: 1.18 s
Wall time: 1.18 s


In [35]:
room_mask = sum(room.to_mask() * (i + 1) for i, room in enumerate(rooms))
sep_mask = create_sep_mask(room_mask)

In [46]:
sx = 3
sy = 3

scaled_sep_mask = scale_sep_mask(sep_mask, sx, sy)
scaled_room_mask = scale_room_mask(room_mask, sx, sy)

In [168]:
face_walls = extract_face_walls(scaled_sep_mask)

In [95]:
rmask = scaled_room_mask

In [198]:
ra = 4
rb = 2

In [197]:
all_runs

[(100, 80, 16, 'v')]

In [184]:
all_runs = []
for i, fw in enumerate(face_walls):
    ws = fw.clone()
    ws[rmask != ra] = 0
    ws = restrict_touching(rmask, ws, ra, rb)
    lx, ly = torch.where(ws > 0)

    if len(lx) == 0:
        continue

    transpose = (i == 0 or i == 3)
    orient = 'h' if transpose else 'v'

    if transpose:
        lx, ly = ly, lx

    runs = extract_walls_runs(lx, ly)

    if transpose:
        all_runs.extend(
            (y, x, len)
            for (x, y, len) in runs
        )
    else:
        all_runs.extend(
            (x, y, len, orient)
            for (x, y, len) in runs
        )

In [170]:
def extract_walls_runs(lx, ly, min_len: int=4):
    
    lx, idx = torch.sort(lx)
    ly = ly[idx]
    
    runs = []

    prev_x = lx[0]
    prev_y = ly[0]
    cur_len = 1

    for x, y in zip(lx[1:], ly[1:]):
        if x - prev_x != 1 or y != prev_y:
            if cur_len >= min_len:
                runs.append((1 + prev_x.item() - cur_len, prev_y.item(), cur_len))
            cur_len = 1
        else:
            cur_len += 1

        prev_x = x
        prev_y = y

    if cur_len >= min_len:
        runs.append((1 + prev_x.item() - cur_len, prev_y.item(), cur_len))

    return runs

In [91]:
# plt.imshow(room_mask == ra)

In [99]:
def restrict_touching(room_mask, ra_walls, ra, rb):
    room_mask = room_mask + 1
    ra = ra + 1
    rb = rb + 1
    
    room_mask[room_mask == ra] = 0
    room_mask[room_mask == rb] = 0

    restricted = (_conv_mask(room_mask, _res_kernel) == 0).byte()
    restricted[ra_walls == 0] = 0

    return restricted


_res_kernel = torch.tensor([
    [1, 1, 1],
    [1, 0, 1],
    [1, 1, 1],
], dtype=torch.int8)