In [1]:
import pytorch_lightning as pl
import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf
from omegaconf import DictConfig
import remfx.utils as utils

log = utils.get_logger(__name__)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def run_train_test(cfg: DictConfig):
    # Apply seed for reproducibility
    if cfg.seed:
        pl.seed_everything(cfg.seed)
    log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
    datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
    log.info(f"Instantiating model <{cfg.model._target_}>.")
    model = hydra.utils.instantiate(cfg.model, _convert_="partial")

    # Init all callbacks
    callbacks = []
    if "callbacks" in cfg:
        for _, cb_conf in cfg["callbacks"].items():
            if "_target_" in cb_conf:
                log.info(f"Instantiating callback <{cb_conf._target_}>.")
                callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))

    logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
    log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
    trainer = hydra.utils.instantiate(
        cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
    )
    log.info("Logging hyperparameters!")
    utils.log_hyperparameters(
        config=cfg,
        model=model,
        datamodule=datamodule,
        trainer=trainer,
        callbacks=callbacks,
        logger=logger,
    )
    trainer.fit(model=model, datamodule=datamodule)
    trainer.test(model=model, datamodule=datamodule, ckpt_path="best")



In [3]:
%env DATASET_ROOT=./data/remfx-data
%env WANDB_PROJECT=RemFX
%env WANDB_ENTITY=mattricesound

env: DATASET_ROOT=./data/remfx-data
env: WANDB_PROJECT=RemFX
env: WANDB_ENTITY=mattricesound


In [4]:
global_overrides = [
    "render_root=./data/rendered", 
    "accelerator=cpu", 
    "datamodule.train_dataset.total_chunks=80",
    "datamodule.val_dataset.total_chunks=10",
    "datamodule.test_dataset.total_chunks=10"
]

## Task 1: Effect Classification

### Model 1: TCN

In [4]:
# TCN
with initialize(version_base=None, config_path="cfg"):
    cfg = compose(config_name="config.yaml", 
                  overrides=["+exp=distortion", "model=classifier_tcn"] + global_overrides)
    print("Task:", cfg.model._target_)
    print("Model:", cfg.model.network._target_)
    print("Config:")
    print(OmegaConf.to_yaml(cfg, resolve=True))

Task: remfx.models.FXClassifier
Model: remfx.tcn.TCN
Config:
seed: 12345
train: true
sample_rate: 48000
chunk_size: 262144
logs_dir: ./logs
render_files: true
render_root: /scratch/EffectSet
accelerator: gpu
log_audio: true
num_kept_effects:
- 0
- 4
num_removed_effects:
- 1
- 1
shuffle_kept_effects: true
shuffle_removed_effects: false
num_classes: 5
effects_to_keep:
- compressor
- reverb
- chorus
- delay
effects_to_remove:
- distortion
callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: valid_loss
    save_top_k: 1
    save_last: true
    mode: min
    verbose: false
    dirpath: ./logs/ckpts/2023-04-02-21-39-19
    filename: '{epoch:02d}-{valid_loss:.3f}'
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: step
  audio_logging:
    _target_: remfx.callbacks.AudioCallback
    sample_rate: 48000
    log_audio: true
  metric_logging:
    _target_: remfx.callbacks.MetricCallback
data

In [5]:
run_train_test(cfg)

Global seed set to 12345


Effect Summary: 
Apply kept effects: ['compressor', 'reverb', 'chorus', 'delay'] (Between 0-4, chosen randomly) -> Dry
Apply remove effects: ['distortion'] (1, chosen in order) -> Wet

Found 2889 files in VocalSet train.
Found 240 files in GuitarSet train.
Found 80 files in DSD100 train.
Found 460 files in IDMT-SMT-Drums train.
Total datasets: 4
Processing files...


InstantiationException: Error in call to target 'remfx.datasets.EffectDataset':
OSError(30, 'Read-only file system')
full_key: datamodule.train_dataset

### Model 2: Demucs (imported from torchaudio)

In [None]:
# Demucs Model
with initialize(version_base=None, config_path="cfg"):
    cfg = compose(config_name="config.yaml", 
                  overrides=["+exp=distortion", "model=classifier_demucs"] + global_overrides)
    print("Task:", cfg.model._target_)
    print("Model:", cfg.model.network._target_)
    print("Config:")
    print(OmegaConf.to_yaml(cfg, resolve=True))

In [None]:
run_train_test(cfg)

## Task 2: Effect Removal

### Model 1: TCN

In [None]:
# TCN
with initialize(version_base=None, config_path="cfg"):
    cfg = compose(config_name="config.yaml", 
                  overrides=["+exp=distortion", "model=tcn"] + global_overrides)
    print("Task:", cfg.model._target_)
    print("Model:", cfg.model.network._target_)
    print("Config:")
    print(OmegaConf.to_yaml(cfg, resolve=True))

In [None]:
run_train_test(cfg)

### Model 2: Demucs (imported from torchaudio)

In [5]:
# Demucs Model
with initialize(version_base=None, config_path="cfg"):
    cfg = compose(config_name="config.yaml", 
                  overrides=["+exp=distortion", "model=demucs"] + global_overrides)
    print("Task:", cfg.model._target_)
    print("Model:", cfg.model.network._target_)
    print("Config:")
    print(OmegaConf.to_yaml(cfg, resolve=True))

Task: remfx.models.RemFX
Model: remfx.models.DemucsModel
Config:
seed: 12345
train: true
sample_rate: 48000
chunk_size: 262144
logs_dir: ./logs
render_files: true
render_root: ./data/rendered
accelerator: cpu
log_audio: true
num_kept_effects:
- 0
- 4
num_removed_effects:
- 1
- 1
shuffle_kept_effects: true
shuffle_removed_effects: false
num_classes: 5
effects_to_keep:
- compressor
- reverb
- chorus
- delay
effects_to_remove:
- distortion
callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: valid_loss
    save_top_k: 1
    save_last: true
    mode: min
    verbose: false
    dirpath: ./logs/ckpts/2023-04-03-00-20-24
    filename: '{epoch:02d}-{valid_loss:.3f}'
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: step
  audio_logging:
    _target_: remfx.callbacks.AudioCallback
    sample_rate: 48000
    log_audio: true
  metric_logging:
    _target_: remfx.callbacks.MetricCallback
dat

In [6]:
run_train_test(cfg)

Global seed set to 12345


Effect Summary: 
Apply kept effects: ['compressor', 'reverb', 'chorus', 'delay'] (Between 0-4, chosen randomly) -> Dry
Apply remove effects: ['distortion'] (1, chosen in order) -> Wet

Found 2889 files in VocalSet train.
Found 240 files in GuitarSet train.
Found 80 files in DSD100 train.
Found 460 files in IDMT-SMT-Drums train.
Found processed files.
Total datasets: 4
Processing files...


100%|██████████| 80/80 [00:05<00:00, 15.22it/s]


Finished rendering
Total chunks: 80
Effect Summary: 
Apply kept effects: ['compressor', 'reverb', 'chorus', 'delay'] (Between 0-4, chosen randomly) -> Dry
Apply remove effects: ['distortion'] (1, chosen in order) -> Wet

Found 364 files in VocalSet val.
Found 60 files in GuitarSet val.
Found 10 files in DSD100 val.
Found 56 files in IDMT-SMT-Drums val.
Found processed files.
Total datasets: 4
Processing files...


100%|██████████| 10/10 [00:00<00:00, 14.21it/s]


Finished rendering
Total chunks: 10
Effect Summary: 
Apply kept effects: ['compressor', 'reverb', 'chorus', 'delay'] (Between 0-4, chosen randomly) -> Dry
Apply remove effects: ['distortion'] (1, chosen in order) -> Wet

Found 360 files in VocalSet test.
Found 60 files in GuitarSet test.
Found 10 files in DSD100 test.
Found 44 files in IDMT-SMT-Drums test.
Found processed files.
Total datasets: 4
Processing files...


100%|██████████| 10/10 [00:01<00:00,  9.32it/s]


Finished rendering
Total chunks: 10


Using cache found in /Users/matthewrice/.cache/torch/hub/harritaylor_torchvggish_master
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmattricesound[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name    | Type        | Params
----------------------------------------
0 | model   | DemucsModel | 83.6 M
1 | metrics | ModuleDict  | 0     
----------------------------------------
83.6 M    Trainable params
0         Non-trainable params
83.6 M    Total params
334.521   Total estimated model params size (MB)


Epoch 4:   0%|          | 0/6 [00:00<?, ?it/s, loss=9.78, v_num=g0t0, Input_SISDR=14.80, Input_STFT=0.659, valid_SISDR=-1.25, valid_STFT=2.980, train_SISDR=-3.25, train_STFT=2.870]         

In [None]:
# Load song and checkpoint
