# Imports & preparatory steps

In [None]:
import os
import os.path as osp
import subprocess
import torch
import shutil
from ruamel.yaml import YAML
import nvidia_smi
from torch import __version__ as torch_version
from platform import python_version

# Check CUDA is available
assert torch.cuda.is_available(), "CPU training is not allowed."

# Check the number of CPUs
# $PBS_NUM_PPN vs $OMP_NUM_THREADS?
N_CPUS = int(os.environ["PBS_NUM_PPN"])

# Limit CPU operation in pytorch to `N_CPUS`
torch.set_num_threads(N_CPUS)
torch.set_num_interop_threads(N_CPUS)

# Set username
USER = os.environ["USER"]

# GPU
n_gpus = torch.cuda.device_count()
nvidia_smi.nvmlInit()
# deviceCount = nvidia_smi.nvmlDeviceGetCount()

print(" > Computational resources...")
print(f" | > Number of CPUs: {N_CPUS}")
print(f" | > Number of GPUs: {n_gpus}")
for idx in range(n_gpus):
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(idx)
    print(f" | > Device {idx}: {nvidia_smi.nvmlDeviceGetName(handle)}")
print(" > Python & module versions...")
print(f" | > Python:    {python_version()}")
print(f" | > PyTorch:   {torch_version}")

nvidia_smi.nvmlShutdown()

# Settings

In [None]:
# Check interactive mode
INTERACTIVE_MODE = bool("JupyterLab" in os.environ["PBS_JOBNAME"])

In [None]:
# Change wav path to scratch
def change_path_to_scratch(scratch_dir, data_dir, inp_data_file):
    out_data_file = osp.join(scratch_dir, data_dir, osp.basename(inp_data_file))
    with open(inp_data_file, encoding="utf-8", ) as finp, open(out_data_file, "w", encoding="utf-8") as fout:
        for line in finp:
            wpath, ph, spk_id = line.strip().split("|")
            wpath = osp.join(scratch_dir, wpath)
            print(f"{wpath}|{ph}|{spk_id}", file=fout)
    return out_data_file

# Floats are not written in scientific notation when storing to YAML
def float_representer(representer, data):
    value = '{0:.15f}'.format(data).rstrip('0')
    return representer.represent_scalar('tag:yaml.org,2002:float', value)

In [None]:
log_dir = "checkpoints"
save_freq = 1 # 10
device = "cuda"
epochs = 200
batch_size = 256
pretrained_model = ""
train_data = "data/train_SPT-MGW.processed0.csv"
val_data = "data/val_SPT-MGW.processed.csv"
show_progress = False

preprocess_params = {
  "sr": 24000,
  "spect_params": {
    "n_fft": 2048,
    "win_length": 1200,
    "hop_length": 300,
  },
  "mel_params": {
    "n_mels": 80,
  },
}

model_params = {
   "input_dim": 80,
   "hidden_dim": 256,
   "n_token": 81,
   "token_embedding_dim": 256,
}

optimizer_params = {
  "lr": 0.0005,
}

# Copy data to scratch dir

In [None]:
# scratch_dir = os.environ["SCRATCHDIR"]
# if not INTERACTIVE_MODE:
#     # Copy dataset
#     # Prepare dataset dir in the scratch
#     print(f" > Copying train data to local scratch: {scratch_dir}")
#     train_data, n_data = copy_data_to_scratch(train_data, scratch_dir)
#     print(f" | > ... {n_data} files copied")
#     print(f" > Copying val data to local scratch: {scratch_dir}")
#     val_data, n_data = copy_data_to_scratch(val_data, scratch_dir)
#     print(f" | > ... {n_data} files copied")

In [None]:
# Set up local scratch directory
scratch_dir = os.environ["SCRATCHDIR"]
# Check "interactive mode" => data are copied when run is scheduled
if not INTERACTIVE_MODE:
    # Get data dir
    data_dir = osp.dirname(osp.abspath(train_data))
    # Copy data from remote storage & prepare dataset dir in the scratch
    print(f"> Copying data to local scratch: {scratch_dir}")
    storage_path, storage_dir = osp.split(data_dir.rstrip(os.sep))

    # Define command for SSH and tar on remote server
    ssh_command = [
        'ssh',
        'storage-plzen4.kky.zcu.cz',
        f"tar -h -C {storage_path} -cf - {storage_dir}"
    ]
    # Define command for local tar extraction
    tar_command = [
        'tar',
        '-xf',
        '-',
        '-C',
        scratch_dir
    ]
    # Run 1st process (SSH and remote tar)
    ssh_process = subprocess.Popen(ssh_command, stdout=subprocess.PIPE, text=True)
    # Run 2nd process (local tar), which reads output form the 1st process
    tar_process = subprocess.Popen(tar_command, stdin=ssh_process.stdout, text=True)
    # Close output pipe from 1st process
    ssh_process.stdout.close()
    # Wait until processes finish
    ssh_returncode = ssh_process.wait()
    tar_returncode = tar_process.wait()
    # Verify process finished successfully
    if ssh_returncode != 0:
        print(f"SSH proces skončil s chybou: {ssh_returncode}")
    if tar_returncode != 0:
        print(f"Tar proces skončil s chybou: {tar_returncode}")

    # Store the scratch dataset so that it is used for training
    train_data = change_path_to_scratch(scratch_dir, storage_dir, train_data)
    val_data = change_path_to_scratch(scratch_dir, storage_dir, val_data)

# Create/update config file

In [None]:
config = {
    "log_dir": log_dir,
    "save_freq": save_freq,
    "device": device,
    "epochs": epochs,
    "batch_size": batch_size,
    "pretrained_model": pretrained_model,
    "train_data": train_data,
    "val_data": val_data,
    "show_progress": show_progress,
    "preprocess_params": preprocess_params,
    "model_params": model_params,
    "optimizer_params": optimizer_params,
}

config_file = os.path.join(scratch_dir, "config.yml")
# Write to a YAML file
yaml = YAML()
yaml.representer.add_representer(float, float_representer)
yaml.default_flow_style = False
with open(config_file, 'w') as f:
    yaml.dump(config, f)

## Run training script

In [None]:
print( " > Start training...")
print(f" | > Batch size: {batch_size}")
print(f" | > Epochs:     {epochs}")
print(f" | > # workers:  {N_CPUS}")

!python train.py --num_workers={N_CPUS} {config_file}

# Cleanup

In [None]:
if not INTERACTIVE_MODE:
    # Delete all files and subdirectories in the directory
    for filename in os.listdir(scratch_dir):
        file_path = os.path.join(scratch_dir, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)  # remove file or symlink
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)  # remove directory
        except Exception as e:
            print(f'Failed to delete {file_path}. Reason: {e}')