In [None]:
from converter import SchematicArrayConverter
from modules import LightningTransformerMinecraftStructureGenerator

experiment_name = 'real_run'
model_version = 10
output_dir = 'schematic_viewer/public/schematics/'
checkpoint_path = f'lightning_logs/{experiment_name}/version_{model_version}/checkpoints/last.ckpt'
model = LightningTransformerMinecraftStructureGenerator.load_from_checkpoint(
    checkpoint_path)
model.eval()

converter = SchematicArrayConverter()

In [None]:
import os

from common.file_paths import BASE_DIR
from converter import SchematicArrayConverter
from modules.data_module import MinecraftDataModule

hdf5_file = os.path.join(BASE_DIR, 'data.h5')
data_module = MinecraftDataModule(
    file_path=hdf5_file,
    batch_size=1,
    # num_workers=4
)

data_module.setup()
val_dataloaders = data_module.val_dataloader()
schematic_array_converter = SchematicArrayConverter()

In [None]:
import random
from pathlib import Path

# Pick a random validation dataloader
dataloader = random.choice(val_dataloaders)
i = val_dataloaders.index(dataloader)
dataset = data_module.val_datasets[i][1]

# Pick a random sample from the dataloader
i = random.randint(0, len(dataset) - 1)
full_structure, masked_structure = dataset[i]

# Fill the masked positions with pink stained glass for visualization
masked_structure_visual = masked_structure.clone()
# masked_structure_visual[masked_structure == 0] = 2173

masked_structure_schematic = schematic_array_converter.array_to_schematic(masked_structure_visual)
masked_structure_schematic.name = 'Test'

# Save the sample
filename = f'sample_{random.random()}.schem'
filepath = os.path.join('schematic_viewer/public/schematics/masked/', filename)
masked_structure_schematic.save_to_file(Path(filepath), 2)

In [None]:
import torch

filled_structure = masked_structure.clone()
filled_structure = filled_structure.to(model.device)
filled_structure = filled_structure.view(-1)
filled_structure = filled_structure.unsqueeze(0)
filled_structure = model(filled_structure)
filled_structure = filled_structure.squeeze(0)
filled_structure = torch.argmax(filled_structure, dim=0)
filled_structure = filled_structure.view(masked_structure.size())

filled_structure_schematic = schematic_array_converter.array_to_schematic(
    filled_structure)
filled_structure_schematic.name = 'Test'

# Save the sample
filename = f'sample_{random.random()}.schem'
filepath = os.path.join('schematic_viewer/public/schematics/filled/', filename)
filled_structure_schematic.save_to_file(Path(filepath), 2)

In [None]:
import time

filled_structure = masked_structure.clone()
filled_structure[filled_structure == 1] = 0

# Generate a sample using the model
start_time = time.time()
for predicted_token, z, y, x in model.fill_structure(filled_structure, temperature=0.7, fill_order='random'):
    print(f'({z}, {y}, {x}): {predicted_token}')
    filled_structure[z, y, x] = predicted_token
end_time = time.time()
print(f'Generation time: {end_time - start_time}')

filled_structure_schematic = schematic_array_converter.array_to_schematic(
    filled_structure)
filled_structure_schematic.name = 'Test'

# Save the sample
filename = f'sample_{random.random()}.schem'
filepath = os.path.join('schematic_viewer/public/schematics/filled/', filename)
filled_structure_schematic.save_to_file(Path(filepath), 2)

In [None]:
import time
import json
import requests

filled_structure = masked_structure.clone()
filled_structure[filled_structure == 1] = 0

# The URL of the FastAPI streaming endpoint
url = 'http://127.0.0.1:8000/complete-structure/'

# The input data to send to the server
input_data = {
    'temperature': 0.7,
    'structure': filled_structure.tolist()
}

serialized_data = json.dumps(input_data)

start_time = time.time()

# Make a POST request and stream the response
response = requests.post(url, data=serialized_data, stream=True)

# Check if the request was successful
if response.status_code == 200:
    # Iterate over the response
    for line in response.iter_lines():
        # Filter out keep-alive new lines
        if line:
            decoded_line = line.decode('utf-8')
            json_data = json.loads(decoded_line)
            value = json_data['value']
            z, y, x = json_data['position']
            filled_structure[z, y, x] = value
else:
    print(f"Error: {response.status_code}")

end_time = time.time()
print(f'Request time: {end_time - start_time}')

# Convert the sample to the desired format using the provided function
filled_structure_schematic = schematic_array_converter.array_to_schematic(
    filled_structure)
filled_structure_schematic.name = 'Test'

# Save the sample
filename = f'sample_epoch_{random.random()}_dataloader_{dataset_name}.schem'
filepath = os.path.join('schematic_viewer/public/schematics/filled/', filename)
filled_structure_schematic.save_to_file(Path(filepath), 2)