Hyperparameter Tuning

In [1]:
import os
from datetime import datetime

import optuna
import torch
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from common.file_paths import BASE_DIR
from modules import (GenerateSchematicCallback,
                     LightningTransformerMinecraftStructureGenerator,
                     MinecraftDataModule)

torch.set_float32_matmul_precision('medium')


def objective(trial: optuna.Trial):
    seed_everything(0, workers=True)

    lightning_model = LightningTransformerMinecraftStructureGenerator(
        num_classes=250,
        max_sequence_length=512,
        embedding_dim=trial.suggest_categorical("embedding_dim", [16, 32, 64]),
        embedding_dropout=trial.suggest_float(
            "embedding_dropout", 0.1, 0.4, step=0.1),
        num_heads=trial.suggest_categorical("num_heads", [8, 16, 32]),
        num_layers=trial.suggest_int("num_layers", 2, 4, 6),
        decoder_dropout=trial.suggest_float(
            "decoder_dropout", 0.1, 0.4, step=0.1),
        freeze_encoder=True,
        learning_rate=trial.suggest_float(
            "learning_rate", 1e-5, 1e-3, log=True)
    )

    hdf5_file = os.path.join(BASE_DIR, 'data.h5')
    data_module = MinecraftDataModule(
        file_path=hdf5_file,
        batch_size=trial.suggest_categorical("batch_size", [32, 64]),
    )

    logger = TensorBoardLogger(
        'lightning_logs', name='minecraft_structure_generator', log_graph=False)
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        mode='min',
        save_last=True,
        save_weights_only=True
    )
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        patience=20,
        verbose=False,
        mode='min'
    )
    generate_schematic_callback = GenerateSchematicCallback(
        save_path='schematic_viewer/public/schematics/',
        data_module=data_module,
        generate_train=False,
        generate_val=True,
        generate_all_datasets=False,
        generate_every_n_epochs=10
    )

    trainer = Trainer(
        max_epochs=5000,
        logger=logger,
        gradient_clip_val=1.0,
        log_every_n_steps=5,
        callbacks=[
            checkpoint_callback,
            early_stop_callback,
            generate_schematic_callback
        ]
    )

    trainer.fit(lightning_model, datamodule=data_module)

    best_score = checkpoint_callback.best_model_score
    if best_score is not None:
        return best_score.item()
    elif "val_loss" in trainer.callback_metrics:
        return trainer.callback_metrics["val_loss"].item()
    else:
        return 0.25


study_name = "study_20231208214359"
# study_name = f"study_{datetime.now().strftime('%Y%m%d%H%M%S')}"
study = optuna.create_study(direction='minimize', study_name=study_name,
                            storage='sqlite:///studies.db', load_if_exists=True)
study.optimize(objective, n_trials=100)

print("Best hyperparameters:", study.best_trial.params)

  from .autonotebook import tqdm as notebook_tqdm
[I 2023-12-09 12:32:26,946] A new study created in RDB with name: study_20231208214359
Seed set to 0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs





LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                                   | Params
-----------------------------------------------------------------
0 | model | TransformerMinecraftStructureGenerator | 82.3 M
-----------------------------------------------------------------
15.9 M    Trainable params
66.4 M    Non-trainable params
82.3 M    Total params
329.131   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\mmmfr\Documents\Repositories\minecraft-schematic-generator\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


                                                                           

c:\Users\mmmfr\Documents\Repositories\minecraft-schematic-generator\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 64:  15%|█▌        | 83/551 [00:46<04:20,  1.79it/s, v_num=11] 

In [1]:
import os

import torch
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from common.file_paths import BASE_DIR
from modules import (GenerateSchematicCallback,
                     LightningTransformerMinecraftStructureGenerator,
                     MinecraftDataModule)

torch.set_float32_matmul_precision('medium')

seed_everything(0, workers=True)

lightning_model = LightningTransformerMinecraftStructureGenerator(
    num_classes=250,
    max_sequence_length=512,
    embedding_dim=32,
    embedding_dropout=0.2,
    num_heads=16,
    num_layers=2,
    decoder_dropout=0.2,
    freeze_encoder=True,
    learning_rate=1e-4
)
# lightning_model = LightningTransformerMinecraftStructureGenerator.load_from_checkpoint(
#     'lightning_logs/minecraft_structure_generator/version_7/checkpoints/last.ckpt')

hdf5_file = os.path.join(BASE_DIR, 'data.h5')
data_module = MinecraftDataModule(
    file_path=hdf5_file,
    batch_size=64,
)

logger = TensorBoardLogger(
    'lightning_logs', name='minecraft_structure_generator', log_graph=False)
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_last=True,
    save_weights_only=True
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=30,
    verbose=False,
    mode='min'
)
# generate_schematic_callback = GenerateSchematicCallback(
#     save_path='schematic_viewer/public/schematics/',
#     data_module=data_module,
#     generate_train=False,
#     generate_val=True,
#     generate_all_datasets=False,
#     generate_every_n_epochs=10
# )

trainer = Trainer(
    max_epochs=5000,
    logger=logger,
    gradient_clip_val=1.0,
    log_every_n_steps=5,
    callbacks=[
        checkpoint_callback,
        early_stop_callback,
        # generate_schematic_callback
    ]
)

trainer.fit(lightning_model, datamodule=data_module)

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs





LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                                   | Params
-----------------------------------------------------------------
0 | model | TransformerMinecraftStructureGenerator | 82.2 M
-----------------------------------------------------------------
15.8 M    Trainable params
66.4 M    Non-trainable params
82.2 M    Total params
328.805   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\mmmfr\Documents\Repositories\minecraft-schematic-generator\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


                                                                           

c:\Users\mmmfr\Documents\Repositories\minecraft-schematic-generator\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 3:  19%|█▊        | 51/275 [00:43<03:09,  1.18it/s, v_num=8] 

Running Best Model

In [2]:
import os

import optuna
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.profilers import SimpleProfiler

from common.file_paths import BASE_DIR
from modules import (GenerateSchematicCallback,
                     LightningTransformerMinecraftStructureGenerator,
                     MinecraftDataModule)

study_name = 'study_20231205030812'
storage_url = 'sqlite:///studies.db'

study = optuna.load_study(study_name=study_name, storage=storage_url)


def get_nth_best_trial(study, n, objective_id=0):
    # Sort the completed trials of the study object by the specified objective value.
    sorted_trials = sorted(
        study.trials, key=lambda t: t.values[objective_id] if t.values is not None else float('inf'))

    # Return the parameters of the n-th best trial.
    return sorted_trials[n].params


objective_id = 1
x = 0
nth_best_params = get_nth_best_trial(study, x, objective_id=0)
print(nth_best_params)

seed_everything(1, workers=True)

lightning_model = LightningTransformerMinecraftStructureGenerator(
    num_classes=20,
    max_sequence_length=512,
    freeze_encoder=True,
    **nth_best_params
)

hdf5_file = os.path.join(BASE_DIR, 'data.h5')
data_module = MinecraftDataModule(
    file_path=hdf5_file,
    batch_size=32,
    # num_workers=4
)

logger = TensorBoardLogger(
    'lightning_logs', name='minecraft_structure_generator', log_graph=False)
profiler = SimpleProfiler()
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_last=True
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=50,
    verbose=False,
    mode='min'
)
generate_schematic_callback = GenerateSchematicCallback(
    save_path='schematic_viewer/public/schematics/',
    data_module=data_module,
    generate_train=False,
    generate_val=True,
    generate_all_datasets=False,
    generate_every_n_epochs=10,
    autoregressive=True
)

trainer = Trainer(
    max_epochs=5000,
    logger=logger,
    profiler=profiler,
    gradient_clip_val=1.0,
    log_every_n_steps=5,
    callbacks=[
        checkpoint_callback,
        early_stop_callback,
        generate_schematic_callback
    ]
)

trainer.fit(lightning_model, datamodule=data_module)

Seed set to 1


{'embedding_dim': 64, 'embedding_dropout': 0.5, 'decoder_dim': 32, 'num_heads': 2, 'num_layers': 3, 'decoder_dropout': 0.2, 'learning_rate': 0.0006395506262517658}


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                                   | Params
-----------------------------------------------------------------
0 | model | TransformerMinecraftStructureGenerator | 66.9 M
-----------------------------------------------------------------
489 K     Trainable params
66.4 M    Non-trainable params
66.9 M    Total params
267.410   Total estimated model params size (MB)


Epoch 322: 100%|██████████| 21/21 [00:03<00:00,  5.85it/s, v_num=22]       Epoch 00323: reducing learning rate of group 0 to 6.3955e-05.
Epoch 369: 100%|██████████| 21/21 [00:06<00:00,  3.08it/s, v_num=22]Epoch 00370: reducing learning rate of group 0 to 6.3955e-06.
Epoch 380: 100%|██████████| 21/21 [00:03<00:00,  5.95it/s, v_num=22]Epoch 00381: reducing learning rate of group 0 to 6.3955e-07.
Epoch 391: 100%|██████████| 21/21 [00:03<00:00,  5.81it/s, v_num=22]Epoch 00392: reducing learning rate of group 0 to 6.3955e-08.
Epoch 402: 100%|██████████| 21/21 [00:03<00:00,  5.80it/s, v_num=22]Epoch 00403: reducing learning rate of group 0 to 6.3955e-09.
Epoch 408: 100%|██████████| 21/21 [00:03<00:00,  5.57it/s, v_num=22]


FIT Profiler Report

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                               	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                	|  -     