In [1]:
%load_ext autoreload
%autoreload 2

Generate data

In [9]:
import random

from common import constants
from schematic_generator import generator

random.seed(0)

configs = [
    # Simple shapes
    {
        "generator_type": ["shape"],
        "shape_type": ["sphere"],
        "radius": [lambda: random.randint(1, (constants.region_size[0] // 2) - 1)] * 5,
        "structure_block_types": [[block] for block in constants.simple_block_types] + [lambda: random.sample(constants.simple_block_types, 3)] * len(constants.simple_block_types),
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))],
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    },
    {
        "generator_type": ["shape"],
        "shape_type": ["cube"],
        "side_length": [lambda: random.randint(1, constants.region_size[0] - 1)] * 5,
        "structure_block_types": [[block] for block in constants.simple_block_types] + [lambda: random.sample(constants.simple_block_types, 3)] * len(constants.simple_block_types),
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))],
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    },
    # Filled
    {
        "generator_type": ["shape"],
        "shape_type": ["sphere"],
        "radius": [lambda: random.randint(3, (constants.region_size[0] // 2) - 1)] * 3,
        "structure_block_types": [[block] for block in constants.simple_block_types] + [lambda: random.sample(constants.simple_block_types, 3)] * (len(constants.simple_block_types) // 3),
        "structure_fill_block_types": [["minecraft:air"], lambda: random.sample(constants.simple_block_types, 1), lambda: random.sample(constants.simple_block_types, 3)],
        "thickness": [lambda: random.randint(1, 3)],
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))],
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    },
    {
        "generator_type": ["shape"],
        "shape_type": ["cube"],
        "side_length": [lambda: random.randint(7, constants.region_size[0] - 1)] * 3,
        "structure_block_types": [[block] for block in constants.simple_block_types] + [lambda: random.sample(constants.simple_block_types, 3)] * (len(constants.simple_block_types) // 3),
        "structure_fill_block_types": [["minecraft:air"], lambda: random.sample(constants.simple_block_types, 1), lambda: random.sample(constants.simple_block_types, 3)],
        "thickness": [lambda: random.randint(1, 3)],
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))],
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    }
]

simple_cubes = [
    {
        "generator_type": ["shape"],
        "shape_type": ["cube"],
        "side_length": [1, 2, 3, 4, 6, 7, 8],
        "structure_block_types": [[block] for block in random.sample(constants.simple_block_types, 12)],
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))] * 5,
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    }
]

simple_spheres = [
    {
        "generator_type": ["shape"],
        "shape_type": ["sphere"],
        "radius": [1, 2, 3],
        "structure_block_types": [[block] for block in random.sample(constants.simple_block_types, 30)],
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))] * 5,
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    }
]

# generator.generate_samples_from_configurations(configs, dry_run=False)
generator.generate_samples_from_configurations(simple_cubes, 'simple_cubes')
generator.generate_samples_from_configurations(simple_spheres, 'simple_spheres')

Generating samples: 100%|██████████| 420/420 [00:00<00:00, 19085.05it/s]
Generating samples: 100%|██████████| 450/450 [00:00<00:00, 18740.75it/s]


Prepare data

In [11]:
from data_preparer import load_schematics

schematics_dir = 'data/schematics'
hdf5_path = 'data/data.h5'
load_schematics(schematics_dir, hdf5_path)

Loading schematics from data/schematics into data/data.h5
Processing generator type: simple_cubes
Split data into 289 training samples, 64 validation samples, and 67 test samples.


Updating set: train for generator: simple_cubes: 100%|██████████| 289/289 [00:00<?, ?it/s]
Updating set: validation for generator: simple_cubes: 100%|██████████| 64/64 [00:00<00:00, 128253.92it/s]
Updating set: test for generator: simple_cubes: 100%|██████████| 67/67 [00:00<00:00, 133945.84it/s]


Processing generator type: simple_spheres
Split data into 305 training samples, 64 validation samples, and 81 test samples.


Updating set: train for generator: simple_spheres: 100%|██████████| 305/305 [00:00<?, ?it/s]
Updating set: validation for generator: simple_spheres: 100%|██████████| 64/64 [00:00<?, ?it/s]
Updating set: test for generator: simple_spheres: 100%|██████████| 81/81 [00:00<?, ?it/s]

Finished updating HDF5 file.





In [2]:
from pathlib import Path

import h5py
from schempy import Schematic
from tqdm import tqdm


# Assuming this is the function you want to apply to each sample name
def process_sample_name(generator_type: str, sample_name: str) -> str:
    schematic_path = f"data/schematics/{generator_type}/{sample_name}.schem"
    schematic = Schematic.from_file(Path(schematic_path))
    return schematic.name

# Function to add a new dataset to each sample group
def add_dataset_to_samples(hdf5_path: str) -> None:
    with h5py.File(hdf5_path, 'a') as hdf5_file:
        # Iterate over splits (val, train, test)
        for split in hdf5_file.keys():
            # Iterate over generator types within each split
            for generator_type in hdf5_file[split].keys():
                # Iterate over samples within each generator type
                sample_names = hdf5_file[split][generator_type].keys()
                sample_names_bar = tqdm(sample_names, desc=f"Processing samples in split '{split}' and generator type '{generator_type}'")
                for sample_name in sample_names_bar:
                    sample_group = hdf5_file[split][generator_type][sample_name]
                    # Apply the function to the sample name
                    description = process_sample_name(generator_type, sample_name)
                    # Create a new dataset with the result of the function
                    sample_group.create_dataset('description', data=description)
                    # print(f"Added description '{description}' to sample '{sample_name}'")

# Replace 'your_hdf5_file.hdf5' with the path to your actual HDF5 file
add_dataset_to_samples('data.h5')

In [4]:
import os

import h5py

from common.file_paths import BASE_DIR

with h5py.File(os.path.join(BASE_DIR, 'data.h5'), 'r') as hf:
    # Iterate over dataset splits (train, val, test)
    for split in hf:
        print(f"Split: {split}")

        group = hf[split]

        # Iterate over generator types
        for generator_type in group:
            print(f"  Generator Type: {generator_type}")

            # Get the first sample in the group
            sample = list(group[generator_type].keys())[0]
            print(f"    Sample: {sample}")

            # Print properties attribute
            print(f"      Properties: {group[generator_type][sample].attrs['properties']}")

            # Iterate over individual datasets
            for dataset in group[generator_type][sample]:
                print(f"      Dataset: {dataset}")

                # Print the shape of the dataset if it is not a scalar
                if group[generator_type][sample][dataset].shape != ():
                    print(f"        Shape: {group[generator_type][sample][dataset].shape}")
                else:
                    print(f"        Value: {group[generator_type][sample][dataset][()]}")
            
            # Print the total number of samples in the group
            print(f"    Total samples: {len(group[generator_type].keys())}")

        # Print the total number of generator types in the split
        print(f"  Total generator types: {len(group.keys())}")
    
    # Print the total number of splits in the file
    print(f"Total splits: {len(hf)}")

Split: test
  Generator Type: simple_cubes
    Sample: ddd06043d254ba45b645ab627bfe55ab20a7c6cd572f101a9f7bc2cc1349fc4b
      Properties: {"Hash": "ddd06043d254ba45b645ab627bfe55ab20a7c6cd572f101a9f7bc2cc1349fc4b", "Properties": {"generator_type": "shape", "shape_type": "cube", "side_length": 8, "structure_block_types": ["minecraft:oak_planks"], "background_block_types": ["minecraft:air"], "position_offset": [-65, -91, 29], "random_seed": 2496062452, "region_size": [16, 16, 16]}}
      Dataset: description
        Value: b'A perfect solid cube with a side length of 8 blocks. It is composed of oak planks. It is floating within an empty void.'
      Dataset: features
        Shape: (1536,)
      Dataset: target
        Shape: (16, 16, 16)
    Total samples: 17
  Total generator types: 1
Split: train
  Generator Type: simple_cubes
    Sample: 03033caa8958c87b10524cba5f7dee8a3d4948337d87e5fee6a0e756f655faa0
      Properties: {"Hash": "03033caa8958c87b10524cba5f7dee8a3d4948337d87e5fee6a0e75

Inference

In [None]:
import os
import re

import torch
from openai import OpenAI

from common.file_paths import TRAINING_DATA_DIR
from model.model import MinecraftStructureGenerator
from converter.converter import RegionTensorConverter

# Initialize the model
model = TransformerMinecraftStructureGenerator(INPUT_EMBEDDING_SIZE, NUM_CLASSES, OUTPUT_SIZE)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the directory where checkpoints are saved
experiment_name = 'test14'
checkpoint_dir = f'checkpoints/{experiment_name}'

# List all checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_') and f.endswith('.pth')]

# Extract epochs from file names and sort them
epochs = [int(re.search(r'checkpoint_(\d+).pth', f).group(1)) for f in checkpoint_files]
latest_epoch = max(epochs, default=0)  # Use default=0 to handle the case when the list is empty

# Load the trained model weights
latest_checkpoint_file = f'checkpoint_{latest_epoch}.pth'
print(f"Loading checkpoint '{latest_checkpoint_file}'...")
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model.eval()

converter = RegionTensorConverter()

# Loop to take user input and perform inference
while True:
    user_input = input("Enter your text input (or type 'exit' to stop): ")
    if user_input.lower() == 'exit':
        break
    print(f"Input: {user_input}")

    # Get the embedding
    print("Getting embedding...")
    client = OpenAI()
    embedding = client.embeddings.create(input=user_input, model="text-embedding-ada-002").data[0].embedding
    input_tensor = torch.tensor(embedding).unsqueeze(0)  # Add batch dimension
    input_tensor = input_tensor.float()
    input_tensor = input_tensor.to(device)
    print(f"Embedding: {input_tensor.shape}")

    # Perform inference
    with torch.no_grad():
        print("Performing inference...")
        output = model(input_tensor)
        print(f"Output: {output.shape}")

    # Process result
    predicted_tokens = torch.argmax(output, dim=1)
    predicted_tokens = predicted_tokens.squeeze(0)
    print(f"Predicted Tokens: {predicted_tokens.shape}")

    # Convert the output tensor to a schematic
    print("Converting output tensor to schematic...")
    region = converter.tensor_to_region(predicted_tokens)
    print("Conversion complete.")

    # Save the schematic to a file
    print("Saving schematic to file...")
    # try:
    #     schematic = region.as_schematic()
    #     schematic.save('test.litematic')
    # except:
    #     print("Failed to save litematica schematic to file.")
    # try:
    #     structure_nbt = region.to_structure_nbt()
    #     structure_nbt.save('test.nbt')
    # except:
    #     print("Failed to save NBT schematic to file.")
    sponge_nbt = region.to_sponge_nbt()
    sponge_nbt.save(f'{user_input.lower().replace(" ", "")}.schem')
    print("Schematic saved to file.")

In [13]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

from schempy.schematic import Block, Schematic, BlockEntity

schematic = Schematic.from_file(Path('sponge.3.schem'))
print(schematic.metadata)

# Usage example
# schematic = Schematic(width=10, height=10, length=10)
schematic.metadata['Description'] = "This is a schematic generated by SchemPy"

# Set a block at coordinates (x=1, y=2, z=3) to a specific value, e.g., 42
block = Block("minecraft:andesite")
schematic.set_block(1, 2, 3, block)
block = Block("minecraft:oak_planks")
schematic.set_block(0, 0, 0, block)

# Retrieve the block value at coordinates (x=1, y=2, z=3)
block = schematic.get_block(8, 9, 0)
print(block)
block_entity = BlockEntity("minecraft:chest", 0, 0, 0, {"LootTable": "minecraft:chests/simple_dungeon"})
schematic.add_block_entity(block_entity)

schematic.save_to_file(Path('example.schem'), 3)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Compound({'Date': Long(1700278591692), 'WorldEdit': Compound({'Version': String('(unknown)'), 'EditingPlatform': String('enginehub:fabric'), 'Origin': IntArray([Int(0), Int(0), Int(0)]), 'Platforms': Compound({'enginehub:fabric': Compound({'Name': String('Fabric-Official'), 'Version': String('7.3.0-beta-02+e11f161')})})})})
minecraft:air
