In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
import torch.nn.functional as F
import cv2
from torch.utils.data import DataLoader

import FloorplanToSTL as stl
import config

from utils.FloorplanToBlenderLib import *
from model import get_model
from utils.loaders import FloorplanSVG, DictToTensor, Compose, RotateNTurns
from utils.plotting import segmentation_plot, polygons_to_image, draw_junction_from_dict
import utils.plotting

from utils.post_prosessing import split_prediction, get_polygons, split_validation
from mpl_toolkits.axes_grid1 import AxesGrid

img_path = "c:/Users/end0t/dev/test_floor_plan.png"
wall_height = 1
scale       = 100
pkl_path = "model_best_val_loss_var.pkl"

room_classes=["Background","Outdoor","Wall","Kitchen","Living Room","Bed Room",
              "Bath","Entry","Railing","Storage","Garage","Undefined"]
icon_classes=["No Icon","Window","Door","Closet","Electrical Applience",
              "Toilet","Sink","Sauna Bench","Fire Place","Bathtub","Chimney"]


def main():
    """ segmentation 可視化で label毎に異なる色を割当てる color map作成。"""
    utils.plotting.discrete_cmap()

    rot:RotateNTurns = RotateNTurns()
    
    """ https://github.com/CubiCasa/CubiCasa5k/tree/master/floortrans/models """
    model:hg_furukawa_original = get_model('hg_furukawa_original', 51)

    split     = [21, len( room_classes ), len( icon_classes )]
    n_classes = split[0] + split[1] + split[2] # = 44
    """ 最終出力層を上書き (class数44に合わせる) """
    model.conv4_ = torch.nn.Conv2d( 256, n_classes, bias=True, kernel_size=1 )
    """ 出力sizeを入力画像に合わせるための upsampling 層 修正 """
    model.upsample = torch.nn.ConvTranspose2d( n_classes,
                                               n_classes,
                                               kernel_size=4,
                                               stride=4 )

    checkpoint = torch.load( pkl_path, map_location='cpu' )  #CPU
    model.load_state_dict( checkpoint['model_state'] )
    model.eval()        # 推論modeへ切替え
    model.to('cpu')     # CPUで動かす

    
    img = cv2.imread(img_path)  # Create tensor for pytorch
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # correct color channels
    img = 2 * (img / 255.0) - 1 # Image transformation to range (-1,1)
    # Move from (h,w,3)--->(3,h,w) as model input dimension is defined like this
    img = np.moveaxis(img, -1, 0)
    
    #img = torch.tensor([img.astype(np.float32)])  # .cuda() 削除 → CPU上に保持
    img = torch.from_numpy(np.expand_dims(img.astype(np.float32), axis=0))

    n_rooms = len( room_classes )
    n_icons = len( icon_classes )

    with torch.no_grad(): # 推論(≠学習)の為、勾配計算 無効化し memory節約
        #Check if shape of image is odd or even
        size_check = np.array([img.shape[2], img.shape[3]]) % 2
        
        height = img.shape[2] - size_check[0]
        width = img.shape[3] - size_check[1]
        img_size = (height, width)

        rotations = [ ( 0,  0),  #  回転なし → 戻しなし
                      ( 1, -1),  #  90度回転 → -90度戻し
                      ( 2,  2),  # 180度回転 → 180度戻し
                      (-1,  1) ] # -90度回転 → +90度戻し

        pred_count = len(rotations)
        prediction = torch.zeros([pred_count, n_classes, height, width])  # CPU上で生成

        for i, r in enumerate( rotations ):
            forward, back = r
            rot_image = rot(img, 'tensor', forward) # 画像を正回転
            pred = model(rot_image)                 # 予測
            pred = rot(pred, 'tensor', back)        # 結果の見た目を戻す
            # heatmap点の意味を戻す(例:icon向き)
            pred = rot(pred, 'points', back)

            # model出力 pred を指定のheight, widthにresize(補間)
            pred = F.interpolate(pred,
                                 size=(height, width),
                                 mode='bilinear',
                                 align_corners=True)
            prediction[i] = pred[0]

    # それぞれの回転での予測結果を平均
    prediction = torch.mean(prediction, 0, True)
    
    rooms_pred = F.softmax(prediction[0, 21:21+12], 0).cpu().data.numpy()
    rooms_pred = np.argmax(rooms_pred, axis=0)

    icons_pred = F.softmax(prediction[0, 21+12:], 0).cpu().data.numpy()
    icons_pred = np.argmax(icons_pred, axis=0)

    heatmaps, rooms, icons = split_prediction(prediction, img_size, split)
    polygons, types, room_polygons, room_types = get_polygons((heatmaps, rooms, icons),
                                                              0.2,
                                                              [1, 2])

    # 壁polygon → 3D変換の準備
    wall_polygon_numbers = [i for i, j in enumerate(types) if j['type'] == 'wall']
    boxes = []
    for i, j in enumerate(polygons):
        if i in wall_polygon_numbers:
            temp = [np.array([k]) for k in j]
            boxes.append(np.array(temp))

    verts, faces, wall_amount = transform.create_nx4_verts_and_faces(boxes,
                                                                     wall_height,
                                                                     scale)
    # Create top walls verts
    verts = []
    for box in boxes:
        verts.extend([transform.scale_point_to_vector(box, scale, 0)])

    # create faces
    faces = []
    for room in verts:
        temp = tuple(range(len(room)))
        faces.append([temp])

    # 部屋やiconの画像を生成
    pol_room_seg, pol_icon_seg = polygons_to_image(polygons,
                                                   types,
                                                   room_polygons,
                                                   room_types,
                                                   height, width )
    # 画面表示
    plt.figure(figsize=(12, 12))
    ax = plt.subplot(1, 1, 1)
    ax.axis('off')
    rseg = ax.imshow(pol_room_seg, cmap='rooms', vmin=0, vmax=n_rooms - 0.1)
    cbar = plt.colorbar(rseg, ticks=np.arange(n_rooms) + 0.5, fraction=0.046, pad=0.01)
    cbar.ax.set_yticklabels(room_classes, fontsize=20)
    plt.tight_layout()
    plt.show()

    
    plt.figure(figsize=(12, 12))
    ax = plt.subplot(1, 1, 1)
    ax.axis('off')
    iseg = ax.imshow(pol_icon_seg, cmap='icons', vmin=0, vmax=n_icons - 0.1)
    cbar = plt.colorbar(iseg, ticks=np.arange(n_icons) + 0.5, fraction=0.046, pad=0.01)
    cbar.ax.set_yticklabels(icon_classes, fontsize=20)
    plt.tight_layout()
    plt.show()

    # blender dataへ変換
    stl.createFloorPlan(image_path  = img_path,
                        target_path = "floorplan",
                        SR_Check=True )

if __name__ == '__main__':
    main()