In [None]:
import os
from glob import glob
import urllib
from tqdm import tqdm
import zipfile
CKPT_URL = "https://huggingface.co/therealvul/nuwave2/resolve/main/multi0_40_0/nuwave2_01_28_20_epoch%3D631.ckpt?download=true"
CKPT_NAME = "nuwave2_01_28_20_epoch=631.ckpt"
HPARAMS_URL = "https://huggingface.co/therealvul/nuwave2/resolve/main/multi0_40_0/hparameter.yaml?download=true"
DATASET_ZIP_URL = "https://huggingface.co/datasets/therealvul/StyleTTS2MLPAcousticReconstruction/resolve/main/Multi0Epoch40AcousticRecData.zip?download=true"
DATASET_TARGET_DIR = "data_dir"

nuwave2_basedir = "/root/nuwave2"
dataset_path = os.path.join(nuwave2_basedir, "dataset.zip")
hparams_path = os.path.join(nuwave2_basedir, "hparameter.yaml")
checkpoints_dir = os.path.join(nuwave2_basedir, "checkpoint")

def download_file(url, dest):
    print(f"Downloading {url} to {dest}")
    with urllib.request.urlopen(url) as r, open(dest, 'wb') as out_file:
        total_size = int(r.info().get('Content-Length', 0))
        block_size = 1024
        with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
            while True:
                data = r.read(block_size)
                if not data:
                    break
                out_file.write(data)
                pbar.update(len(data))

if not os.path.exists(checkpoints_dir):
    os.makedirs(checkpoints_dir, exist_ok=True)

# Download model and hparams
if not len(glob(os.path.join(checkpoints_dir, "*.ckpt"))):
    download_file(CKPT_URL, os.path.join(checkpoints_dir,CKPT_NAME))
    download_file(HPARAMS_URL, hparams_path)

# Download zip dataset
if not os.path.exists(DATASET_TARGET_DIR):
    os.makedirs(DATASET_TARGET_DIR, exist_ok=True)
    download_file(DATASET_ZIP_URL, dataset_path)
    with zipfile.ZipFile(dataset_path, 'r') as f:
        files_count = len(f.namelist())
        with tqdm(
            total=files_count, desc="Extracting files", unit="file") as pbar:
            for info in f.namelist():
                if info.endswith('.wav'):
                    unzip_file_path = DATASET_TARGET_DIR
                    if os.path.exists(os.path.join(unzip_file_path, info)):
                        continue
                    f.extract(info, unzip_file_path)
                pbar.update(1)

In [None]:
# Adjust hparams
BATCH_SIZE = 24 # 8 -> 16 GB

# Use DATASET_ZIP_DIRECTORY
import yaml
import os
import torch
with open(hparams_path) as f:
    hparams = yaml.safe_load(f)
    hparams['batch_size'] = BATCH_SIZE
    hparams['data']['dir'] = os.path.join(nuwave2_basedir, DATASET_TARGET_DIR)
    hparams['train']['num_workers'] = os.cpu_count()
    hparams['gpus'] = torch.cuda.device_count()

with open(hparams_path, 'w') as f:
    yaml.dump(hparams, f, default_flow_style=False)
    print("hparams written")

In [None]:
last_ckpt = sorted(glob(os.path.join(checkpoints_dir, "*.ckpt")))[-1]
assert(os.path.exists(last_ckpt))
import re
pattern=r'epoch=(\d+)'
match = re.search(pattern, last_ckpt)
if match:
    epoch_num = int(match.group(1))
    print(f"Resume epoch {epoch_num}")
    !python trainer.py -r {epoch_num}
else:
    print("Starting fresh train")
    !python trainer.py