In [None]:
# default_exp core

# core

> Core routines for shazbot

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export 
import torch 
import accelerate
import tqdm

## Audio utils

In [None]:
#export
def is_silence(
    audio,       # torch tensor of multichannel audio
    thresh=-70,  # threshold in dB below which we declare to be silence
    ):
    "checks if entire clip is 'silence' below some dB threshold"
    dBmax = 20*torch.log10(torch.flatten(audio.abs()).max()).cpu().numpy()
    return dBmax < thresh

In [None]:
# code tests
import torch 

x = torch.ones((2,10))
assert not is_silence(1e-3*x) # not silent
assert is_silence(1e-5*x) # silent
assert is_silence(1e-3*x, thresh=-50) # higher thresh

## Parallelism utils

In [None]:
#|export        
class HostPrinter():
    "lil accelerate utility for only printing on host node"
    def __init__(self, accelerator):
        self.accelerator = accelerator
    def __call__(self, s:str):
        if self.accelerator.is_main_process:
            print(s, flush=True)

In [None]:
#test hostprinter
accelerator = accelerate.Accelerator()
device = accelerator.device
hprint = HostPrinter(accelerator)
hprint(f'Using device: {device}')

Using device: cuda


## Utils for PyTorch models

In [None]:
#|export 
def save(accelerator, args, model, opt=None, epoch=None, step=None):
    "for checkpointing & model saves"
    accelerator.wait_for_everyone()
    filename = f'{args.name}_{step:08}.pth' if (step is not None) else f'{args.name}.pth'
    if accelerator.is_main_process:
        tqdm.write(f'Saving to {filename}...')
    obj = {'model': accelerator.unwrap_model(model).state_dict() }
    if opt is not None:   obj['opt'] = opt.state_dict()
    if epoch is not None: obj['epoch'] = epoch
    if step is not None:  obj['step'] = step
    accelerator.save(obj, filename)
    

def n_params(module):
    """Returns the number of trainable parameters in a module."""
    return sum(p.numel() for p in module.parameters())


def freeze(model):
    "freezes model weights; turns off gradient info "
    for param in model.parameters():  
        param.requires_grad = False