<a href="https://colab.research.google.com/github/buganart/unagan/blob/master/unagan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# @title Setup
# @markdown 1. Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`
# @markdown 2. Check GPU, should be a Tesla V100 if you want to train it as fast as possible.
# @markdown 3. Mount google drive.
# @markdown 4. Log in to wandb.


!nvidia-smi -L
import os

print(f"We have {os.cpu_count()} CPU cores.")
print()

try:
    from google.colab import drive, output

    IN_COLAB = True
except ImportError:
    from IPython.display import clear_output

    IN_COLAB = False

from pathlib import Path

if IN_COLAB:
    drive.mount("/content/drive/")

    if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
        raise RuntimeError(
            "Shortcut to our shared drive folder doesn't exits.\n\n"
            "\t1. Go to the google drive web UI\n"
            '\t2. Right click shared folder IRCMS_GAN_collaborative_database and click "Add shortcut to Drive"'
        )

clear = output.clear if IN_COLAB else clear_output


def clear_on_success(msg="Ok!"):
    if _exit_code == 0:
        clear()
        print(msg)


print()
print("Wandb installation and login ...")
%pip install -q wandb

wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
wandb_local_netrc_path = Path("/root/.netrc")
if wandb_drive_netrc_path.exists():
    import shutil

    print("Wandb .netrc file found, will use that to log in.")
    shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
else:
    print(
        f"Wandb config not found at {wandb_drive_netrc_path}.\n"
        f"Using manual login.\n\n"
        f"To use auto login in the future, finish the manual login first and then run:\n\n"
        f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
        f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
        f"Then that file will be used to login next time.\n"
    )

!wandb login

GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-2106e880-2d86-b3d7-13cc-a33d2c4b13f6)
We have 2 CPU cores.



KeyboardInterrupt: ignored

# **Description and training**

This notebook serves to train UnaGAN, logging the results to the wandb project "demiurge/unagan". The [buganart/unagan](https://github.com/buganart/unagan) code is a modificaiton of the [ciaua/unagan repository](https://github.com/ciaua/unagan). To start training UnaGAN the user will need to specify the path for **audio_db**, a sound file (.wav) folder in the mounted Google Drive. All of the folder's data will be used for training and training process evaluation.  If the run stops and the user wants to resume it, please specify `wandb run id` in the **resume_run_id**. For all the training arguments, please see [ciaua/unagan repository](https://github.com/ciaua/unagan). 



In [None]:
#@title CONFIGURATION

# Fill in the configuration then Then, select `Runtime` and `Run all` then let it ride!

#@markdown ###Training
drive = Path('/content/drive/MyDrive')
print(f"Google drive at {drive}")    

drive_audio_db_root = drive
collaborative_database = drive / "IRCMS_GAN_collaborative_database"
violingan_experiment_dir = collaborative_database / "Experiments" / "colab-violingan"
experiment_dir = violingan_experiment_dir / "unagan"

#@markdown The path of the Audio Database (folder containing .wav files) you'd like to work with
audio_db = "/content/drive/MyDrive/AUDIO DATABASE/TESTING/" #@param {type:"string"}
audio_db_dir = Path(audio_db)
if not audio_db_dir.exists():
    raise RuntimeError(f"The audio_db_dir {audio_db_dir} does not exist.")

#@markdown Use wand ID to resume previous run or leave empty to start from scratch
resume_run_id = "" #@param {type: "string"}

#@markdown ###Training arguments
feat_dim =  80#@param {type: "integer"}
z_dim = 20 #@param {type: "integer"}
# z_scale_factors = 2 #@param {type: "integer"}
num_va = 200 #@param {type: "integer"}

gamma = 1.0 #@param {type: "number"}
lambda_k = 0.01 #@param {type: "number"}
init_k = 0.0 #@param {type: "number"}

init_lr = 0.001 #@param {type: "number"}
num_epochs = 200 #@param {type: "integer"}

lambda_cycle = 1 #@param {type: "integer"}
max_grad_norm = 3 #@param {type: "integer"}
save_rate = 20 #@param {type: "integer"}
batch_size =  10#@param {type: "integer"}

def check_wandb_id(run_id):
    import re
    if run_id and not re.match(r"^[\da-z]{8}$", run_id):
        raise RuntimeError(
            "Run ID needs to be 8 characters long and contain only letters a-z and digits.\n"
            f"Got \"{run_id}\""
        )

check_wandb_id(resume_run_id)

# z_scale_factors = [z_scale_factor, z_scale_factor, z_scale_factor, z_scale_factor]

config = dict(
    audio_db_dir=audio_db_dir,
    resume_run_id=resume_run_id,
    feat_dim=feat_dim,
    z_dim=z_dim,
    num_va=num_va,
    gamma=gamma,
    lambda_k=lambda_k,
    init_k=init_k,
    init_lr=init_lr,
    num_epochs=num_epochs,
    lambda_cycle=lambda_cycle,
    max_grad_norm=max_grad_norm,
    save_rate=save_rate,
    batch_size=batch_size,
)
for k,v in config.items():
    print(f"=> {k:30}: {v}")

In [None]:
#@title CLONE UNAGAN REPO AND INSTALL DEPENDENCIES

# os.environ["WANDB_MODE"] = "dryrun"
if IN_COLAB:
    !git clone https://github.com/buganart/unagan
    %cd "/content/unagan/"
    # !git checkout dev
    %pip install -r requirements.txt

    clear_on_success("Repo cloned! Dependencies installed!")

In [None]:
#@title COPY FILES TO LOCAL RUNTIME
local_wav_dir = Path("data")
local_wav_dir.mkdir(exist_ok=True)
!find "{audio_db_dir}"/ -maxdepth 1 -type f | xargs -t -d "\n" -I'%%' -P 10 -n 1 rsync -a '%%' "$local_wav_dir"/
clear_on_success("All files copied to this runtime.")

audio_paths = sorted(list(local_wav_dir.glob("*")))
num_files = len(audio_paths)
print(f"Database has {num_files} files in total.")


In [None]:
#@title COLLECT AUDIO CLIPS
!python scripts/collect_audio_clips.py --audio-dir "$local_wav_dir" --extension wav
clear_on_success(f"Done.")

In [None]:
#@title EXTRACT MEL SPECTROGRAMS
!python scripts/extract_mel.py --n_mel_channels "$feat_dim"
clear_on_success("Done!")

In [None]:
#@title GENERATE DATASET
!python scripts/make_dataset.py
clear_on_success("Done!")

In [None]:
#@title COMPUTE MEAN AND STANDARD DEVIATION
!python scripts/compute_mean_std.mel.py
clear_on_success("Done!")

In [None]:
#@title TRAIN

!env PYTHONPATH="." python scripts/train.hierarchical_with_cycle.py \
    --model-id "$resume_run_id" \
    --audio_db_dir "$audio_db_dir" \
    --wandb-dir "$experiment_dir" \
    --feat_dim "$feat_dim" \
    --z_dim "$z_dim" \
    --num_va "$num_va" \
    --gamma "$gamma" \
    --lambda_k "$lambda_k" \
    --init_k "$init_k" \
    --init_lr "$init_lr" \
    --num_epochs "$num_epochs" \
    --lambda_cycle "$lambda_cycle" \
    --max_grad_norm "$max_grad_norm" \
    --save_rate "$save_rate" \
    --batch_size "$batch_size"