In [1]:
# Preliminary setup of execution environment
import os
from pathlib import Path
import subprocess

nntile_dir = Path.cwd() / ".."

# Set environment variables
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Limit CUDA visibility
os.environ["OMP_NUM_THREADS"] = "1" # Disable BLAS parallelism
os.environ["PYTHONPATH"] = str(nntile_dir / "build" / "wrappers" / "python") # Path to a binary dir of NNTile Python wrappers

# All StarPU environment variables are available at https://files.inria.fr/starpu/doc/html/ExecutionConfigurationThroughEnvironmentVariables.html
os.environ["STARPU_NCPU"] = "1" # Use only 1 CPU core
os.environ["STARPU_NCUDA"] = "1" # Use only 1 CUDA device
os.environ["STARPU_SILENT"] = "1" # Do not show lots of StarPU outputs
os.environ["STARPU_SCHED"] = "dmdasd" # Name StarPU scheduler to be used
os.environ["STARPU_FXT_TRACE"] = "0" # Do not generate FXT traces
os.environ["STARPU_WORKERS_NOBIND"] = "1" # Do not bind workers (it helps if several instances of StarPU run in parallel)
os.environ["STARPU_PROFILING"] = "1" # This enables logging performance of workers and bandwidth of memory nodes
os.environ["STARPU_HOME"] = str(Path.cwd() / "starpu") # Main directory in which StarPU stores its configuration files
os.environ["STARPU_PERF_MODEL_DIR"] = str(Path(os.environ["STARPU_HOME"]) / "sampling") # Main directory in which StarPU stores its performance model files
os.environ["STARPU_PERF_MODEL_HOMOGENEOUS_CPU"] = "1" # Assume all CPU cores are equal
os.environ["STARPU_PERF_MODEL_HOMOGENEOUS_CUDA"] = "1" # Assume all CUDA devices are equal
os.environ["STARPU_HOSTNAME"] = "GPT2_example" # Force the hostname to be used when managing performance model files
os.environ["STARPU_FXT_PREFIX"] = str(Path(os.environ["STARPU_HOME"]) / "fxt") # Directory to store FXT traces if enabled

# NNTile-related
os.environ["NNTILE_LOGGER"] = "1" # Enable logger
os.environ["NNTILE_LOGGER_SERVER_ADDR"] = "127.0.0.1" # Logger will be running on the localhost
os.environ["NNTILE_LOGGER_SERVER_PORT"] = "5001" # Port for logger server
os.environ["NNTILE_LOGGER_CLIENT_PORT"] = "6006" # Port for client web interface of the logger
os.environ["NNTILE_LOGGER_SERVER_DIR"] = str(Path.cwd() / "logs") # Directory to store logs on the logger server

In [2]:
# Launch logger if needed
if os.getenv("NNTILE_LOGGER", "0") != "0":
    logger_env = os.environ.copy()
    logger_env.update({
        "LOG_DIR": os.getenv("NNTILE_LOGGER_SERVER_DIR"),
        "SPLIT_HOURS": "720",
        "CLEAR_LOGS": "0",
        "SERVER_PORT": os.getenv("NNTILE_LOGGER_SERVER_PORT")
    })
    logger_proc = subprocess.Popen(["python", nntile_dir / "logger" / "server.py"], env=logger_env)

2024-08-09 10:59:55.660722: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-09 10:59:55.694668: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
# Launch an external python process to finetune a pretrained GPT2 model on WikiText-103
# If logger server is launched, then TensorBoard results can be accessed at localhost:6006
!python ../wrappers/python/examples/gpt2_training.py --nntile-nepochs=100 --nntile-logger --nntile-logger-server-addr=127.0.0.1

Namespace(model='gpt2', pretrained='remote', checkpoint_path='', config_path='', save_checkpoint_path='.model', optimizer='adam', model_path='.model', seq_len_tile=1024, batch_size=1, minibatch_size=1, minibatch_size_tile=1, n_embd_tile=384, n_inner_tile=1536, n_head_tile=-1, torch_device='cpu', torch_dtype='fp32', torch_compile=False, nntile_dtype='fp32', check=False, check_fp64=False, torch_nforward=0, torch_nforward_warmup=0, torch_nbackward=0, torch_nbackward_warmup=0, nntile_restrict=None, nntile_flashattention=False, nntile_use_redux=False, nntile_nforward=0, nntile_nforward_warmup=0, nntile_nbackward=0, nntile_nbackward_warmup=0, dataset='WikiText-103', dataset_path='.data', dataset_select=100, lr=0.0, torch_nepochs=0, torch_nepochs_warmup=0, nntile_nepochs=100, nntile_nepochs_warmup=0, nntile_logger=True, nntile_logger_server_addr='127.0.0.1', nntile_logger_server_port=5001)
Trying to connect to 127.0.0.1:5001
WORKER COUNT: 2
BUS COUNT: 2
MEMNODES COUNT: 2
IS initialized : 1
St