In [2]:
# imports
import math
import wave
import struct
import os 
# import tarfile
import audiolm_pytorch
from audiolm_pytorch import AudioLMSoundStream, SoundStreamTrainer
from audiolm_pytorch import EncodecWrapper
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio
from torch.profiler import profile, record_function, ProfilerActivity, schedule
import datetime
import argparse
import re

import random
import numpy as np
from torch.utils.data import DataLoader

import cocochorales_custom_dataset

In [3]:
# Usage:
# python audiolm_pytorch_demo_laion.py --semantic=/path/to/semantic --coarse=/path/to/coarse --fine=/path/to/fine
# Checkpoint flags are optional of course. You need to give a full path, no guarantees if it's not a full path.
# define all dataset paths, checkpoints, etc
prefix = "/media/checkpoint/audiolm-pytorch-results"
hubert_ckpt = f'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

print(f"training on audiolm_pytorch version {audiolm_pytorch.version.__version__}")


training on audiolm_pytorch version 1.2.19


In [14]:
#############

codec = EncodecWrapper()
wav2vec = HubertWithKmeans(
    # use_mert = True,
    checkpoint_path = f"{prefix}/{hubert_ckpt}",
    # checkpoint_path = None,
    kmeans_path = f"{prefix}/{hubert_quantizer}"
)

/media/checkpoint/audiolm-pytorch-results/hubert/hubert_base_ls960.pt


In [5]:
def get_semantic_transformer():
    semantic_transformer = SemanticTransformer(
        num_semantic_tokens = wav2vec.codebook_size,
        dim = 1024,
        depth = 6
    ).cuda()
    return semantic_transformer


In [6]:
def get_coarse_transformer():
    coarse_transformer = CoarseTransformer(
        num_semantic_tokens = wav2vec.codebook_size,
        codebook_size = 1024,
        num_coarse_quantizers = 3,
        dim = 512,
        depth = 6
    ).cuda()
    return coarse_transformer


In [7]:

def get_fine_transformer():
    fine_transformer = FineTransformer(
        num_coarse_quantizers = 3,
        num_fine_quantizers = 5,
        codebook_size = 1024,
        dim = 512,
        depth = 6
    ).cuda()
    return fine_transformer

In [8]:

def get_results_folder_path(transformer_name, prefix, results_folder_slurm_job_id):
    assert transformer_name in {"semantic", "coarse", "fine"}
    results_folder = f"{prefix}/{transformer_name}_results_{results_folder_slurm_job_id}"
    return results_folder

def get_potential_checkpoint_num_steps(results_folder):
    if not os.path.exists(results_folder):
        return None

    checkpoints = [f for f in os.listdir(results_folder) if f.endswith('.pt')]
    steps = [int(re.findall(r'\d+', ckpt)[-1]) for ckpt in checkpoints]
    max_step = max(steps, default=0)
    return max_step

def get_potential_checkpoint_path(transformer_name, prefix, results_folder_slurm_job_id):
    """Determine checkpoint paths based on the checkpoint id for the transformer specified by transformer_name and prefix. searches in `prefix` folder) or latest available checkpoints in `prefix` folder. Returns None if no such checkpoints exist at all."""
    results_folder = get_results_folder_path(transformer_name, prefix, results_folder_slurm_job_id)
    max_step = get_potential_checkpoint_num_steps(results_folder)
    return f"{results_folder}/{transformer_name}.transformer.{max_step}.pt" if max_step > 0 else None


In [11]:
semantic_transformer = get_semantic_transformer()
coarse_transformer = get_coarse_transformer()
fine_transformer = get_fine_transformer()

semantic_results_folder_suffix = str(1)
coarse_results_folder_suffix = str(1)
fine_results_folder_suffix = str(1)
# sampling using one gpu only, so just load info about the transformers instead of using the trainer's load method
semantic_ckpt = get_potential_checkpoint_path("semantic", prefix, semantic_results_folder_suffix)
coarse_ckpt = get_potential_checkpoint_path("coarse", prefix, coarse_results_folder_suffix)
fine_ckpt = get_potential_checkpoint_path("fine", prefix, fine_results_folder_suffix)
assert semantic_ckpt is not None and coarse_ckpt is not None and fine_ckpt is not None, "all three checkpoints should exist"


In [12]:
print(semantic_ckpt)
print(coarse_ckpt)
print(fine_ckpt)

/media/checkpoint/audiolm-pytorch-results/semantic_results_1/semantic.transformer.86400.pt
/media/checkpoint/audiolm-pytorch-results/coarse_results_1/coarse.transformer.86400.pt
/media/checkpoint/audiolm-pytorch-results/fine_results_1/fine.transformer.25200.pt


In [15]:

semantic_transformer.load(semantic_ckpt)
coarse_transformer.load(coarse_ckpt)
fine_transformer.load(fine_ckpt)
assert semantic_transformer.device == coarse_transformer.device and coarse_transformer.device == fine_transformer.device, f"all three transformers should be on the same device. instead got semantic on {semantic_transformer.device}, coarse on {coarse_transformer.device}, and fine on {fine_transformer.device}"


In [16]:
print("loaded checkpoints. sampling now...")
# Generate output and save
audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = codec,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

loaded checkpoints. sampling now...


generating semantic: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2048/2048 [00:33<00:00, 61.87it/s]
generating coarse: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [00:12<00:00, 41.58it/s]
generating fine: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [02:41<00:00,  3.18it/s]


sampled fine token ids shape: torch.Size([1, 2560])
num eos ids: 0
indices of eos id: []
AFTER MASKING OUT AFTER EOS
num eos ids: 0
indices of eos id: []
mask out generated fine tokens is False
coarse_and_fine_ids.shape torch.Size([1, 512, 8]) to be decoded from codebook indices by codec
wav shape after codec.decode_from_codebook_indices: torch.Size([1, 1, 163840])


NameError: name 'args' is not defined

In [None]:


generated_wav = audiolm(batch_size = 1)

In [None]:
generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))

In [18]:
generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])

generating semantic: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2048/2048 [00:33<00:00, 61.22it/s]
generating coarse: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [00:12<00:00, 41.25it/s]
generating fine: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [02:40<00:00,  3.18it/s]


sampled fine token ids shape: torch.Size([1, 2560])
num eos ids: 0
indices of eos id: []
AFTER MASKING OUT AFTER EOS
num eos ids: 0
indices of eos id: []
mask out generated fine tokens is False
coarse_and_fine_ids.shape torch.Size([1, 512, 8]) to be decoded from codebook indices by codec
wav shape after codec.decode_from_codebook_indices: torch.Size([1, 1, 163840])


In [19]:

semantic_results_folder = get_results_folder_path("semantic", prefix, semantic_results_folder_suffix)
semantic_num_steps = get_potential_checkpoint_num_steps(semantic_results_folder)
coarse_results_folder = get_results_folder_path("coarse", prefix, coarse_results_folder_suffix)
coarse_num_steps = get_potential_checkpoint_num_steps(coarse_results_folder)
fine_results_folder = get_results_folder_path("fine", prefix, fine_results_folder_suffix)
fine_num_steps = get_potential_checkpoint_num_steps(fine_results_folder)
output_path = f"{prefix}/out_semantic_id_{semantic_results_folder_suffix}_steps_{semantic_num_steps}_coarse_id_{coarse_results_folder_suffix}_steps_{coarse_num_steps}_fine_id_{fine_results_folder_suffix}_steps_{fine_num_steps}.wav"
sample_rate = 24000
torchaudio.save(output_path, generated_wav_with_text_condition.cpu(), sample_rate)
print("sampled. exiting.")

sampled. exiting.
