In [1]:
import os
import shutil
import os.path as osp

import torch
from torch import nn
import torch.nn.functional as F

from accelerate import Accelerator
from accelerate.utils import LoggerType

from torch.optim import AdamW
from transformers import AlbertConfig, AlbertModel
from accelerate import DistributedDataParallelKwargs

from model import MultiTaskModel
from dataloader import build_dataloader
from utils import length_to_mask, scan_checkpoint

from datasets import load_from_disk

from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import yaml
import pickle

config_path = "Configs/config.yml" # you can change it to anything else
config = yaml.safe_load(open(config_path))

In [3]:
import pickle

with open(config['dataset_params']['token_maps'], 'rb') as handle:
    token_maps = pickle.load(handle)

In [4]:
import os
os.environ['TRUST_REMOTE_CODE'] = 'True'

from transformers import TransfoXLTokenizer, TransfoXLModel
tokenizer = TransfoXLTokenizer.from_pretrained(config['dataset_params']['tokenizer'])

`TransfoXL` was deprecated due to security issues linked to `pickle.load` in `TransfoXLTokenizer`. See more details on this model's documentation page: `https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/transfo-xl.md`.


In [5]:
criterion = nn.CrossEntropyLoss() # F0 loss (regression)

best_loss = float('inf')  # best test loss
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
loss_train_record = list([])
loss_test_record = list([])

num_steps = config['num_steps']
log_interval = config['log_interval']
save_interval = config['save_interval']

In [6]:
checkpoint_path = "/workspace/src/PL-BERT-ID/step_1000000.t7"  
fine_tune = True  

In [7]:
def train():
    
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    
    # Initialize accelerator first
    accelerator = Accelerator(mixed_precision=config['mixed_precision'], split_batches=True, kwargs_handlers=[ddp_kwargs])
    
    curr_steps = 0
    
    dataset = load_from_disk(config["data_folder"])

    log_dir = config['log_dir']
    if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
    shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
    
    batch_size = config["batch_size"]
    train_loader = build_dataloader(dataset, 
                                    batch_size=batch_size, 
                                    num_workers=0, 
                                    dataset_config=config['dataset_params'])

    albert_base_configuration = AlbertConfig(**config['model_params'])
    
    bert = AlbertModel(albert_base_configuration)
    bert = MultiTaskModel(bert, 
                          num_vocab=1 + max([m['token'] for m in token_maps.values()]), 
                          num_tokens=config['model_params']['vocab_size'],
                          hidden_size=config['model_params']['hidden_size'])
    
    load = True
    iters = 0  # Initialize iters
    
    if fine_tune and osp.exists(checkpoint_path):
        # Fine-tune dari checkpoint spesifik
        checkpoint_step = int(checkpoint_path.split('_')[-1].split('.')[0])  
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        state_dict = checkpoint['net']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] if k.startswith('module.') else k  
            new_state_dict[name] = v
        bert.load_state_dict(new_state_dict, strict=False)
        # Reset iters untuk fine-tuning, atau gunakan step dari checkpoint
        iters = 0  # Mulai dari 0 untuk fine-tuning
        # Atau jika ingin melanjutkan: iters = checkpoint_step
        accelerator.print(f'Fine-tuning from checkpoint: {checkpoint_path}, starting from step {iters}')
    else:
        try:
            files = os.listdir(log_dir)
            ckpts = []
            for f in os.listdir(log_dir):
                if f.startswith("step_"): ckpts.append(f)

            checkpoint_iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
            if checkpoint_iters:
                iters = sorted(checkpoint_iters)[-1]
            else:
                iters = 0
                load = False
        except:
            iters = 0
            load = False
    
    optimizer = AdamW(bert.parameters(), lr=1e-4)
    
    if load and not fine_tune and iters > 0:
        checkpoint_file = log_dir + "/step_" + str(iters) + ".t7"
        if osp.exists(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location='cpu')
            optimizer.load_state_dict(checkpoint['optimizer'])
            accelerator.print(f'Loaded checkpoint from step {iters}')
    
    bert, optimizer, train_loader = accelerator.prepare(
        bert, optimizer, train_loader
    )

    accelerator.print(f'Start training from step {iters}...')

    running_loss = 0
    
    for _, batch in enumerate(train_loader):        
        # Check if we've reached the maximum steps
        if iters >= num_steps:
            accelerator.print(f'Reached maximum steps ({num_steps}). Stopping training.')
            return
            
        curr_steps += 1

        batch = [b.to(accelerator.device) if hasattr(b, "to") else b for b in batch]
        words, labels, phonemes, input_lengths, masked_indices = batch

        text_mask = length_to_mask(torch.Tensor(input_lengths)).to(accelerator.device)
        tokens_pred, words_pred = bert(phonemes, attention_mask=(~text_mask).int())
        
        loss_vocab = 0
        for _s2s_pred, _text_input, _text_length, _masked_indices in zip(words_pred, words, input_lengths, masked_indices):
            loss_vocab += criterion(_s2s_pred[:_text_length], 
                                        _text_input[:_text_length])
        loss_vocab /= words.size(0)
        
        loss_token = 0
        sizes = 1
        for _s2s_pred, _text_input, _text_length, _masked_indices in zip(tokens_pred, labels, input_lengths, masked_indices):
            if len(_masked_indices) > 0:
                _text_input = _text_input[:_text_length][_masked_indices]
                loss_tmp = criterion(_s2s_pred[:_text_length][_masked_indices], 
                                            _text_input[:_text_length]) 
                loss_token += loss_tmp
                sizes += 1
        loss_token /= sizes

        loss = loss_vocab + loss_token

        optimizer.zero_grad()
        accelerator.backward(loss)
        optimizer.step()

        running_loss += loss.item()

        iters += 1  # Increment setelah training step
        
        if iters % log_interval == 0:
            # Fix the warning by detaching tensors
            accelerator.print ('Step [%d/%d], Loss: %.5f, Vocab Loss: %.5f, Token Loss: %.5f'
                    %(iters, num_steps, running_loss / log_interval, loss_vocab.detach().item(), loss_token.detach().item()))
            running_loss = 0
            
        if iters % save_interval == 0:
            accelerator.print('Saving..')

            state = {
                'net':  bert.state_dict(),
                'step': iters,
                'optimizer': optimizer.state_dict(),
            }

            accelerator.save(state, log_dir + '/step_' + str(iters) + '.t7')

In [None]:
from accelerate import notebook_launcher
while True:
    notebook_launcher(train, args=(), num_processes=1, use_port=33389)

Launching training on one GPU.
177
Fine-tuning from checkpoint: /workspace/src/PL-BERT-ID/step_1000000.t7, starting from step 0
Start training from step 0...
Step [10/1000000], Loss: 15.40331, Vocab Loss: 6.71846, Token Loss: 4.32557
Step [20/1000000], Loss: 7.33390, Vocab Loss: 2.65711, Token Loss: 3.35914
Step [30/1000000], Loss: 5.58973, Vocab Loss: 2.03578, Token Loss: 3.02255
Step [40/1000000], Loss: 4.95343, Vocab Loss: 1.68984, Token Loss: 2.93480
Step [50/1000000], Loss: 4.42378, Vocab Loss: 1.38884, Token Loss: 2.83454
Step [60/1000000], Loss: 4.14844, Vocab Loss: 1.14820, Token Loss: 2.65726
