In [None]:
%load_ext autoreload
%autoreload 2

Generate data

In [None]:
import random

from common import constants
from schematic_generator import generator

configs = [
    # Simple shapes
    {
        "generator_type": ["shape"],
        "shape_type": ["sphere"],
        "radius": [lambda: random.randint(1, (constants.region_size[0] // 2) - 1)] * 5,
        "structure_block_types": [[block] for block in constants.simple_block_types] + [lambda: random.sample(constants.simple_block_types, 3)] * len(constants.simple_block_types),
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))],
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    },
    {
        "generator_type": ["shape"],
        "shape_type": ["cube"],
        "side_length": [lambda: random.randint(1, constants.region_size[0] - 1)] * 5,
        "structure_block_types": [[block] for block in constants.simple_block_types] + [lambda: random.sample(constants.simple_block_types, 3)] * len(constants.simple_block_types),
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))],
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    },
    # Filled
    {
        "generator_type": ["shape"],
        "shape_type": ["sphere"],
        "radius": [lambda: random.randint(3, (constants.region_size[0] // 2) - 1)] * 3,
        "structure_block_types": [[block] for block in constants.simple_block_types] + [lambda: random.sample(constants.simple_block_types, 3)] * (len(constants.simple_block_types) // 3),
        "structure_fill_block_types": [["minecraft:air"], lambda: random.sample(constants.simple_block_types, 1), lambda: random.sample(constants.simple_block_types, 3)],
        "thickness": [lambda: random.randint(1, 3)],
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))],
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    },
    {
        "generator_type": ["shape"],
        "shape_type": ["cube"],
        "side_length": [lambda: random.randint(7, constants.region_size[0] - 1)] * 3,
        "structure_block_types": [[block] for block in constants.simple_block_types] + [lambda: random.sample(constants.simple_block_types, 3)] * (len(constants.simple_block_types) // 3),
        "structure_fill_block_types": [["minecraft:air"], lambda: random.sample(constants.simple_block_types, 1), lambda: random.sample(constants.simple_block_types, 3)],
        "thickness": [lambda: random.randint(1, 3)],
        "background_block_types": [["minecraft:air"]],
        "position_offset": [lambda: (random.randint(-100, 100), random.randint(-100, 100), random.randint(-100, 100))],
        "random_seed": [lambda: random.randint(0, 2**32 - 1)],
        "region_size": [constants.region_size]
    }
]

generator.generate_samples_from_configurations(configs, dry_run=False)

Prepare data

In [None]:
from model.data_preparer import prepare_data

prepare_data()

In [None]:
import os

import h5py

from common.file_paths import TRAINING_DATA_DIR

with h5py.File(os.path.join(TRAINING_DATA_DIR, 'data.h5'), 'r') as hf:
    # Iterate over all groups in the root of the file
    for group_name in hf:
        group = hf[group_name]
        
        # Check if both "input_embedding" and "target_tensor" datasets exist in the current group
        if "input_embedding" in group and "target_tensor" in group:
            input_dataset = group["input_embedding"]
            target_dataset = group["target_tensor"]
            
            # Optionally, print their shapes to verify
            print(f"Group '{group_name}':")
            print(f"  Input Embedding Shape: {input_dataset.shape}")
            print(f"  Target Tensor Shape: {target_dataset.shape}")
        else:
            # If either dataset is missing, print a warning
            print(f"Warning: Group '{group_name}' is missing one or both of the required datasets.")
    
    # Optionally, print the total number of groups in the file
    print(f"Total Groups: {len(hf)}")

Train

In [1]:
import os
import re
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from common.file_paths import TRAINING_DATA_DIR
from model.dataset import MinecraftDataset
from model.model import MinecraftStructureGenerator

random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

experiment_name = 'test14'

INPUT_EMBEDDING_SIZE = 1536
NUM_CLASSES = 345
OUTPUT_SIZE = [64, 64, 64]
model = MinecraftStructureGenerator(INPUT_EMBEDDING_SIZE, NUM_CLASSES, OUTPUT_SIZE)
print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=10, verbose=True)

# Define the directory where checkpoints are saved
checkpoint_dir = f'checkpoints/{experiment_name}'
os.makedirs(checkpoint_dir, exist_ok=True)

# List all checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_') and f.endswith('.pth')]

# Extract epochs from file names and sort them
epochs = [int(re.search(r'checkpoint_(\d+).pth', f).group(1)) for f in checkpoint_files]
latest_epoch = max(epochs, default=0)  # Use default=0 to handle the case when the list is empty

# Load the latest checkpoint
if epochs:  # Check if the list is not empty
    latest_checkpoint_file = f'checkpoint_{latest_epoch}.pth'
    checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
    checkpoint = torch.load(checkpoint_path)

    # Restore the model and optimizer state
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # If you saved the epoch number and loss in the checkpoint, you can restore these too
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    print(f"Checkpoint '{latest_checkpoint_file}' loaded. Resuming training from epoch {epoch}.")
else:
    print("No checkpoints found. Starting training from scratch.")
    epoch = 0  # Start from the first epoch
    loss = None  # Loss will be initialized during training

    # Initialize the weights of the model
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            # He initialization for layers with ReLU activation
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            # Default initialization for batch normalization
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            # Xavier initialization for fully connected layers
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0)

# Load the dataset from the HDF5 file
hdf5_file = os.path.join(TRAINING_DATA_DIR, 'data.h5')

# Assuming you have already defined your complete dataset
dataset = MinecraftDataset(hdf5_file)

# # Define the lengths of the splits - 80% training, 10% validation, and 10% testing
# train_size = int(0.8 * len(dataset))
# val_test_size = len(dataset) - train_size
# val_size = int(0.5 * val_test_size)
# test_size = val_test_size - val_size

# # Randomly split the dataset into training, validation, and testing
# train_dataset, val_test_dataset = random_split(dataset, [train_size, val_test_size])
# val_dataset, test_dataset = random_split(val_test_dataset, [val_size, test_size])

train_dataset, val_dataset, test_dataset, _ = random_split(dataset, [100, 100, 100, len(dataset) - 300])

# Create DataLoaders for all datasets
BATCH_SIZE = 10
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Print 1 batch of training data
for inputs, targets, name in train_dataloader:
    print(f"Inputs Shape: {inputs.shape}")
    print(f"Targets Shape: {targets.shape}")
    print(f"Name: {name}")
    break

# Log the model architecture to TensorBoard
example_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
example_inputs, _, _ = next(iter(example_dataloader))
example_inputs = example_inputs.to(device)
writer = SummaryWriter(f'runs/{experiment_name}')
writer.add_graph(model, example_inputs)

# Move the example inputs back to CPU to free up GPU memory
example_inputs = example_inputs.to('cpu')
del example_inputs  # Explicitly delete the example inputs
torch.cuda.empty_cache()  # Clear the cache if necessary

# Training loop
NUM_EPOCHS = 5000
for epoch in range(epoch, NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    val_loss = 0.0

    train_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{NUM_EPOCHS}")
    for inputs, targets, prompt in train_bar:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        logits = model(inputs)

        # Bring the classes dimension to the end before flattening
        loss = F.cross_entropy(logits, targets)
        train_loss += loss.item()

        loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        # Update the progress bar with the latest loss
        train_bar.set_postfix(loss=loss.item())

        # Log the loss to TensorBoard
        writer.add_scalar('Loss/train', loss, epoch)

    # Calculate average training loss for the epoch
    train_loss /= len(train_dataloader)

    scheduler.step(loss)

    # Validation loop
    val_bar = tqdm(val_dataloader, desc=f"Validation Epoch {epoch + 1}/{NUM_EPOCHS}")
    model.eval()
    with torch.no_grad():
        for inputs, targets, _ in val_bar:
            inputs, targets = inputs.to(device), targets.to(device)

            logits = model(inputs)

            loss = F.cross_entropy(logits, targets)
            val_loss += loss.item()

            # Update the progress bar with the latest loss
            val_bar.set_postfix(loss=loss.item())

            # Log the loss to TensorBoard
            writer.add_scalar('Loss/val', loss, epoch)

    # Calculate average validation loss for the epoch
    val_loss /= len(val_dataloader)

    # Print training and validation loss
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

    # Save checkpoint
    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, os.path.join(checkpoint_dir, f"checkpoint_{epoch}.pth"))

writer.close()

MinecraftStructureGenerator(
  (fc1): Linear(in_features=1536, out_features=65536, bias=True)
  (bn1): BatchNorm1d(65536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1): Conv3d(128, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv3d(64, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(3, 3, 3))
  (bn3): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(32, 345, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(5, 5, 5))
)
Number of parameters: 105124441
No checkpoints found. Starting training from scratch.
Inputs Shape: torch.Size([10, 1536])
Targets Shape: torch.Size([10, 64, 64, 64])
Name: ('72ccf061e5200bb34100ca424e78a005a0f2a318acdf9c8b6c608eac680686f4', 'f30f5603fe2f9ba393ddf4974f8ae2c0831402f0158b4502d452c953bc2ad3e5', 'a4fc69d7ef4898a6e123e5b0426537af30ae336441b0dc036505e94b04bba4f

Training Epoch 1/5000: 100%|██████████| 10/10 [00:03<00:00,  2.88it/s, loss=1.46]
Validation Epoch 1/5000: 100%|██████████| 10/10 [00:01<00:00,  8.00it/s, loss=2.56]


Epoch 1/5000, Training Loss: 2.9716, Validation Loss: 2.6701


Training Epoch 2/5000: 100%|██████████| 10/10 [00:02<00:00,  3.58it/s, loss=1.63]
Validation Epoch 2/5000: 100%|██████████| 10/10 [00:01<00:00,  9.98it/s, loss=2.81]


Epoch 2/5000, Training Loss: 1.2910, Validation Loss: 3.0719


Training Epoch 3/5000: 100%|██████████| 10/10 [00:02<00:00,  3.65it/s, loss=2.21]
Validation Epoch 3/5000: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s, loss=3.11]


Epoch 3/5000, Training Loss: 1.1455, Validation Loss: 3.5490


Training Epoch 4/5000: 100%|██████████| 10/10 [00:02<00:00,  3.66it/s, loss=0.647]
Validation Epoch 4/5000: 100%|██████████| 10/10 [00:01<00:00,  9.88it/s, loss=3.27]


Epoch 4/5000, Training Loss: 1.0479, Validation Loss: 3.8092


Training Epoch 5/5000: 100%|██████████| 10/10 [00:02<00:00,  3.59it/s, loss=0.984]
Validation Epoch 5/5000: 100%|██████████| 10/10 [00:01<00:00,  9.80it/s, loss=3.75]


Epoch 5/5000, Training Loss: 0.9741, Validation Loss: 4.3110


Training Epoch 6/5000: 100%|██████████| 10/10 [00:02<00:00,  3.50it/s, loss=1.1] 
Validation Epoch 6/5000: 100%|██████████| 10/10 [00:01<00:00,  9.56it/s, loss=3.4]


Epoch 6/5000, Training Loss: 0.8966, Validation Loss: 4.0911


Training Epoch 7/5000: 100%|██████████| 10/10 [00:02<00:00,  3.51it/s, loss=0.672]
Validation Epoch 7/5000: 100%|██████████| 10/10 [00:01<00:00,  9.80it/s, loss=3.64]


Epoch 7/5000, Training Loss: 0.8360, Validation Loss: 4.3249


Training Epoch 8/5000: 100%|██████████| 10/10 [00:02<00:00,  3.58it/s, loss=0.617]
Validation Epoch 8/5000: 100%|██████████| 10/10 [00:01<00:00,  9.31it/s, loss=3.33]


Epoch 8/5000, Training Loss: 0.7891, Validation Loss: 3.8551


Training Epoch 9/5000: 100%|██████████| 10/10 [00:02<00:00,  3.58it/s, loss=0.78]
Validation Epoch 9/5000: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s, loss=3.74]


Epoch 9/5000, Training Loss: 0.7654, Validation Loss: 4.3297


Training Epoch 10/5000: 100%|██████████| 10/10 [00:02<00:00,  3.63it/s, loss=1.13]
Validation Epoch 10/5000: 100%|██████████| 10/10 [00:01<00:00,  9.81it/s, loss=3.3]


Epoch 10/5000, Training Loss: 0.7379, Validation Loss: 3.5850


Training Epoch 11/5000: 100%|██████████| 10/10 [00:02<00:00,  3.56it/s, loss=0.433]
Validation Epoch 11/5000: 100%|██████████| 10/10 [00:01<00:00,  9.86it/s, loss=3.22]


Epoch 11/5000, Training Loss: 0.7417, Validation Loss: 3.3053


Training Epoch 12/5000: 100%|██████████| 10/10 [00:02<00:00,  3.59it/s, loss=0.377]
Validation Epoch 12/5000: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s, loss=3.26]


Epoch 12/5000, Training Loss: 0.6931, Validation Loss: 3.5370


Training Epoch 13/5000: 100%|██████████| 10/10 [00:02<00:00,  3.51it/s, loss=0.663]
Validation Epoch 13/5000: 100%|██████████| 10/10 [00:01<00:00,  9.53it/s, loss=3.3]


Epoch 13/5000, Training Loss: 0.6692, Validation Loss: 3.4366


Training Epoch 14/5000: 100%|██████████| 10/10 [00:02<00:00,  3.59it/s, loss=0.549]
Validation Epoch 14/5000: 100%|██████████| 10/10 [00:01<00:00,  9.69it/s, loss=3.49]


Epoch 14/5000, Training Loss: 0.6258, Validation Loss: 3.8715


Training Epoch 15/5000: 100%|██████████| 10/10 [00:02<00:00,  3.44it/s, loss=0.636]
Validation Epoch 15/5000: 100%|██████████| 10/10 [00:01<00:00,  9.59it/s, loss=3.55]


Epoch 15/5000, Training Loss: 0.6883, Validation Loss: 3.6693


Training Epoch 16/5000: 100%|██████████| 10/10 [00:02<00:00,  3.63it/s, loss=0.588]
Validation Epoch 16/5000: 100%|██████████| 10/10 [00:01<00:00,  9.66it/s, loss=3.68]


Epoch 16/5000, Training Loss: 0.6047, Validation Loss: 3.8847


Training Epoch 17/5000: 100%|██████████| 10/10 [00:02<00:00,  3.55it/s, loss=0.544]
Validation Epoch 17/5000: 100%|██████████| 10/10 [00:01<00:00,  9.90it/s, loss=3.38]


Epoch 17/5000, Training Loss: 0.6268, Validation Loss: 3.5772


Training Epoch 18/5000: 100%|██████████| 10/10 [00:02<00:00,  3.51it/s, loss=0.299]
Validation Epoch 18/5000: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s, loss=3.69]


Epoch 18/5000, Training Loss: 0.5819, Validation Loss: 3.8064


Training Epoch 19/5000: 100%|██████████| 10/10 [00:02<00:00,  3.50it/s, loss=0.791]
Validation Epoch 19/5000: 100%|██████████| 10/10 [00:01<00:00,  9.76it/s, loss=3.47]


Epoch 19/5000, Training Loss: 0.5089, Validation Loss: 3.9202


Training Epoch 20/5000: 100%|██████████| 10/10 [00:02<00:00,  3.53it/s, loss=0.399]
Validation Epoch 20/5000: 100%|██████████| 10/10 [00:01<00:00,  9.72it/s, loss=4.43]


Epoch 20/5000, Training Loss: 0.4934, Validation Loss: 4.5824


Training Epoch 21/5000: 100%|██████████| 10/10 [00:02<00:00,  3.58it/s, loss=0.719]
Validation Epoch 21/5000: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s, loss=3.75]


Epoch 21/5000, Training Loss: 0.4620, Validation Loss: 3.8315


Training Epoch 22/5000: 100%|██████████| 10/10 [00:02<00:00,  3.44it/s, loss=0.641]
Validation Epoch 22/5000: 100%|██████████| 10/10 [00:01<00:00,  9.63it/s, loss=3.54]


Epoch 22/5000, Training Loss: 0.4731, Validation Loss: 3.9110


Training Epoch 23/5000: 100%|██████████| 10/10 [00:02<00:00,  3.63it/s, loss=0.415]
Validation Epoch 23/5000: 100%|██████████| 10/10 [00:01<00:00,  9.63it/s, loss=3.39]


Epoch 23/5000, Training Loss: 0.4378, Validation Loss: 4.1047


Training Epoch 24/5000: 100%|██████████| 10/10 [00:02<00:00,  3.56it/s, loss=0.496]
Validation Epoch 24/5000: 100%|██████████| 10/10 [00:01<00:00,  9.67it/s, loss=3.92]


Epoch 24/5000, Training Loss: 0.3987, Validation Loss: 4.0567


Training Epoch 25/5000: 100%|██████████| 10/10 [00:02<00:00,  3.51it/s, loss=0.448]
Validation Epoch 25/5000: 100%|██████████| 10/10 [00:01<00:00,  9.50it/s, loss=3.61]


Epoch 25/5000, Training Loss: 0.4405, Validation Loss: 4.1252


Training Epoch 26/5000: 100%|██████████| 10/10 [00:02<00:00,  3.49it/s, loss=0.559]
Validation Epoch 26/5000: 100%|██████████| 10/10 [00:01<00:00,  9.65it/s, loss=4.04]


Epoch 26/5000, Training Loss: 0.3629, Validation Loss: 4.5377


Training Epoch 27/5000: 100%|██████████| 10/10 [00:02<00:00,  3.39it/s, loss=0.527]
Validation Epoch 27/5000: 100%|██████████| 10/10 [00:01<00:00,  9.87it/s, loss=3.47]


Epoch 27/5000, Training Loss: 0.3533, Validation Loss: 4.2347


Training Epoch 28/5000: 100%|██████████| 10/10 [00:02<00:00,  3.60it/s, loss=0.468]
Validation Epoch 28/5000: 100%|██████████| 10/10 [00:01<00:00,  9.57it/s, loss=3.98]


Epoch 28/5000, Training Loss: 0.3461, Validation Loss: 4.6341


Training Epoch 29/5000: 100%|██████████| 10/10 [00:02<00:00,  3.55it/s, loss=0.155]
Validation Epoch 29/5000: 100%|██████████| 10/10 [00:01<00:00,  9.50it/s, loss=3.94]


Epoch 29/5000, Training Loss: 0.4515, Validation Loss: 4.4243


Training Epoch 30/5000: 100%|██████████| 10/10 [00:02<00:00,  3.51it/s, loss=0.51]
Validation Epoch 30/5000: 100%|██████████| 10/10 [00:01<00:00,  9.76it/s, loss=3.94]


Epoch 30/5000, Training Loss: 0.3258, Validation Loss: 4.4223


Training Epoch 31/5000: 100%|██████████| 10/10 [00:02<00:00,  3.57it/s, loss=0.603]
Validation Epoch 31/5000: 100%|██████████| 10/10 [00:01<00:00,  9.75it/s, loss=3.59]


Epoch 31/5000, Training Loss: 0.3111, Validation Loss: 4.4154


Training Epoch 32/5000: 100%|██████████| 10/10 [00:02<00:00,  3.50it/s, loss=0.223]
Validation Epoch 32/5000: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s, loss=3.82]


Epoch 32/5000, Training Loss: 0.2773, Validation Loss: 4.5723


Training Epoch 33/5000: 100%|██████████| 10/10 [00:02<00:00,  3.56it/s, loss=0.219]
Validation Epoch 33/5000: 100%|██████████| 10/10 [00:01<00:00,  9.92it/s, loss=4.01]


Epoch 33/5000, Training Loss: 0.2890, Validation Loss: 4.8078


Training Epoch 34/5000: 100%|██████████| 10/10 [00:02<00:00,  3.42it/s, loss=0.233]
Validation Epoch 34/5000: 100%|██████████| 10/10 [00:01<00:00,  9.81it/s, loss=3.97]


Epoch 34/5000, Training Loss: 0.2891, Validation Loss: 4.6840


Training Epoch 35/5000: 100%|██████████| 10/10 [00:02<00:00,  3.60it/s, loss=0.231]
Validation Epoch 35/5000: 100%|██████████| 10/10 [00:01<00:00,  9.64it/s, loss=3.89]


Epoch 35/5000, Training Loss: 0.2461, Validation Loss: 4.7446


Training Epoch 36/5000: 100%|██████████| 10/10 [00:02<00:00,  3.58it/s, loss=0.533]
Validation Epoch 36/5000: 100%|██████████| 10/10 [00:01<00:00,  9.88it/s, loss=3.87]


Epoch 36/5000, Training Loss: 0.2596, Validation Loss: 4.8397


Training Epoch 37/5000: 100%|██████████| 10/10 [00:02<00:00,  3.44it/s, loss=0.273]
Validation Epoch 37/5000: 100%|██████████| 10/10 [00:01<00:00,  9.69it/s, loss=3.87]


Epoch 37/5000, Training Loss: 0.2474, Validation Loss: 4.3835


Training Epoch 38/5000: 100%|██████████| 10/10 [00:02<00:00,  3.57it/s, loss=0.328]
Validation Epoch 38/5000: 100%|██████████| 10/10 [00:01<00:00,  9.58it/s, loss=3.92]


Epoch 38/5000, Training Loss: 0.2669, Validation Loss: 4.6041


Training Epoch 39/5000: 100%|██████████| 10/10 [00:02<00:00,  3.53it/s, loss=0.354]
Validation Epoch 39/5000: 100%|██████████| 10/10 [00:01<00:00,  9.61it/s, loss=3.74]


Epoch 39/5000, Training Loss: 0.2573, Validation Loss: 4.4246


Training Epoch 40/5000: 100%|██████████| 10/10 [00:02<00:00,  3.49it/s, loss=0.155]
Validation Epoch 40/5000: 100%|██████████| 10/10 [00:01<00:00,  9.78it/s, loss=3.58]


Epoch 40/5000, Training Loss: 0.2949, Validation Loss: 4.4688


Training Epoch 41/5000: 100%|██████████| 10/10 [00:02<00:00,  3.48it/s, loss=0.222]
Validation Epoch 41/5000: 100%|██████████| 10/10 [00:01<00:00,  9.26it/s, loss=4.47]


Epoch 41/5000, Training Loss: 0.2323, Validation Loss: 4.9435


Training Epoch 42/5000: 100%|██████████| 10/10 [00:02<00:00,  3.38it/s, loss=0.544]
Validation Epoch 42/5000: 100%|██████████| 10/10 [00:01<00:00,  9.73it/s, loss=4.42]


Epoch 42/5000, Training Loss: 0.2670, Validation Loss: 4.7078


Training Epoch 43/5000: 100%|██████████| 10/10 [00:02<00:00,  3.46it/s, loss=0.132]
Validation Epoch 43/5000: 100%|██████████| 10/10 [00:01<00:00,  9.59it/s, loss=3.81]


Epoch 43/5000, Training Loss: 0.2689, Validation Loss: 4.4533


Training Epoch 44/5000: 100%|██████████| 10/10 [00:02<00:00,  3.57it/s, loss=0.305]
Validation Epoch 44/5000: 100%|██████████| 10/10 [00:01<00:00,  9.75it/s, loss=4.05]


Epoch 44/5000, Training Loss: 0.2235, Validation Loss: 4.9466


Training Epoch 45/5000: 100%|██████████| 10/10 [00:02<00:00,  3.55it/s, loss=0.179]
Validation Epoch 45/5000: 100%|██████████| 10/10 [00:01<00:00,  9.81it/s, loss=3.96]


Epoch 45/5000, Training Loss: 0.2184, Validation Loss: 4.5350


Training Epoch 46/5000: 100%|██████████| 10/10 [00:02<00:00,  3.60it/s, loss=0.0935]
Validation Epoch 46/5000: 100%|██████████| 10/10 [00:01<00:00,  9.41it/s, loss=4]  


Epoch 46/5000, Training Loss: 0.2104, Validation Loss: 4.6957


Training Epoch 47/5000: 100%|██████████| 10/10 [00:02<00:00,  3.49it/s, loss=0.266]
Validation Epoch 47/5000: 100%|██████████| 10/10 [00:01<00:00,  9.82it/s, loss=4]  


Epoch 47/5000, Training Loss: 0.2026, Validation Loss: 4.8916


Training Epoch 48/5000: 100%|██████████| 10/10 [00:02<00:00,  3.45it/s, loss=0.248]
Validation Epoch 48/5000: 100%|██████████| 10/10 [00:01<00:00,  9.59it/s, loss=3.82]


Epoch 48/5000, Training Loss: 0.1878, Validation Loss: 4.8700


Training Epoch 49/5000: 100%|██████████| 10/10 [00:02<00:00,  3.56it/s, loss=0.161]
Validation Epoch 49/5000: 100%|██████████| 10/10 [00:01<00:00,  9.49it/s, loss=3.82]


Epoch 49/5000, Training Loss: 0.1823, Validation Loss: 4.9030


Training Epoch 50/5000: 100%|██████████| 10/10 [00:02<00:00,  3.37it/s, loss=0.2]  
Validation Epoch 50/5000: 100%|██████████| 10/10 [00:01<00:00,  9.55it/s, loss=3.75]


Epoch 50/5000, Training Loss: 0.1805, Validation Loss: 4.8570


Training Epoch 51/5000: 100%|██████████| 10/10 [00:02<00:00,  3.51it/s, loss=0.351]
Validation Epoch 51/5000: 100%|██████████| 10/10 [00:01<00:00,  9.35it/s, loss=4.01]


Epoch 51/5000, Training Loss: 0.1909, Validation Loss: 4.9619


Training Epoch 52/5000: 100%|██████████| 10/10 [00:02<00:00,  3.46it/s, loss=0.106]
Validation Epoch 52/5000: 100%|██████████| 10/10 [00:01<00:00,  9.29it/s, loss=3.88]


Epoch 52/5000, Training Loss: 0.1798, Validation Loss: 4.7907


Training Epoch 53/5000: 100%|██████████| 10/10 [00:03<00:00,  3.16it/s, loss=0.24] 
Validation Epoch 53/5000: 100%|██████████| 10/10 [00:01<00:00,  8.83it/s, loss=3.94]


Epoch 53/5000, Training Loss: 0.1853, Validation Loss: 4.7774


Training Epoch 54/5000: 100%|██████████| 10/10 [00:03<00:00,  3.19it/s, loss=0.163]
Validation Epoch 54/5000: 100%|██████████| 10/10 [00:01<00:00,  8.85it/s, loss=4.3]


Epoch 54/5000, Training Loss: 0.1719, Validation Loss: 4.8271


Training Epoch 55/5000: 100%|██████████| 10/10 [00:03<00:00,  3.18it/s, loss=0.108]
Validation Epoch 55/5000: 100%|██████████| 10/10 [00:01<00:00,  8.66it/s, loss=4]  


Epoch 55/5000, Training Loss: 0.1716, Validation Loss: 4.7616


Training Epoch 56/5000: 100%|██████████| 10/10 [00:03<00:00,  3.20it/s, loss=0.126]
Validation Epoch 56/5000: 100%|██████████| 10/10 [00:01<00:00,  8.79it/s, loss=3.72]


Epoch 56/5000, Training Loss: 0.1679, Validation Loss: 4.5350


Training Epoch 57/5000: 100%|██████████| 10/10 [00:03<00:00,  3.18it/s, loss=0.148]


Epoch 00057: reducing learning rate of group 0 to 1.0000e-04.


Validation Epoch 57/5000: 100%|██████████| 10/10 [00:01<00:00,  8.81it/s, loss=3.87]


Epoch 57/5000, Training Loss: 0.1557, Validation Loss: 4.7548


Training Epoch 58/5000: 100%|██████████| 10/10 [00:03<00:00,  3.30it/s, loss=0.217]
Validation Epoch 58/5000: 100%|██████████| 10/10 [00:01<00:00,  9.53it/s, loss=3.88]


Epoch 58/5000, Training Loss: 0.1612, Validation Loss: 4.7678


Training Epoch 59/5000: 100%|██████████| 10/10 [00:02<00:00,  3.53it/s, loss=0.0877]
Validation Epoch 59/5000: 100%|██████████| 10/10 [00:01<00:00,  9.77it/s, loss=3.87]


Epoch 59/5000, Training Loss: 0.1426, Validation Loss: 4.7670


Training Epoch 60/5000: 100%|██████████| 10/10 [00:02<00:00,  3.44it/s, loss=0.0526]
Validation Epoch 60/5000: 100%|██████████| 10/10 [00:01<00:00,  9.58it/s, loss=3.87]


Epoch 60/5000, Training Loss: 0.1381, Validation Loss: 4.7857


Training Epoch 61/5000: 100%|██████████| 10/10 [00:02<00:00,  3.41it/s, loss=0.0919]
Validation Epoch 61/5000: 100%|██████████| 10/10 [00:01<00:00,  8.68it/s, loss=3.89]


Epoch 61/5000, Training Loss: 0.1343, Validation Loss: 4.8159


Training Epoch 62/5000: 100%|██████████| 10/10 [00:02<00:00,  3.38it/s, loss=0.187]
Validation Epoch 62/5000: 100%|██████████| 10/10 [00:01<00:00,  9.64it/s, loss=3.86]


Epoch 62/5000, Training Loss: 0.1324, Validation Loss: 4.7537


Training Epoch 63/5000: 100%|██████████| 10/10 [00:02<00:00,  3.56it/s, loss=0.147]
Validation Epoch 63/5000: 100%|██████████| 10/10 [00:01<00:00,  9.35it/s, loss=3.93]


Epoch 63/5000, Training Loss: 0.1340, Validation Loss: 4.8520


Training Epoch 64/5000: 100%|██████████| 10/10 [00:02<00:00,  3.46it/s, loss=0.179]
Validation Epoch 64/5000: 100%|██████████| 10/10 [00:01<00:00,  9.11it/s, loss=3.92]


Epoch 64/5000, Training Loss: 0.1302, Validation Loss: 4.8015


Training Epoch 65/5000: 100%|██████████| 10/10 [00:02<00:00,  3.37it/s, loss=0.0969]
Validation Epoch 65/5000: 100%|██████████| 10/10 [00:01<00:00,  9.42it/s, loss=3.94]


Epoch 65/5000, Training Loss: 0.1247, Validation Loss: 4.8628


Training Epoch 66/5000: 100%|██████████| 10/10 [00:02<00:00,  3.48it/s, loss=0.154]
Validation Epoch 66/5000: 100%|██████████| 10/10 [00:01<00:00,  9.58it/s, loss=3.95]


Epoch 66/5000, Training Loss: 0.1247, Validation Loss: 4.8768


Training Epoch 67/5000: 100%|██████████| 10/10 [00:02<00:00,  3.42it/s, loss=0.0568]
Validation Epoch 67/5000: 100%|██████████| 10/10 [00:01<00:00,  9.71it/s, loss=4.02]


Epoch 67/5000, Training Loss: 0.1240, Validation Loss: 4.9323


Training Epoch 68/5000: 100%|██████████| 10/10 [00:02<00:00,  3.42it/s, loss=0.143]
Validation Epoch 68/5000: 100%|██████████| 10/10 [00:01<00:00,  9.49it/s, loss=4.03]


Epoch 68/5000, Training Loss: 0.1245, Validation Loss: 4.9955


Training Epoch 69/5000: 100%|██████████| 10/10 [00:02<00:00,  3.41it/s, loss=0.154]
Validation Epoch 69/5000: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s, loss=4.04]


Epoch 69/5000, Training Loss: 0.1201, Validation Loss: 4.9787


Training Epoch 70/5000: 100%|██████████| 10/10 [00:02<00:00,  3.56it/s, loss=0.0402]
Validation Epoch 70/5000: 100%|██████████| 10/10 [00:01<00:00,  9.60it/s, loss=4.05]


Epoch 70/5000, Training Loss: 0.1265, Validation Loss: 4.9926


Training Epoch 71/5000: 100%|██████████| 10/10 [00:02<00:00,  3.46it/s, loss=0.248]
Validation Epoch 71/5000: 100%|██████████| 10/10 [00:01<00:00,  9.52it/s, loss=4.07]


Epoch 71/5000, Training Loss: 0.1248, Validation Loss: 5.0252


Training Epoch 72/5000: 100%|██████████| 10/10 [00:02<00:00,  3.46it/s, loss=0.102]
Validation Epoch 72/5000: 100%|██████████| 10/10 [00:01<00:00,  9.74it/s, loss=4.08]


Epoch 72/5000, Training Loss: 0.1244, Validation Loss: 5.0264


Training Epoch 73/5000: 100%|██████████| 10/10 [00:02<00:00,  3.53it/s, loss=0.13] 
Validation Epoch 73/5000: 100%|██████████| 10/10 [00:01<00:00,  9.61it/s, loss=4.07]


Epoch 73/5000, Training Loss: 0.1233, Validation Loss: 5.0436


Training Epoch 74/5000: 100%|██████████| 10/10 [00:02<00:00,  3.53it/s, loss=0.119]
Validation Epoch 74/5000: 100%|██████████| 10/10 [00:01<00:00,  9.06it/s, loss=4.08]


Epoch 74/5000, Training Loss: 0.1211, Validation Loss: 5.0726


Training Epoch 75/5000: 100%|██████████| 10/10 [00:02<00:00,  3.43it/s, loss=0.0351]
Validation Epoch 75/5000: 100%|██████████| 10/10 [00:01<00:00,  9.66it/s, loss=4.08]


Epoch 75/5000, Training Loss: 0.1225, Validation Loss: 5.0616


Training Epoch 76/5000: 100%|██████████| 10/10 [00:02<00:00,  3.54it/s, loss=0.302]
Validation Epoch 76/5000: 100%|██████████| 10/10 [00:01<00:00,  9.60it/s, loss=4.12]


Epoch 76/5000, Training Loss: 0.1276, Validation Loss: 5.1267


Training Epoch 77/5000: 100%|██████████| 10/10 [00:02<00:00,  3.38it/s, loss=0.105]
Validation Epoch 77/5000: 100%|██████████| 10/10 [00:01<00:00,  9.83it/s, loss=4.15]


Epoch 77/5000, Training Loss: 0.1194, Validation Loss: 5.0817


Training Epoch 78/5000: 100%|██████████| 10/10 [00:02<00:00,  3.53it/s, loss=0.0977]
Validation Epoch 78/5000: 100%|██████████| 10/10 [00:01<00:00,  9.75it/s, loss=4.22]


Epoch 78/5000, Training Loss: 0.1244, Validation Loss: 5.1433


Training Epoch 79/5000: 100%|██████████| 10/10 [00:03<00:00,  3.28it/s, loss=0.108]
Validation Epoch 79/5000: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s, loss=4.21]


Epoch 79/5000, Training Loss: 0.1163, Validation Loss: 5.1433


Training Epoch 80/5000: 100%|██████████| 10/10 [00:02<00:00,  3.54it/s, loss=0.195]
Validation Epoch 80/5000: 100%|██████████| 10/10 [00:01<00:00,  9.74it/s, loss=4.24]


Epoch 80/5000, Training Loss: 0.1193, Validation Loss: 5.1791


Training Epoch 81/5000: 100%|██████████| 10/10 [00:02<00:00,  3.40it/s, loss=0.145]
Validation Epoch 81/5000: 100%|██████████| 10/10 [00:01<00:00,  9.78it/s, loss=4.22]


Epoch 81/5000, Training Loss: 0.1182, Validation Loss: 5.1320


Training Epoch 82/5000: 100%|██████████| 10/10 [00:03<00:00,  3.32it/s, loss=0.147]
Validation Epoch 82/5000: 100%|██████████| 10/10 [00:01<00:00,  9.55it/s, loss=4.24]


Epoch 82/5000, Training Loss: 0.1153, Validation Loss: 5.1854


Training Epoch 83/5000: 100%|██████████| 10/10 [00:02<00:00,  3.54it/s, loss=0.0886]
Validation Epoch 83/5000: 100%|██████████| 10/10 [00:01<00:00,  9.82it/s, loss=4.25]


Epoch 83/5000, Training Loss: 0.1185, Validation Loss: 5.1678


Training Epoch 84/5000: 100%|██████████| 10/10 [00:02<00:00,  3.54it/s, loss=0.128]
Validation Epoch 84/5000: 100%|██████████| 10/10 [00:01<00:00,  9.17it/s, loss=4.28]


Epoch 84/5000, Training Loss: 0.1190, Validation Loss: 5.2466


Training Epoch 85/5000: 100%|██████████| 10/10 [00:02<00:00,  3.46it/s, loss=0.0758]
Validation Epoch 85/5000: 100%|██████████| 10/10 [00:01<00:00,  9.31it/s, loss=4.25]


Epoch 85/5000, Training Loss: 0.1175, Validation Loss: 5.1981


Training Epoch 86/5000: 100%|██████████| 10/10 [00:02<00:00,  3.43it/s, loss=0.148]


Epoch 00086: reducing learning rate of group 0 to 1.0000e-05.


Validation Epoch 86/5000: 100%|██████████| 10/10 [00:01<00:00,  8.97it/s, loss=4.27]


Epoch 86/5000, Training Loss: 0.1153, Validation Loss: 5.2005


Training Epoch 87/5000: 100%|██████████| 10/10 [00:02<00:00,  3.37it/s, loss=0.0492]
Validation Epoch 87/5000: 100%|██████████| 10/10 [00:01<00:00,  9.21it/s, loss=4.28]


Epoch 87/5000, Training Loss: 0.1190, Validation Loss: 5.2261


Training Epoch 88/5000: 100%|██████████| 10/10 [00:02<00:00,  3.54it/s, loss=0.13] 
Validation Epoch 88/5000: 100%|██████████| 10/10 [00:01<00:00,  9.91it/s, loss=4.27]


Epoch 88/5000, Training Loss: 0.1166, Validation Loss: 5.2002


Training Epoch 89/5000: 100%|██████████| 10/10 [00:02<00:00,  3.39it/s, loss=0.0621]
Validation Epoch 89/5000: 100%|██████████| 10/10 [00:01<00:00,  9.51it/s, loss=4.27]


Epoch 89/5000, Training Loss: 0.1194, Validation Loss: 5.2099


Training Epoch 90/5000: 100%|██████████| 10/10 [00:03<00:00,  3.33it/s, loss=0.191]
Validation Epoch 90/5000: 100%|██████████| 10/10 [00:01<00:00,  9.21it/s, loss=4.3]


Epoch 90/5000, Training Loss: 0.1177, Validation Loss: 5.2637


Training Epoch 91/5000: 100%|██████████| 10/10 [00:02<00:00,  3.48it/s, loss=0.0636]
Validation Epoch 91/5000: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s, loss=4.31]


Epoch 91/5000, Training Loss: 0.1183, Validation Loss: 5.2646


Training Epoch 92/5000: 100%|██████████| 10/10 [00:03<00:00,  3.30it/s, loss=0.163]
Validation Epoch 92/5000: 100%|██████████| 10/10 [00:01<00:00,  9.66it/s, loss=4.28]


Epoch 92/5000, Training Loss: 0.1163, Validation Loss: 5.2540


Training Epoch 93/5000: 100%|██████████| 10/10 [00:02<00:00,  3.54it/s, loss=0.0879]
Validation Epoch 93/5000: 100%|██████████| 10/10 [00:01<00:00,  9.53it/s, loss=4.25]


Epoch 93/5000, Training Loss: 0.1162, Validation Loss: 5.1817


Training Epoch 94/5000: 100%|██████████| 10/10 [00:02<00:00,  3.55it/s, loss=0.23] 
Validation Epoch 94/5000: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s, loss=4.28]


Epoch 94/5000, Training Loss: 0.1167, Validation Loss: 5.2285


Training Epoch 95/5000: 100%|██████████| 10/10 [00:02<00:00,  3.55it/s, loss=0.0902]
Validation Epoch 95/5000: 100%|██████████| 10/10 [00:01<00:00,  9.77it/s, loss=4.31]


Epoch 95/5000, Training Loss: 0.1172, Validation Loss: 5.2447


Training Epoch 96/5000: 100%|██████████| 10/10 [00:03<00:00,  3.33it/s, loss=0.064]
Validation Epoch 96/5000: 100%|██████████| 10/10 [00:01<00:00,  9.65it/s, loss=4.29]


Epoch 96/5000, Training Loss: 0.1182, Validation Loss: 5.2225


Training Epoch 97/5000: 100%|██████████| 10/10 [00:02<00:00,  3.35it/s, loss=0.106]


Epoch 00097: reducing learning rate of group 0 to 1.0000e-06.


Validation Epoch 97/5000: 100%|██████████| 10/10 [00:01<00:00,  8.89it/s, loss=4.28]


Epoch 97/5000, Training Loss: 0.1195, Validation Loss: 5.2000


Training Epoch 98/5000: 100%|██████████| 10/10 [00:02<00:00,  3.43it/s, loss=0.0813]
Validation Epoch 98/5000: 100%|██████████| 10/10 [00:01<00:00,  9.39it/s, loss=4.24]


Epoch 98/5000, Training Loss: 0.1175, Validation Loss: 5.1385


Training Epoch 99/5000: 100%|██████████| 10/10 [00:02<00:00,  3.55it/s, loss=0.151]
Validation Epoch 99/5000: 100%|██████████| 10/10 [00:01<00:00,  9.52it/s, loss=4.26]


Epoch 99/5000, Training Loss: 0.1174, Validation Loss: 5.1965


Training Epoch 100/5000: 100%|██████████| 10/10 [00:02<00:00,  3.50it/s, loss=0.0539]
Validation Epoch 100/5000: 100%|██████████| 10/10 [00:01<00:00,  9.78it/s, loss=4.28]


Epoch 100/5000, Training Loss: 0.1170, Validation Loss: 5.2050


Training Epoch 101/5000: 100%|██████████| 10/10 [00:02<00:00,  3.55it/s, loss=0.199]
Validation Epoch 101/5000: 100%|██████████| 10/10 [00:01<00:00,  9.68it/s, loss=4.28]


Epoch 101/5000, Training Loss: 0.1161, Validation Loss: 5.2237


Training Epoch 102/5000: 100%|██████████| 10/10 [00:02<00:00,  3.49it/s, loss=0.123]
Validation Epoch 102/5000: 100%|██████████| 10/10 [00:01<00:00,  8.87it/s, loss=4.32]


Epoch 102/5000, Training Loss: 0.1208, Validation Loss: 5.2512


Training Epoch 103/5000: 100%|██████████| 10/10 [00:02<00:00,  3.48it/s, loss=0.0852]
Validation Epoch 103/5000: 100%|██████████| 10/10 [00:01<00:00,  9.01it/s, loss=4.29]


Epoch 103/5000, Training Loss: 0.1162, Validation Loss: 5.1918


Training Epoch 104/5000: 100%|██████████| 10/10 [00:03<00:00,  3.32it/s, loss=0.112]
Validation Epoch 104/5000: 100%|██████████| 10/10 [00:01<00:00,  8.97it/s, loss=4.27]


Epoch 104/5000, Training Loss: 0.1174, Validation Loss: 5.2074


Training Epoch 105/5000: 100%|██████████| 10/10 [00:02<00:00,  3.49it/s, loss=0.0619]
Validation Epoch 105/5000: 100%|██████████| 10/10 [00:01<00:00,  9.51it/s, loss=4.31]


Epoch 105/5000, Training Loss: 0.1177, Validation Loss: 5.2075


Training Epoch 106/5000: 100%|██████████| 10/10 [00:03<00:00,  3.32it/s, loss=0.155]
Validation Epoch 106/5000: 100%|██████████| 10/10 [00:01<00:00,  9.57it/s, loss=4.28]


Epoch 106/5000, Training Loss: 0.1147, Validation Loss: 5.2151


Training Epoch 107/5000: 100%|██████████| 10/10 [00:02<00:00,  3.37it/s, loss=0.0525]
Validation Epoch 107/5000: 100%|██████████| 10/10 [00:01<00:00,  9.75it/s, loss=4.3]


Epoch 107/5000, Training Loss: 0.1185, Validation Loss: 5.2482


Training Epoch 108/5000: 100%|██████████| 10/10 [00:02<00:00,  3.49it/s, loss=0.144]


Epoch 00108: reducing learning rate of group 0 to 1.0000e-07.


Validation Epoch 108/5000: 100%|██████████| 10/10 [00:01<00:00,  8.92it/s, loss=4.26]


Epoch 108/5000, Training Loss: 0.1153, Validation Loss: 5.2153


Training Epoch 109/5000: 100%|██████████| 10/10 [00:02<00:00,  3.37it/s, loss=0.147]
Validation Epoch 109/5000: 100%|██████████| 10/10 [00:01<00:00,  8.81it/s, loss=4.3]


Epoch 109/5000, Training Loss: 0.1202, Validation Loss: 5.2615


Training Epoch 110/5000: 100%|██████████| 10/10 [00:02<00:00,  3.45it/s, loss=0.0794]
Validation Epoch 110/5000: 100%|██████████| 10/10 [00:01<00:00,  9.52it/s, loss=4.26]


Epoch 110/5000, Training Loss: 0.1195, Validation Loss: 5.2162


Training Epoch 111/5000: 100%|██████████| 10/10 [00:02<00:00,  3.42it/s, loss=0.0682]
Validation Epoch 111/5000: 100%|██████████| 10/10 [00:01<00:00,  9.46it/s, loss=4.29]


Epoch 111/5000, Training Loss: 0.1177, Validation Loss: 5.2690


Training Epoch 112/5000: 100%|██████████| 10/10 [00:02<00:00,  3.38it/s, loss=0.121]
Validation Epoch 112/5000: 100%|██████████| 10/10 [00:01<00:00,  9.42it/s, loss=4.28]


Epoch 112/5000, Training Loss: 0.1160, Validation Loss: 5.2267


Training Epoch 113/5000: 100%|██████████| 10/10 [00:02<00:00,  3.50it/s, loss=0.111]
Validation Epoch 113/5000: 100%|██████████| 10/10 [00:01<00:00,  9.60it/s, loss=4.24]


Epoch 113/5000, Training Loss: 0.1204, Validation Loss: 5.1746


Training Epoch 114/5000: 100%|██████████| 10/10 [00:02<00:00,  3.54it/s, loss=0.152]
Validation Epoch 114/5000: 100%|██████████| 10/10 [00:01<00:00,  9.61it/s, loss=4.32]


Epoch 114/5000, Training Loss: 0.1340, Validation Loss: 5.2896


Training Epoch 115/5000: 100%|██████████| 10/10 [00:02<00:00,  3.42it/s, loss=0.149]
Validation Epoch 115/5000: 100%|██████████| 10/10 [00:01<00:00,  9.44it/s, loss=4.29]


Epoch 115/5000, Training Loss: 0.1177, Validation Loss: 5.2270


Training Epoch 116/5000: 100%|██████████| 10/10 [00:02<00:00,  3.54it/s, loss=0.0956]
Validation Epoch 116/5000: 100%|██████████| 10/10 [00:01<00:00,  9.07it/s, loss=4.28]


Epoch 116/5000, Training Loss: 0.1160, Validation Loss: 5.2181


Training Epoch 117/5000: 100%|██████████| 10/10 [00:02<00:00,  3.39it/s, loss=0.0894]
Validation Epoch 117/5000: 100%|██████████| 10/10 [00:01<00:00,  9.74it/s, loss=4.34]


Epoch 117/5000, Training Loss: 0.1221, Validation Loss: 5.2437


Training Epoch 118/5000:  70%|███████   | 7/10 [00:02<00:00,  3.54it/s, loss=0.118] 

Inference

In [1]:
import os
import re

import torch
from openai import OpenAI

from common.file_paths import TRAINING_DATA_DIR
from model.model import MinecraftStructureGenerator
from converter.converter import RegionTensorConverter

# Assuming the constants are defined as in the training script
INPUT_EMBEDDING_SIZE = 1536
NUM_CLASSES = 345
OUTPUT_SIZE = [64, 64, 64]

# Initialize the model
model = MinecraftStructureGenerator(INPUT_EMBEDDING_SIZE, NUM_CLASSES, OUTPUT_SIZE)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the directory where checkpoints are saved
experiment_name = 'test14'
checkpoint_dir = f'checkpoints/{experiment_name}'

# List all checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_') and f.endswith('.pth')]

# Extract epochs from file names and sort them
epochs = [int(re.search(r'checkpoint_(\d+).pth', f).group(1)) for f in checkpoint_files]
latest_epoch = max(epochs, default=0)  # Use default=0 to handle the case when the list is empty

# Load the trained model weights
latest_checkpoint_file = f'checkpoint_{latest_epoch}.pth'
print(f"Loading checkpoint '{latest_checkpoint_file}'...")
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model.eval()

converter = RegionTensorConverter()

# Loop to take user input and perform inference
while True:
    user_input = input("Enter your text input (or type 'exit' to stop): ")
    if user_input.lower() == 'exit':
        break
    print(f"Input: {user_input}")

    # Get the embedding
    print("Getting embedding...")
    client = OpenAI()
    embedding = client.embeddings.create(input=user_input, model="text-embedding-ada-002").data[0].embedding
    input_tensor = torch.tensor(embedding).unsqueeze(0)  # Add batch dimension
    input_tensor = input_tensor.float()
    input_tensor = input_tensor.to(device)
    print(f"Embedding: {input_tensor.shape}")

    # Perform inference
    with torch.no_grad():
        print("Performing inference...")
        output = model(input_tensor)
        print(f"Output: {output.shape}")

    # Process result
    predicted_tokens = torch.argmax(output, dim=1)
    predicted_tokens = predicted_tokens.squeeze(0)
    print(f"Predicted Tokens: {predicted_tokens.shape}")

    # Convert the output tensor to a schematic
    print("Converting output tensor to schematic...")
    region = converter.tensor_to_region(predicted_tokens)
    print("Conversion complete.")

    # Save the schematic to a file
    print("Saving schematic to file...")
    # try:
    #     schematic = region.as_schematic()
    #     schematic.save('test.litematic')
    # except:
    #     print("Failed to save litematica schematic to file.")
    # try:
    #     structure_nbt = region.to_structure_nbt()
    #     structure_nbt.save('test.nbt')
    # except:
    #     print("Failed to save NBT schematic to file.")
    sponge_nbt = region.to_sponge_nbt()
    sponge_nbt.save(f'{user_input.lower().replace(" ", "")}.schem')
    print("Schematic saved to file.")

In [13]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

from schempy.schematic import Block, Schematic, BlockEntity

schematic = Schematic.from_file(Path('sponge.3.schem'))
print(schematic.metadata)

# Usage example
# schematic = Schematic(width=10, height=10, length=10)
schematic.metadata['Description'] = "This is a schematic generated by SchemPy"

# Set a block at coordinates (x=1, y=2, z=3) to a specific value, e.g., 42
block = Block("minecraft:andesite")
schematic.set_block(1, 2, 3, block)
block = Block("minecraft:oak_planks")
schematic.set_block(0, 0, 0, block)

# Retrieve the block value at coordinates (x=1, y=2, z=3)
block = schematic.get_block(8, 9, 0)
print(block)
block_entity = BlockEntity("minecraft:chest", 0, 0, 0, {"LootTable": "minecraft:chests/simple_dungeon"})
schematic.add_block_entity(block_entity)

schematic.save_to_file(Path('example.schem'), 3)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Compound({'Date': Long(1700278591692), 'WorldEdit': Compound({'Version': String('(unknown)'), 'EditingPlatform': String('enginehub:fabric'), 'Origin': IntArray([Int(0), Int(0), Int(0)]), 'Platforms': Compound({'enginehub:fabric': Compound({'Name': String('Fabric-Official'), 'Version': String('7.3.0-beta-02+e11f161')})})})})
minecraft:air
