<a href="https://colab.research.google.com/github/buganart/unagan/blob/master/Unagan_generate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Setup
# @markdown 1. Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`
# @markdown 2. Check GPU, should be a Tesla V100 if you want to train it as fast as possible.
# @markdown 3. Mount google drive.
# @markdown 4. Log in to wandb.

!nvidia-smi -L
import os
os.environ["WANDB_MODE"] = "dryrun"

print(f"We have {os.cpu_count()} CPU cores.")
print()

try:
    from google.colab import drive, output

    IN_COLAB = True
except ImportError:
    from IPython.display import clear_output

    IN_COLAB = False

from pathlib import Path

if IN_COLAB:
    drive.mount("/content/drive/")

    if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
        raise RuntimeError(
            "Shortcut to our shared drive folder doesn't exits.\n\n"
            "\t1. Go to the google drive web UI\n"
            '\t2. Right click shared folder IRCMS_GAN_collaborative_database and click "Add shortcut to Drive"'
        )

clear = output.clear if IN_COLAB else clear_output

def clear_on_success(msg="Ok!"):
    if _exit_code == 0:
        clear()
        print(msg)

print()
print("Wandb installation and login ...")
%pip install -q wandb

wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
wandb_local_netrc_path = Path("/root/.netrc")
if wandb_drive_netrc_path.exists():
    import shutil

    print("Wandb .netrc file found, will use that to log in.")
    shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
else:
    print(
        f"Wandb config not found at {wandb_drive_netrc_path}.\n"
        f"Using manual login.\n\n"
        f"To use auto login in the future, finish the manual login first and then run:\n\n"
        f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
        f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
        f"Then that file will be used to login next time.\n"
    )

!wandb login

In [None]:
#@title Configuration

#@markdown Wandb run id for `melgan` run.
melgan_run_id = "h6tospcl" #@param {type: "string"}

#@markdown Wandb run id for `unagan` run.
unagan_run_id = "2o3gbv1z" #@param {type: "string"}

#@markdown Duration of generate samples in seconds.
duration = 10 #@param {type: "integer"}

#@markdown Number of samples to generate.
num_samples = 10 #@param {type: "integer"}

#@markdown Random seed for sample generation.
seed = 123 #@param {type: "integer"}

#@markdown The path of the directory where a directory for the generate files is created.
output_dir = "/content/drive/MyDrive/IRCMS_GAN_collaborative_database/Experiments/colab-violingan/unagan-outputs" #@param {type:"string"}
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

def check_wandb_id(run_id):
    import re
    if run_id and not re.match(r"^[\da-z]{8}$", run_id):
        raise RuntimeError(
            "Run ID needs to be 8 characters long and contain only letters a-z and digits.\n"
            f"Got \"{run_id}\""
        )

check_wandb_id(unagan_run_id)
check_wandb_id(melgan_run_id)

config = dict(
    melgan_run_id=melgan_run_id,
    unagan_run_id=unagan_run_id,
    duration=duration,
    num_samples=num_samples,
    seed=seed,
    output_dir=output_dir,
)

for k,v in config.items():
    print(f"=> {k:20}: {v}")

In [None]:
#@title Clone `buganart/unagan` repo.
if IN_COLAB:
    %cd /content
    !git clone https://github.com/buganart/unagan
    clear_on_success("Repo cloned!")

In [None]:
#@title Install dependencies
if IN_COLAB:
    %cd /content/unagan
    %pip install -q -r requirements.txt

clear_on_success("Dependencies installed!")

In [None]:
#@title Download files from wandb
import download_weights

download_weights.main(
    melgan_run_id=melgan_run_id,
    unagan_run_id=unagan_run_id,
    model_dir=Path('models/custom')
)

In [None]:
#@title Generate
import generate

generate.main(
    num_samples=num_samples, 
    gid=0,
    output_folder=output_dir,
    seed=seed,
    duration=duration,
    melgan_run_id=melgan_run_id,
    unagan_run_id=unagan_run_id,
)
print(f"Samples saved to {output_dir}")