In [1]:
from minecraft_schematic_generator.converter import SchematicArrayConverter
from minecraft_schematic_generator.modules import LightningTransformerMinecraftStructureGenerator

experiment_name = 'center_data'
model_version = 13
# checkpoint_path = f'../lightning_logs/{experiment_name}/version_{model_version}/checkpoints/epoch=20-step=3096980 copy 2.ckpt'
checkpoint_path = f'../lightning_logs/{experiment_name}/version_{model_version}/checkpoints/last.ckpt'
model = LightningTransformerMinecraftStructureGenerator.load_from_checkpoint(
    checkpoint_path)
model.eval()

converter = SchematicArrayConverter()

In [2]:
import os

from minecraft_schematic_generator.converter import SchematicArrayConverter
from minecraft_schematic_generator.modules.data_module import MinecraftDataModule

hdf5_file = '../data/data_v2.h5'
data_module = MinecraftDataModule(
    file_path=hdf5_file,
    batch_size=1,
    # num_workers=4
)

data_module.setup(index=1)
val_dataloaders = data_module.val_dataloader()
# val_dataloaders = [data_module.train_dataloader()]
schematic_array_converter = SchematicArrayConverter()

Loading training datasets: 100%|██████████| 1/1 [00:01<00:00,  1.02s/it]
Loading validation datasets: 100%|██████████| 1/1 [00:00<00:00,  7.55it/s]
Loading test datasets: 100%|██████████| 1/1 [00:00<00:00,  7.55it/s]


In [3]:
import random
from pathlib import Path

# Pick a random validation dataloader
dataloader = random.choice(val_dataloaders)
i = val_dataloaders.index(dataloader)
dataset = data_module.val_datasets[i][1]

# Pick a random sample from the dataloader
i = random.randint(0, len(dataset) - 1)
full_structure, masked_structure = dataset[i]

full_structure_schematic = schematic_array_converter.array_to_schematic(full_structure)
full_structure_schematic.name = 'Test'
full_structure_schematic.save_to_file(Path('full_structure.schem'), 2)

# Fill the masked positions with pink stained glass for visualization
masked_structure_visual = masked_structure.clone()
masked_structure_visual[masked_structure == 0] = 1

masked_structure_schematic = schematic_array_converter.array_to_schematic(masked_structure_visual)
masked_structure_schematic.name = 'Test'

# Save the sample
masked_structure_schematic.save_to_file(Path('masked_structure.schem'), 2)

In [4]:
import torch
import torch.nn.functional as F

# Fill the masked positions with pink stained glass for visualization
masked_structure = dataset._mask_structure(full_structure)
masked_structure_visual = masked_structure.clone()

kernel = torch.ones((1, 1, 3, 3, 3), dtype=masked_structure_visual.dtype, device=masked_structure_visual.device)
kernel[0, 0, 1, 1, 1] = 0  # Ignore the central element

# Create a mask of elements greater than 1
greater_than_1 = masked_structure_visual > 1
print(greater_than_1.sum())

# Convolve to count neighbors that are greater than 1
greater_than_1 = greater_than_1.unsqueeze(0).unsqueeze(0)
neighbors_greater_than_1 = F.conv3d(
    greater_than_1.float(), kernel.float(), padding=1) >= 1
neighbors_greater_than_1 = neighbors_greater_than_1.squeeze(0).squeeze(0)
print(neighbors_greater_than_1.sum())

# Create a mask for elements that are 0
is_zero = masked_structure_visual == 0
print(is_zero.sum())

# Combine the masks
mask = neighbors_greater_than_1 & is_zero
print(mask.sum())

masked_structure_visual[masked_structure == 0] = 2173
masked_structure_visual[mask == 1] = 456
masked_structure_schematic = schematic_array_converter.array_to_schematic(masked_structure_visual)
masked_structure_schematic.name = 'Test'

# Save the sample
masked_structure_schematic.save_to_file(Path('masked_structure.schem'), 2)

tensor(806)
tensor(1084)
tensor(446)
tensor(199)


In [5]:
import time

filled_structure = masked_structure.clone()

# Remove air
filled_structure[filled_structure == 1] = 0

# Generate a sample using the model
start_time = time.time()
for predicted_token, z, y, x in model.fill_structure(filled_structure, temperature=0.7, start_radius=1, max_iterations=5, max_blocks=50, air_probability_iteration_scaling=0.0):
    print(f'({z}, {y}, {x}): {predicted_token}')
    filled_structure[z, y, x] = predicted_token
end_time = time.time()
print(f'Generation time: {end_time - start_time}')

filled_structure[filled_structure == 0] = 2173

filled_structure_schematic = schematic_array_converter.array_to_schematic(
    filled_structure)
filled_structure_schematic.name = 'Test'

# Save the sample
filled_structure_schematic.save_to_file(Path('filled_structure.schem'), 2)

Iteration 1/5


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Selected token 62 with probability 99.9%, air probability 0.0%
(6, 6, 4): 62
Filled 1/50 blocks
Selected token 1 with probability 100.0%, air probability 100.0%
Selected token 1 with probability 100.0%, air probability 100.0%
Selected token 1 with probability 98.9%, air probability 98.9%
Selected token 1 with probability 100.0%, air probability 100.0%
Selected token 62 with probability 99.9%, air probability 0.1%
(5, 6, 3): 62
Filled 2/50 blocks
Selected token 7426 with probability 99.5%, air probability 0.0%
(7, 6, 6): 7426
Filled 3/50 blocks
Selected token 1 with probability 70.6%, air probability 70.6%
Selected token 1 with probability 100.0%, air probability 100.0%
Selected token 1 with probability 99.7%, air probability 99.7%
Selected token 62 with probability 99.7%, air probability 0.2%
(4, 6, 3): 62
Filled 4/50 blocks
Selected token 1 with probability 100.0%, air probability 100.0%
Selected token 1 with probability 99.4%, air probability 99.4%
Selected token 62 with probability 

In [None]:
import time
import json
import requests

filled_structure = masked_structure.clone()
filled_structure[filled_structure == 1] = 0

# The URL of the FastAPI streaming endpoint
url = 'http://127.0.0.1:8000/complete-structure/'

# The input data to send to the server
input_data = {
    'temperature': 0.7,
    'structure': filled_structure.tolist()
}

serialized_data = json.dumps(input_data)

start_time = time.time()

# Make a POST request and stream the response
response = requests.post(url, data=serialized_data, stream=True)

# Check if the request was successful
if response.status_code == 200:
    # Iterate over the response
    for line in response.iter_lines():
        # Filter out keep-alive new lines
        if line:
            decoded_line = line.decode('utf-8')
            json_data = json.loads(decoded_line)
            value = json_data['value']
            z, y, x = json_data['position']
            filled_structure[z, y, x] = value
else:
    print(f"Error: {response.status_code}")

end_time = time.time()
print(f'Request time: {end_time - start_time}')

# Convert the sample to the desired format using the provided function
filled_structure_schematic = schematic_array_converter.array_to_schematic(
    filled_structure)
filled_structure_schematic.name = 'Test'

# Save the sample
filename = f'sample_epoch_{random.random()}_dataloader_{dataset_name}.schem'
filepath = os.path.join('schematic_viewer/public/schematics/filled/', filename)
filled_structure_schematic.save_to_file(Path(filepath), 2)