In [None]:
from os.path import join as pjoin
from pprint import pformat
import yaml
from sklearn.metrics import average_precision_score, roc_auc_score
import numpy as np
from collections import defaultdict
import gc
import copy, math

from personal_VAD.utils import *
from personal_VAD.dataloader_specOnline import get_dataloader as get_online_dataloader
from personal_VAD.dataloader import get_dataloader
from personal_VAD.models.ASPVAD import TotalLoss

import torch
import torch.amp as amp
torch.set_float32_matmul_precision('high')

In [None]:
class Metrics:
    def __init__(self, mode='train', thres=0.5, eps=1e-8):
        self.eps = eps
        self.mode = mode
        self.thres = thres
        self.reset()

    def reset(self):
        self.tp, self.fp, self.tn, self.fn = 0, 0, 0, 0
        self.emb_correct = 0
        self.total_spk = 0
        self.loss_sum = 0
        self.loss_count = 0
        if self.mode!='train':
            self.logits = []
            self.labels = []

    def update(self, loss, logits, labels, spk_logits, spk_labels):
        preds = (logits > self.thres).astype(int)
        self.tp += np.sum((preds == 1) & (labels == 1))
        self.fp += np.sum((preds == 1) & (labels == 0))
        self.tn += np.sum((preds == 0) & (labels == 0))
        self.fn += np.sum((preds == 0) & (labels == 1))

        pred_spk = np.argmax(spk_logits, axis=-1)
        self.emb_correct += np.sum(pred_spk == spk_labels)
        self.total_spk += len(spk_labels)

        self.loss_sum += loss
        self.loss_count += 1

        if self.mode!='train':
            self.logits.append(logits)
            self.labels.append(labels)

    def compute(self):
        if self.total_spk==0: return defaultdict(float)

        avg_loss = self.loss_sum / self.loss_count
        acc = (self.tp + self.tn) / (self.tp + self.fp + self.tn + self.fn)
        spk_acc = self.emb_correct / self.total_spk
        if self.mode=='train': return {'loss': avg_loss, 'acc': acc, 'spk_acc': spk_acc}

        precision = self.tp / (self.tp + self.fp + self.eps)
        recall = self.tp / (self.tp + self.fn + self.eps)
        f1 = 2 * precision * recall / (precision + recall + self.eps)
        logits = np.concatenate(self.logits)
        labels = np.concatenate(self.labels)
        ap = average_precision_score(labels, logits)
        auc = roc_auc_score(labels, logits)
        return {'loss': avg_loss, 'acc': acc, 'spk_acc': spk_acc, 'f1': f1, 'ap': ap, 'roc': auc}

In [None]:
def train(**kwargs):
    with open(kwargs['config']) as con_read: yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
    args = dict(yaml_config, **kwargs)
    exp_dir = validate_path(args['exp_dir'], is_dir=True)
    model_dir = validate_path(pjoin(exp_dir,"models"), is_dir=True)
    with open(pjoin(exp_dir, 'config.yaml'), 'w') as fout: fout.write(yaml.dump(args))

    logger = get_logger(exp_dir, 'train.log')
    set_seed(args.get('seed', 0))
    device = get_device()
    logger.info(f"device: {device}")
    logger.info("<== Passed Arguments ==>")
    for line in pformat(args).split('\n'): logger.info(line)

    args['model_args']['num_speakers'] = args["dataset_args"]["speaker_num"]
    model, start_epoch = load_model(args['model'], args['model_args'], device=device, get_epoch=True, do_compile=args.get('do_compile', False))
    logger.info("<== Model ==>")
    logger.info(f'model size: {sum(param.numel() for param in model.parameters())}')
    #for line in pformat(model).split('\n'): logger.info(line)
    if start_epoch>1: logger.info(f'checkpoint loaded. start_epoch: {start_epoch}')

    dataloader_args = args['dataloader_args']
    dataset_args = args['dataset_args']
    train_dataloader = get_online_dataloader(dataset_args.get("train_dataset_args",None), dataloader_args)
    val_dataloader = get_dataloader(dataset_args.get("val_dataset_args",None), dataloader_args, eval=True)
    test_dataloader = get_dataloader(dataset_args.get("test_dataset_args",None), dataloader_args, eval=True)
    logger.info("<== Dataloaders ==>")
    #logger.info(f"train data num: {len(train_dataloader.dataset)}")
    logger.info(f"val data num: {len(val_dataloader.dataset)}")
    logger.info(f"test data num: {len(test_dataloader.dataset)}")

    criterion = TotalLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), **args['optimizer_args'])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **args['scheduler_args'])
    scaler = amp.GradScaler(device.type, enabled=(args.get('enable_amp',False) and device.type=='cuda'))

    logger.info("<========== Training process ==========>")
    logger.info("+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+")
    logger.info(f"|{'Epoch':>10}|{'Lr':>10}|{'Train_Loss':>10}|{'Train_Acc':>10}|{'Spk_Acc':>10}|{'Val_Loss':>10}|{'Val_Acc':>10}|{'Val_AP':>10}|{'Val_AUC':>10}|{'Val_F1':>10}|")
    logger.info("+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+")
    best_loss, best_loss_epoch, best_loss_model_state_dict = float('inf'), 0, None
    best_acc, best_acc_epoch, best_acc_model_state_dict = 0, 0, None
    best_f1, best_f1_epoch, best_f1_model_state_dict = 0, 0, None
    model.train()
    val_interval = args['val_interval']
    first_val = math.ceil(start_epoch / val_interval) * val_interval
    for i, step in enumerate(range(first_val, args['max_steps']+1, val_interval)):
        if i==0: train_len = first_val-start_epoch+1
        else: train_len = val_interval
        train_metrics = train_steps(model, train_dataloader, criterion, optimizer, scaler, train_len)
        val_metrics = evaluate(model, val_dataloader, criterion)
        train_loss = train_metrics['loss']
        val_loss = val_metrics['loss']
        val_acc = val_metrics['acc']
        val_f1 = val_metrics['f1']
        #scheduler.step()

        logger.info(f"|{step:>10}|{scheduler.get_last_lr()[0]:>10.4f}|{train_loss:>10.6f}|{train_metrics['acc']:>10.4f}|{train_metrics['spk_acc']:>10.4f}|{val_loss:>10.6f}|{val_metrics['acc']:>10.4f}|{val_metrics['ap']:>10.4f}|{val_metrics['roc']:>10.4f}|{val_metrics['f1']:>10.4f}|")

        if step % args['save_interval'] == 0:
            save_checkpoint(model, pjoin(model_dir, f'model_{step}.pt'))
        if best_loss > val_loss: best_loss, best_loss_epoch, best_loss_model_state_dict = val_loss, step, copy.deepcopy(model.state_dict())
        if best_acc < val_acc: best_acc, best_acc_epoch, best_acc_model_state_dict = val_acc, step, copy.deepcopy(model.state_dict())
        if best_f1 < val_f1: best_f1, best_f1_epoch, best_f1_model_state_dict = val_f1, step, copy.deepcopy(model.state_dict())

    logger.info("+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+")

    if best_loss_model_state_dict:
        save_checkpoint(best_loss_model_state_dict, pjoin(model_dir, f'best_loss_model_{best_loss_epoch}.pt'))
    if best_f1_model_state_dict:
        save_checkpoint(best_f1_model_state_dict, pjoin(model_dir, f'best_f1_model_{best_f1_epoch}.pt'))
    if best_acc_model_state_dict:
        save_checkpoint(best_acc_model_state_dict, pjoin(model_dir, f'best_acc_model_{best_acc_epoch}.pt'))
        model.load_state_dict(best_acc_model_state_dict)

    val_metrics = evaluate(model, val_dataloader, criterion)
    logger.info(f"<========== Best validation value : epoch {best_acc_epoch} ==========>")
    logger.info("+----------+----------+----------+----------+----------+")
    logger.info(f"|{'val_Loss':>10}|{'val_Acc':>10}|{'val_AP':>10}|{'val_AUC':>10}|{'val_F1':>10}|")
    logger.info("+----------+----------+----------+----------+----------+")
    logger.info(f"|{val_metrics['loss']:>10.6f}|{val_metrics['acc']:>10.4f}|{val_metrics['ap']:>10.4f}|{val_metrics['roc']:>10.4f}|{val_metrics['f1']:>10.4f}|")
    logger.info("+----------+----------+----------+----------+----------+")

    test_metrics = evaluate(model, test_dataloader, criterion)
    logger.info("<========== Test process ==========>")
    logger.info("+----------+----------+----------+----------+----------+")
    logger.info(f"|{'test_Loss':>10}|{'test_Acc':>10}|{'test_AP':>10}|{'test_AUC':>10}|{'test_F1':>10}|")
    logger.info("+----------+----------+----------+----------+----------+")
    logger.info(f"|{test_metrics['loss']:>10.6f}|{test_metrics['acc']:>10.4f}|{test_metrics['ap']:>10.4f}|{test_metrics['roc']:>10.4f}|{test_metrics['f1']:>10.4f}|")
    logger.info("+----------+----------+----------+----------+----------+")


def train_steps(model: torch.nn.Module, dataloader, criterion, optimizer, scaler, max_steps):
    device = next(model.parameters()).device
    model.train()
    enable_amp = scaler._enabled

    metrics = Metrics(mode='train')
    for i, batch in enumerate(dataloader):
        if i == max_steps:
            break

        simul = batch['simul'].to(device, non_blocking=True)    # (B,T,F)
        enroll = batch['enroll'].to(device, non_blocking=True)
        label = batch['label'].float().to(device, non_blocking=True)  # (B,T)
        enroll_length = batch['enroll_len'].to(device, non_blocking=True)
        simul_length = batch['simul_len'].to(device, non_blocking=True)
        spk = batch['spk'].to(device, non_blocking=True)

        with amp.autocast(device.type, enabled=enable_amp):
            logit, tilde_s, spk_logit = model(enroll, simul, enroll_length=enroll_length, simul_length=simul_length, spk_label=spk)
            loss, _, _, _ = criterion(logit, tilde_s, spk_logit, label, spk)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        mask = torch.arange(label.size(1), device=device)[None, :] < simul_length[:, None]  # (B,T)
        metrics.update(loss.item(),
                       logit[mask].cpu().detach().numpy(),
                       label[mask].cpu().detach().numpy(),
                       spk_logit.cpu().detach().numpy(),
                       spk.cpu().detach().numpy())

    #gc.collect()
    return metrics.compute()


@torch.no_grad()
def evaluate(model:torch.nn.Module, dataloader, criterion):
    device = next(model.parameters()).device
    model.eval()

    metrics = Metrics(mode='test')
    for batch in dataloader:
        simul = batch['simul'].to(device, non_blocking=True)    # (B,T,F)
        enroll = batch['enroll'].to(device, non_blocking=True)
        label = batch['label'].float().to(device, non_blocking=True)  # (B,T)
        enroll_length = batch['enroll_len'].to(device, non_blocking=True)
        simul_length = batch['simul_len'].to(device, non_blocking=True)
        spk = batch['spk'].to(device, non_blocking=True)

        logit, tilde_s, spk_logit = model(enroll, simul, enroll_length=enroll_length, simul_length=simul_length, spk_label=spk)
        _, loss, _, _ = criterion(logit, tilde_s, spk_logit, label, spk)

        mask = torch.arange(label.size(1), device=device)[None, :] < simul_length[:, None]  # (B,T)
        metrics.update(loss.item(),
                       logit[mask].cpu().detach().numpy(),
                       label[mask].cpu().detach().numpy(),
                       spk_logit.cpu().detach().numpy(),
                       spk.cpu().detach().numpy())
    return metrics.compute()

In [None]:
train(config=r'conf/libri.yaml')