In [None]:
from minecraft_schematic_generator.converter import SchematicArrayConverter
from minecraft_schematic_generator.modules import (
    LightningTransformerMinecraftStructureGenerator,
)

experiment_name = "center_data"
model_version = 13
# checkpoint_path = f'../lightning_logs/{experiment_name}/version_{model_version}/checkpoints/epoch=20-step=3096980 copy 2.ckpt'
checkpoint_path = (
    f"../lightning_logs/{experiment_name}/version_{model_version}/checkpoints/last.ckpt"
)
model = LightningTransformerMinecraftStructureGenerator.load_from_checkpoint(
    checkpoint_path
)
model.eval()

converter = SchematicArrayConverter()

In [None]:
from minecraft_schematic_generator.converter import SchematicArrayConverter
from minecraft_schematic_generator.modules.data_module import MinecraftDataModule

hdf5_file = "../data/data_v3.h5"
data_module = MinecraftDataModule(
    file_path=hdf5_file,
    batch_size=1,
    # num_workers=4
)

data_module.setup()
dataloader = data_module.train_dataloader()
# val_dataloaders = [data_module.train_dataloader()]
schematic_array_converter = SchematicArrayConverter(data_module.block_token_converter)

In [None]:
from pathlib import Path

from lightning import Trainer

from minecraft_schematic_generator.converter import SchematicArrayConverter
from minecraft_schematic_generator.model.structure_masker import StructureMasker
from minecraft_schematic_generator.model.structure_transformer import (
    StructureTransformer,
)
from minecraft_schematic_generator.modules.data_module import MinecraftDataModule

# structure_masker = StructureMasker(skip_block_property_transform_chance=0.0, random_block_property_transform_chance=0.0, add_blocks_chance=0.0)
structure_masker = StructureMasker()
structure_transformer = StructureTransformer()
# hdf5_file = "../data/data_masking_test.h5"
hdf5_file = "/mnt/windows/data_v7.h5"
data_module = MinecraftDataModule(
    file_path=hdf5_file,
    structure_masker=structure_masker,
    structure_transformer=structure_transformer,
    crop_sizes={7: 1, 9: 1, 11: 1, 13: 1, 15: 1},
    # crop_sizes={11: 1},
)
trainer = Trainer()
data_module.trainer = trainer

data_module.setup()
dataloader = data_module.train_dataloader()
print(len(dataloader))
# val_dataloader = data_module.val_dataloader()
# val_dataloaders = [data_module.train_dataloader()]
block_token_converter = data_module.get_block_token_converter()
schematic_array_converter = SchematicArrayConverter(block_token_converter)

# dataset = MinecraftDataset(hdf5_file, "train", "hermitcraft\\hermitcraft6\\overworld")

# Get first 10 samples from dataloader
for i, batch in enumerate(dataloader):
    if i >= 10:
        break
    print(i)
    full_structure, masked_structure, mask = batch
    full_structure = full_structure.squeeze()
    masked_structure = masked_structure.squeeze()
    mask = mask.squeeze()

    full_structure_schematic = schematic_array_converter.array_to_schematic(
        full_structure
    )
    full_structure_schematic.name = "Test"
    full_structure_schematic.save_to_file(Path(f"{i}_full.schem"), 2)

    air_token = block_token_converter.universal_str_to_token("universal_minecraft:air")
    pink_stained_glass_token = block_token_converter.universal_str_to_token(
        "universal_minecraft:stained_glass[color=pink]"
    )
    green_stained_glass_token = block_token_converter.universal_str_to_token(
        "universal_minecraft:stained_glass[color=green]"
    )

    masked_structure[masked_structure == air_token] = pink_stained_glass_token
    masked_structure[masked_structure == 0] = air_token
    masked_structure_schematic = schematic_array_converter.array_to_schematic(
        masked_structure
    )
    masked_structure_schematic.name = "Test"

    # Save the sample
    masked_structure_schematic.save_to_file(Path(f"{i}_masked.schem"), 2)

    masked_structure[mask] = green_stained_glass_token
    masked_structure_schematic = schematic_array_converter.array_to_schematic(
        masked_structure
    )
    masked_structure_schematic.name = "Test"
    # Save the sample
    masked_structure_schematic.save_to_file(Path(f"{i}_masked_masked.schem"), 2)

In [None]:
import torch
import torch.nn.functional as F

# Fill the masked positions with pink stained glass for visualization
# masked_structure = dataset._mask_structure(full_structure)
masked_structure_visual = masked_structure.clone()

kernel = torch.ones(
    (1, 1, 3, 3, 3),
    dtype=masked_structure_visual.dtype,
    device=masked_structure_visual.device,
)
kernel[0, 0, 1, 1, 1] = 0  # Ignore the central element

# Create a mask of elements greater than 1
greater_than_1 = masked_structure_visual > 1
print(greater_than_1.sum())

# Convolve to count neighbors that are greater than 1
greater_than_1 = greater_than_1.unsqueeze(0).unsqueeze(0)
neighbors_greater_than_1 = (
    F.conv3d(greater_than_1.float(), kernel.float(), padding=1) >= 1
)
neighbors_greater_than_1 = neighbors_greater_than_1.squeeze(0).squeeze(0)
print(neighbors_greater_than_1.sum())

# Create a mask for elements that are 0
is_zero = masked_structure_visual == 0
print(is_zero.sum())

# Combine the masks
mask = neighbors_greater_than_1 & is_zero
print(mask.sum())

masked_structure_visual[masked_structure == 0] = 2173
masked_structure_visual[mask == 1] = 456
masked_structure_schematic = schematic_array_converter.array_to_schematic(
    masked_structure_visual
)
masked_structure_schematic.name = "Test"

# Save the sample
masked_structure_schematic.save_to_file(Path("masked_structure.schem"), 2)

In [None]:
import time

filled_structure = masked_structure.clone()

# Remove air
filled_structure[filled_structure == 1] = 0

# Generate a sample using the model
start_time = time.time()
for predicted_token, z, y, x in model.fill_structure(
    filled_structure,
    temperature=0.7,
    start_radius=1,
    max_iterations=5,
    max_blocks=50,
    air_probability_iteration_scaling=0.0,
):
    print(f"({z}, {y}, {x}): {predicted_token}")
    filled_structure[z, y, x] = predicted_token
end_time = time.time()
print(f"Generation time: {end_time - start_time}")

filled_structure[filled_structure == 0] = 2173

filled_structure_schematic = schematic_array_converter.array_to_schematic(
    filled_structure
)
filled_structure_schematic.name = "Test"

# Save the sample
filled_structure_schematic.save_to_file(Path("filled_structure.schem"), 2)