## Imports & preparatory steps

In [None]:
import os
import shutil
import os.path as osp
import yaml

import torch
from torch import nn
import torch.nn.functional as F

from accelerate import Accelerator
# from accelerate.utils import LoggerType

from transformers import AdamW
from transformers import AlbertConfig, AlbertModel
from accelerate import DistributedDataParallelKwargs

from model import MultiTaskModel
from dataloader import build_dataloader
from utils import length_to_mask# , scan_checkpoint

from datasets import load_from_disk
# from torch.utils.tensorboard import SummaryWriter

from torch import __version__ as torch_version
from platform import python_version

# Check CUDA is available
assert torch.cuda.is_available(), "CPU training is not allowed."

# Check the number of CPUs
N_CPUS = int(os.environ["PBS_NUM_PPN"])

# Limit CPU operation in pytorch to `N_CPUS`
torch.set_num_threads(N_CPUS)
torch.set_num_interop_threads(N_CPUS)

# Set username
USER = os.environ["USER"]

# GPU
n_gpus = torch.cuda.device_count()
# nvidia_smi.nvmlInit()

print(" > Computational resources...")
print(f" | > Number of CPUs: {N_CPUS}")
print(f" | > Number of GPUs: {n_gpus}")
# for idx in range(n_gpus):
#    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(idx)
#    print(f" | > Device {idx}: {nvidia_smi.nvmlDeviceGetName(handle)}")
print(" > Python & module versions...")
print(f" | > Python:    {python_version()}")
print(f" | > PyTorch:   {torch_version}")

## Settings

In [None]:
# Check interactive mode
INTERACTIVE_MODE = bool("JupyterLab" in os.environ["PBS_JOBNAME"])

In [None]:
log_dir = "Checkpoint"
mixed_precision = "fp16"
data_folder = "cz-phon-sentences.processed"
batch_size = 32
save_interval = 100
log_interval = 10
num_process = 1 # number of GPUs
num_steps = 1000000

dataset_params = {
    # tokenizer: "transfo-xl-wt103"
    "tokenizer": "/storage/plzen4-ntis/home/jmatouse/experimenty/BERT_cs_phn_ipa",
    "token_separator": " ", # token used for phoneme separator (space)
    "token_mask": "M", # token used for phoneme mask (M)
    "word_separator": 291686, # token used for word separator (ʃajmijef)
    # word_separator: 3039 # token used for word separator (<formula>)
    # token_maps: "token_maps.pkl" # token map path
    
    "max_mel_length": 512, # max phoneme length
    
    "word_mask_prob": 0.15, # probability to mask the entire word
    "phoneme_mask_prob": 0.1, # probability to mask each phoneme
    "replace_prob": 0.2, # probablity to replace phonemes
}
    
model_params = {
    "vocab_size": 179, # 178
    "hidden_size": 768,
    "num_attention_heads": 12,
    "intermediate_size": 2048,
    "max_position_embeddings": 512,
    "num_hidden_layers": 12,
    "dropout": 0.1,
}

## Copy data to scratch dir

In [None]:
scratch_dir = os.environ["SCRATCHDIR"]
if not INTERACTIVE_MODE:
    # Copy dataset
    # Prepare dataset dir in the scratch
    print(f"> Copying data to local scratch: {scratch_dir}")
    local_data_folder = os.path.join(scratch_dir, os.path.basename(data_folder))
    shutil.copytree(data_folder, local_data_folder)
    local_tokenizer_folder = os.path.join(scratch_dir, os.path.basename(dataset_params["tokenizer"]))
    shutil.copytree(dataset_params["tokenizer"], local_tokenizer_folder)
    # Store the scratch dataset so that it is used for training
    data_folder = local_data_folder
    dataset_params["tokenizer"] = local_tokenizer_folder

## Create/update config file

In [None]:
config = {
    "log_dir": log_dir,
    "mixed_precision": mixed_precision,
    "data_folder": data_folder,
    "batch_size": batch_size,
    "save_interval": save_interval,
    "log_interval": log_interval,
    "num_process": num_process, # number of GPUs
    "num_steps": 1000000,
    "dataset_params": dataset_params,
    "model_params": model_params,
}

config_file = os.path.join(scratch_dir, "config.yml")
# Write to a YAML file
with open(config_file, 'w') as file:
    yaml.dump(config, file)

## Run training script

In [None]:
from transformers import AutoTokenizer # , AutoModel
tokenizer = AutoTokenizer.from_pretrained(config['dataset_params']['tokenizer'])
print('Vocab size:', tokenizer.vocab_size)

dataset = load_from_disk(config["data_folder"])
print(f'Dataset:\n{dataset}')

In [None]:
criterion = nn.CrossEntropyLoss() # F0 loss (regression)

best_loss = float('inf')  # best test loss
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

num_steps = config['num_steps']
log_interval = config['log_interval']
save_interval = config['save_interval']

In [None]:
def train():
    
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    
    curr_steps = 0
    log_dir = config['log_dir']
    
    if not osp.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)
    shutil.copy(config_file, osp.join(log_dir, osp.basename(config_file)))
    
    batch_size = config["batch_size"]
    train_loader = build_dataloader(
        dataset,
        batch_size=batch_size,
        num_workers=0,
        dataset_config=config['dataset_params']
    )

    albert_base_configuration = AlbertConfig(**config['model_params'])
    
    bert = AlbertModel(albert_base_configuration)
    bert = MultiTaskModel(
        bert,
        # num_vocab=1 + max([m['token'] for m in token_maps.values()]),
        num_vocab=tokenizer.vocab_size,
        num_tokens=config['model_params']['vocab_size'],
        hidden_size=config['model_params']['hidden_size']
    )
    
    load = True
    try:
        ckpts = []
        for f in os.listdir(log_dir):
            if f.startswith("step_"): ckpts.append(f)

        iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
        iters = sorted(iters)[-1]
    except:
        iters = 0
        load = False
    
    optimizer = AdamW(bert.parameters(), lr=1e-4)
    accelerator = Accelerator(mixed_precision=config['mixed_precision'], split_batches=True, kwargs_handlers=[ddp_kwargs])
    
    if load:
        checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
        state_dict = checkpoint['net']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v

        bert.load_state_dict(new_state_dict, strict=False)
        
        accelerator.print('Checkpoint loaded.')
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    bert, optimizer, train_loader = accelerator.prepare(
        bert, optimizer,
        train_loader
    )

    accelerator.print('Start training...')

    # Training is stopped after the defined number of steps
    # => just set up a high upper bound in the range
    for epoch in range(1, 1000):
        running_loss = 0
        
        for _, batch in enumerate(train_loader):        
            curr_steps += 1
            
            words, labels, phonemes, input_lengths, masked_indices = batch
            
            text_mask = length_to_mask(torch.Tensor(input_lengths)).to('cuda')        
            tokens_pred, words_pred = bert(phonemes, attention_mask=(~text_mask).int())
            
            loss_vocab = 0
            for _s2s_pred, _text_input, _text_length, _masked_indices in zip(words_pred, words, input_lengths, masked_indices):
                loss_vocab += criterion(_s2s_pred[:_text_length], _text_input[:_text_length])
            loss_vocab /= words.size(0)
            
            loss_token = 0
            sizes = 1
            for _s2s_pred, _text_input, _text_length, _masked_indices in zip(tokens_pred, labels, input_lengths, masked_indices):
                if len(_masked_indices) > 0:
                    _text_input = _text_input[:_text_length][_masked_indices]
                    loss_tmp = criterion(
                        _s2s_pred[:_text_length][_masked_indices], 
                        _text_input[:_text_length]
                    ) 
                    loss_token += loss_tmp
                    sizes += 1
            loss_token /= sizes

            loss = loss_vocab + loss_token

            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()

            running_loss += loss.item()

            iters = iters + 1
            if (iters+1) % log_interval == 0:
                accelerator.print ('Epoch %2d: Step [%d/%d], Loss: %.5f, Vocab Loss: %.5f, Token Loss: %.5f'
                        %(epoch, iters+1, num_steps, running_loss/log_interval, loss_vocab, loss_token))
                running_loss = 0
                
            if (iters+1) % save_interval == 0:
                accelerator.print('Saving..')

                state = {
                    'net':  bert.state_dict(),
                    'step': iters,
                    'optimizer': optimizer.state_dict(),
                }

                accelerator.save(state, log_dir + '/step_' + str(iters + 1) + '.t7')

            if curr_steps > num_steps:
                return

from accelerate import notebook_launcher
while True:
    notebook_launcher(train, args=(), num_processes=1, use_port=33389)

In [None]:
train()

## Cleanup

In [None]:
if not INTERACTIVE_MODE:
    # Delete all files and subdirectories in the directory
    for filename in os.listdir(scratch_dir):
        file_path = os.path.join(scratch_dir, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)  # remove file or symlink
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)  # remove directory
        except Exception as e:
            print(f'Failed to delete {file_path}.')
            print(f'Reason: {e}')
