In [1]:
import os
import torchaudio
# Get audio file path
curr_dir = os.getcwd()
file_path = os.path.join(curr_dir, 'sample2.wav')
audio_samples = os.path.join(curr_dir, 'audio_samples')

In [2]:
# Load the waveform
waveform, sample_rate = torchaudio.load(file_path)
print(f'waveform.shape \t{waveform.shape}')
print(f'sample_rate \t{sample_rate}')
print(f'length: \t{waveform.shape[1]/sample_rate} sec')

waveform.shape 	torch.Size([2, 1323000])
sample_rate 	44100
length: 	30.0 sec


In [3]:
from audiolm_pytorch import EncodecWrapper
encodec = EncodecWrapper()

encodec.eval()

soundstream = encodec

INFO:fairseq.tasks.text_to_speech:Please install tensorboardX: pip install tensorboardX


In [8]:
from encodec import EncodecModel
from encodec.utils import convert_audio

import torchaudio
import torch

# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_48khz()
# The number of codebooks used will be determined bythe bandwidth selected.
# E.g. for a bandwidth of 6kbps, `n_q = 8` codebooks are used.
# Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8) and 12 kbps (n_q =16) and 24kbps (n_q=32).
# For the 48 kHz model, only 3, 6, 12, and 24 kbps are supported. The number
# of codebooks for each is half that of the 24 kHz model as the frame rate is twice as much.
model.set_target_bandwidth(12)

# Load and pre-process the audio waveform
wav, sr = torchaudio.load(file_path)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.unsqueeze(0)


# Extract discrete codes from EnCodec
with torch.no_grad():
    encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [B, n_q, T]
print(codes)
print(codes.shape)

tensor([[[ 752,  385,  524,  ...,  409,  312,  214],
         [ 852,  789,  888,  ...,  698,  471,  826],
         [  37,  965,  683,  ..., 1009,  977,  971],
         ...,
         [ 521,  115,  130,  ...,  945,  762,    0],
         [  11,  179,  564,  ...,  681, 1014, 1014],
         [ 666,  581,  957,  ...,  997,  694,  686]]])
torch.Size([1, 8, 4545])


In [9]:
embeddings, codes, _ = encodec(waveform, return_encoded=True)
print(embeddings)
print(codes)

tensor([[[-1.2382,  9.5159, -6.6285,  ...,  0.8754, -3.2991,  1.7266],
         [-0.2811, 10.4031, -4.3174,  ..., -0.4926, -1.1932,  3.5006],
         [-0.1946, 11.2920, -4.1249,  ..., -0.1044, -1.0520,  4.0774],
         ...,
         [ 0.6942,  9.6216, -2.8055,  ...,  0.4207, -1.3194,  2.5025],
         [ 0.1435,  9.0620, -3.1732,  ..., -0.2560, -1.8630,  1.7947],
         [-0.2630,  8.9826, -3.0226,  ..., -0.2029, -1.1596,  2.3247]],

        [[-1.2382,  9.5159, -6.6285,  ...,  0.8754, -3.2991,  1.7266],
         [-0.2811, 10.4031, -4.3174,  ..., -0.4926, -1.1932,  3.5006],
         [-0.1946, 11.2920, -4.1249,  ..., -0.1044, -1.0520,  4.0774],
         ...,
         [ 0.6942,  9.6216, -2.8055,  ...,  0.4207, -1.3194,  2.5025],
         [ 0.1435,  9.0620, -3.1732,  ..., -0.2560, -1.8630,  1.7947],
         [-0.2630,  8.9826, -3.0226,  ..., -0.2029, -1.1596,  2.3247]]])
tensor([[[ 670,  901,  301,  ...,  724,  851,  896],
         [ 957,  669,  483,  ...,   19,  782,  472],
         [

In [10]:
print(embeddings.shape) #batch_size, num_frames, emedding_dim
print(codes.shape) # batch_size, num_frames, num_quantizers

torch.Size([2, 4135, 128])
torch.Size([2, 4135, 8])


In [11]:
# Create semantic tokens
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    flash_attn = True
).cpu()

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = audio_samples,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train() # saved into results folder


AssertionError: only one Trainer can be instantiated at a time for training

In [8]:
from audiolm_pytorch import HubertWithKmeans

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

In [None]:
import torch
from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerTrainer



coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6,
    flash_attn = True
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = audio_samples,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1_000_000
)

trainer.train()

training with dataset of 156 samples and validating with randomly splitted 9 samples
0: loss: 69.8819808959961
0: valid loss 58.27711868286133
0: saving model to results
1: loss: 74.16581726074219
2: loss: 65.0303726196289
3: loss: 59.086605072021484
4: loss: 58.262535095214844
5: loss: 56.15260696411133
6: loss: 27.63350486755371
7: loss: 40.65735626220703
8: loss: 57.455135345458984
9: loss: 36.98195266723633
10: loss: 53.835548400878906
11: loss: 50.02263641357422
12: loss: 32.02082443237305
13: loss: 42.12492370605469
14: loss: 45.14252471923828
15: loss: 45.72853088378906
16: loss: 44.671546936035156
17: loss: 26.326539993286133
18: loss: 42.56187057495117
19: loss: 31.45322036743164
20: loss: 37.24658966064453
21: loss: 38.65011215209961
22: loss: 30.86720085144043
23: loss: 30.712915420532227
24: loss: 33.37001037597656
25: loss: 26.17238998413086
26: loss: 28.163972854614258
27: loss: 17.022890090942383
28: loss: 11.208353042602539
29: loss: 28.43684959411621
30: loss: 30.52349

KeyboardInterrupt: 

In [None]:
import torch
from audiolm_pytorch import FineTransformer, FineTransformerTrainer

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6,
    flash_attn = True
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = soundstream,
    folder = audio_samples,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 6_000
)

trainer.train()

training with dataset of 156 samples and validating with randomly splitted 9 samples
0: loss: 75.78882598876953
0: valid loss 73.59510803222656
0: saving model to results
1: loss: 59.42388153076172
2: loss: 72.45049285888672
3: loss: 52.032752990722656
4: loss: 31.2309513092041
5: loss: 71.10739135742188
6: loss: 40.21917724609375
7: loss: 37.598655700683594
8: loss: 50.03717803955078
9: loss: 61.48522186279297
10: loss: 43.96183776855469
11: loss: 39.361324310302734
12: loss: 36.05011749267578
13: loss: 51.9398193359375
14: loss: 40.307037353515625
15: loss: 41.769371032714844
16: loss: 44.776649475097656
17: loss: 44.575462341308594
18: loss: 43.00227355957031
19: loss: 43.78857421875
20: loss: 39.390403747558594
21: loss: 38.168724060058594
22: loss: 36.61262512207031
23: loss: 38.80466079711914
24: loss: 39.066864013671875
25: loss: 14.242904663085938
26: loss: 38.4583740234375
27: loss: 21.84329605102539
28: loss: 36.355995178222656
29: loss: 9.721351623535156
30: loss: 35.2554550

In [None]:
from audiolm_pytorch import SemanticTransformer, FineTransformer, CoarseTransformer

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    flash_attn = True
).cpu()

semantic_transformer.load_state_dict(torch.load('./results/semantic.transformer.0.pt'))

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6,
    flash_attn = True
)

coarse_transformer.load('./results/coarse.transformer.6000.pt')

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6,
    flash_attn = True,
)

fine_transformer.load('./results/fine.transformer.6000.pt')

print(fine_transformer)

RuntimeError: Error(s) in loading state_dict for SemanticTransformer:
	Missing key(s) in state_dict: "start_token", "semantic_embedding.weight", "proj_text_embed.weight", "transformer.layers.0.0.norm.gamma", "transformer.layers.0.0.norm.beta", "transformer.layers.0.0.to_q.weight", "transformer.layers.0.0.to_kv.weight", "transformer.layers.0.0.to_out.0.weight", "transformer.layers.0.2.0.gamma", "transformer.layers.0.2.0.beta", "transformer.layers.0.2.1.weight", "transformer.layers.0.2.3.gamma", "transformer.layers.0.2.3.beta", "transformer.layers.0.2.5.weight", "transformer.layers.1.0.norm.gamma", "transformer.layers.1.0.norm.beta", "transformer.layers.1.0.to_q.weight", "transformer.layers.1.0.to_kv.weight", "transformer.layers.1.0.to_out.0.weight", "transformer.layers.1.2.0.gamma", "transformer.layers.1.2.0.beta", "transformer.layers.1.2.1.weight", "transformer.layers.1.2.3.gamma", "transformer.layers.1.2.3.beta", "transformer.layers.1.2.5.weight", "transformer.layers.2.0.norm.gamma", "transformer.layers.2.0.norm.beta", "transformer.layers.2.0.to_q.weight", "transformer.layers.2.0.to_kv.weight", "transformer.layers.2.0.to_out.0.weight", "transformer.layers.2.2.0.gamma", "transformer.layers.2.2.0.beta", "transformer.layers.2.2.1.weight", "transformer.layers.2.2.3.gamma", "transformer.layers.2.2.3.beta", "transformer.layers.2.2.5.weight", "transformer.layers.3.0.norm.gamma", "transformer.layers.3.0.norm.beta", "transformer.layers.3.0.to_q.weight", "transformer.layers.3.0.to_kv.weight", "transformer.layers.3.0.to_out.0.weight", "transformer.layers.3.2.0.gamma", "transformer.layers.3.2.0.beta", "transformer.layers.3.2.1.weight", "transformer.layers.3.2.3.gamma", "transformer.layers.3.2.3.beta", "transformer.layers.3.2.5.weight", "transformer.layers.4.0.norm.gamma", "transformer.layers.4.0.norm.beta", "transformer.layers.4.0.to_q.weight", "transformer.layers.4.0.to_kv.weight", "transformer.layers.4.0.to_out.0.weight", "transformer.layers.4.2.0.gamma", "transformer.layers.4.2.0.beta", "transformer.layers.4.2.1.weight", "transformer.layers.4.2.3.gamma", "transformer.layers.4.2.3.beta", "transformer.layers.4.2.5.weight", "transformer.layers.5.0.norm.gamma", "transformer.layers.5.0.norm.beta", "transformer.layers.5.0.to_q.weight", "transformer.layers.5.0.to_kv.weight", "transformer.layers.5.0.to_out.0.weight", "transformer.layers.5.2.0.gamma", "transformer.layers.5.2.0.beta", "transformer.layers.5.2.1.weight", "transformer.layers.5.2.3.gamma", "transformer.layers.5.2.3.beta", "transformer.layers.5.2.5.weight", "transformer.norm.gamma", "transformer.norm.beta", "to_logits.weight", "to_logits.bias". 
	Unexpected key(s) in state_dict: "model", "optim", "version". 

In [None]:
from audiolm_pytorch import AudioLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = torch.load('./results/semantic.transformer.0.pt'),
    coarse_transformer = torch.load('./results/coarse.transformer.0.pt'),
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

# # or with priming

# generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))

# # or with text condition, if given

# generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])


In [19]:
from torch.utils.data import DataLoader, Dataset

class AudioDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        waveform, sample_rate = torchaudio.load(self.file_list[idx])
        return waveform

# Get audio sound files

file_list_path = [os.path.join(curr_dir, 'audio_samples', f) for f in os.listdir(os.path.join(curr_dir, 'audio_samples')) if f.endswith('.wav')]

dataset = AudioDataset(file_list_path)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper

conformer = ConformerWrapper(
    codebook_size = 1024,
    num_quantizers = 8,
    conformer = dict(
        dim = 512,
        depth = 2
    ),
)

model = SoundStorm(
    conformer,
    steps = 18,          # 18 steps, as in original maskgit paper
    schedule = 'cosine'  # currently the best schedule is cosine
)

# do the below in a loop for a ton of data

loss, _ = model(codes)
loss.backward()

# The SoundStorm model is now trained and can be used to generate audio
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(10):  # Example: 10 epochs\n",
    for batch in data_loader:  # Assuming data_loader is defined\n",

        loss, _ = model(codes)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f'Epoch {epoch + 1} completed')

# model can now generate in 18 steps. ~2 seconds sounds reasonable

generated = model.generate(1024, batch_size = 2) # (2, 1024)
print(generated.shape)

IndexError: too many indices for tensor of dimension 3

In [13]:
import torchaudio
from IPython.display import Audio

# Assuming the waveform is a tensor of shape [2, 8192]
waveform = generated_audio # Example waveform
sample_rate = 16000

# Play the waveform directly
Audio(waveform.numpy(), rate=sample_rate)
