In [13]:
# import zipfile
# with zipfile.ZipFile("/home/u202194/tmp/dataset.zip","r") as zip_ref:
#     zip_ref.extractall("/home/u202194/tmp/floor plan")

In [25]:
# !pwd

/home/u202194/tmp


In [1]:
dataset_dir = r"D:\Projects and Training\Projects\floor plan\dataset\dataset\floorplan_dataset"  # Replace with the path to your dataset
train_dir = r"D:\Projects and Training\Projects\floor plan\dataset\dataset\train"  # Replace with the path to save the training set
validate_dir = r"D:\Projects and Training\Projects\floor plan\dataset\dataset\val"  # Replace with the path to save the validation set

In [2]:
import os
import random
from shutil import copyfile

# import torch
# from torchvision import datasets, transforms

# Define the split ratio (e.g., 80% for training, 20% for validation)
split_ratio = 0.8

# Create the output directories if they don't exist
os.makedirs(train_dir, exist_ok=True)
os.makedirs(validate_dir, exist_ok=True)

# List all the PNG files in the dataset directory
image_files = [f for f in os.listdir(dataset_dir) if f.endswith(".png")]

# Shuffle the image files randomly
random.shuffle(image_files)

# Calculate the number of files for training and validation
num_total = len(image_files)
num_train = int(num_total * split_ratio)
num_validate = num_total - num_train

# Split the dataset
train_files = image_files[:num_train]
validate_files = image_files[num_train:]

# Copy the training images to the train directory
for filename in train_files:
    source_path = os.path.join(dataset_dir, filename)
    destination_path = os.path.join(train_dir, filename)
    copyfile(source_path, destination_path)

# Copy the validation images to the validate directory
for filename in validate_files:
    source_path = os.path.join(dataset_dir, filename)
    destination_path = os.path.join(validate_dir, filename)
    copyfile(source_path, destination_path)

print(f"Split the dataset into {len(train_files)} training images and {len(validate_files)} validation images.")


Split the dataset into 64630 training images and 16158 validation images.


In [3]:
room_label = [(0,'LivingRoom'),
            (1,'MasterRoom'),
            (2,'Kitchen'),
            (3,'Bathroom'),
            (4,'DiningRoom'),
            (5,'ChildRoom'),
            (6,'StudyRoom'),
            (7,'SecondRoom'),
            (8,'GuestRoom'),
            (9,'Balcony'),
            (10,'Entrance'),
            (11,'Storage'),
            (12,'Wall-in'),
            (13,'External'),
            (14,'ExteriorWall'),
            (15,'FrontDoor'),
            (16,'InteriorWall'),
            (17,'InteriorDoor')]

category = [category for category in room_label if category[1] not in set(['External',\
            'ExteriorWall','FrontDoor','InteriorWall','InteriorDoor'])]

num_category = len(category)

pixel2length = 18/256

def label2name(label=0):
    if label < 0 or label > 17:
        raise Exception("Invalid label!", label)
    else:
        return room_label[label][1]

def label2index(label=0):
    if label < 0 or label > 17:
        raise Exception("Invalid label!", label)
    else:
        return label

def index2label(index=0):
    if index < 0 or index > 17:
        raise Exception("Invalid index!", index)
    else:
        return index

def compute_centroid(mask):
    sum_h = 0
    sum_w = 0
    count = 0
    shape_array = mask.shape
    for h in range(shape_array[0]):  
        for w in range(shape_array[1]):
            if mask[h, w] != 0:
                sum_h += h
                sum_w += w
                count += 1
    return (sum_h//count, sum_w//count)

def log(file, msg='', is_print=True):
    if is_print:
        print(msg)
    file.write(msg + '\n')
    file.flush()

In [4]:
from tqdm import tqdm
from tqdm.notebook import tqdm_notebook

In [5]:
from PIL import Image
import numpy as np
import shutil
import pickle
# import utils
import os


def write2pickle(train_dir, pkl_dir):
    train_data_path = [os.path.join(train_dir, path) for path in os.listdir(train_dir)]
    print(f'Number of dataset: {len(train_data_path)}')
    for path in tqdm_notebook(train_data_path):
        with Image.open(path) as temp:
            image_array = np.asarray(temp, dtype=np.uint8)
        boundary_mask = image_array[:,:,0]
        category_mask = image_array[:,:,1]
        index_mask = image_array[:,:,2]
        inside_mask = image_array[:,:,3]
        shape_array = image_array.shape
        index_category = []
        room_node = []

        interiorWall_mask = np.zeros(category_mask.shape, dtype=np.uint8)
        interiorWall_mask[category_mask == 16] = 1        
        interiordoor_mask = np.zeros(category_mask.shape, dtype=np.uint8)
        interiordoor_mask[category_mask == 17] = 1

        for h in range(shape_array[0]):  
            for w in range(shape_array[1]):
                index = index_mask[h, w]
                category = category_mask[h, w]
                if index > 0 and category <= 12:
                    if len(index_category):
                        flag = True
                        for i in index_category:
                            if i[0] == index:
                                flag = False
                        if flag:
                            index_category.append((index, category))
                    else:
                        index_category.append((index, category))

        for (index, category) in index_category:
            node = {}
            node['category'] = int(category)
            mask = np.zeros(index_mask.shape, dtype=np.uint8)
            mask[index_mask == index] = 1
            node['centroid'] = compute_centroid(mask)
            room_node.append(node)
        
        pkl_path = path.replace(train_dir, pkl_dir)
        pkl_path = pkl_path.replace('png', 'pkl')
        pkl_file = open(pkl_path, 'wb')
        pickle.dump([inside_mask, boundary_mask, interiorWall_mask, interiordoor_mask, room_node], 
            pkl_file, protocol=pickle.HIGHEST_PROTOCOL)
        pkl_file.close()

if __name__=='__main__':
    print("*******************************************")
    train_dataset_dir = r"D:\Projects and Training\Projects\floor plan\dataset\dataset\train"
    val_dataset_dir = r"D:\Projects and Training\Projects\floor plan\dataset\dataset\val"
    
    train_pickle_dir = r"D:\Projects and Training\Projects\floor plan\dataset\dataset\pickle\train"
    val_pickle_dir = r"D:\Projects and Training\Projects\floor plan\dataset\dataset\pickle\val"
    if os.path.exists(train_pickle_dir):
        shutil.rmtree(train_pickle_dir)
    os.mkdir(train_pickle_dir)
    if os.path.exists(val_pickle_dir):
        shutil.rmtree(val_pickle_dir)
    os.mkdir(val_pickle_dir)
    write2pickle(train_dataset_dir, train_pickle_dir)
    write2pickle(val_dataset_dir, val_pickle_dir)

*******************************************
Number of dataset: 64630


  0%|          | 0/64630 [00:00<?, ?it/s]

Number of dataset: 16158


  0%|          | 0/16158 [00:00<?, ?it/s]