In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from encodec import EncodecModel
from encodec.utils import convert_audio

import torch

bd = 3.0 #bandwidth
cd = int(bd*(2/3)) #codebook dim

enc_model = EncodecModel.encodec_model_48khz()
enc_model.set_target_bandwidth(bd)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
enc_model.to(device)

In [None]:
import os
import torch
import wandb
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Model, GPT2Config, Trainer, TrainingArguments
from transformers import GPT2LMHeadModel, GPT2Config
from tqdm import tqdm

np = 2688 #(Number of Positional embeddings) - max sequence length

checkpoint_dir = f"./results-{np}-{bd}-kbps/checkpoint-915000/"

from IPython.display import Audio

def codes_to_wav(codes):
    
    data = codes.reshape(1, codes.shape[0]//cd, cd).T.squeeze().unsqueeze(0).to(device)
    
    enc_frames = [(data, None)]
        
    data = enc_model.decode(enc_frames)
    
    return data

class DanceDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.file_list = os.listdir(data_dir)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        tensor = torch.load(file_path).detach().cpu()
        last_dim = tensor.shape[-1]
        start_idx = torch.randint(0, max(last_dim - np//cd, 1), (1,)).item()
        end_idx = min(start_idx + np//cd, last_dim)
        sample = tensor[:, :, start_idx:end_idx]  #Take a random slice of audio from the current audio file

        input_ids = sample.mT.flatten()
        attention_mask = torch.ones_like(input_ids)
        labels = sample.mT.flatten()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

custom_config = GPT2Config(
    vocab_size=1026,
    n_positions=np, #max positional embeddings
    n_layer=12,
    n_embd=768,
    n_head=12,
    n_inner=3072,
    decoder_start_token_id=1024,
    pad_token_id=1024,
    eos_token_id=1025
)

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import GPT2LMHeadModel

resume_from_checkpoint = True #Set to False for fresh training run

#Load trained GPT-2 model with custom configuration
if resume_from_checkpoint:
    #Load from checkpoint
    model = GPT2LMHeadModel.from_pretrained(checkpoint_dir)
else:
    #Fresh training
    model = GPT2LMHeadModel(custom_config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

dance_dataset = DanceDataset("dance-data\\encoded-44khz-3kbps")
data_loader = DataLoader(dance_dataset, batch_size=1, shuffle=True)

for batch in data_loader:
    print("Batch shape:", batch["input_ids"].shape)
    print(batch["input_ids"])
    audio = codes_to_wav(batch["input_ids"].squeeze(0))
    break
#Audio(data=audio.squeeze().detach().cpu().numpy(), rate=48000, autoplay=True)

print(batch["input_ids"].device)

In [None]:
import torchaudio

def evaluate_fn(step, temperatures=[1.5, 2.0, 2.5, 3.0, 3.5], save_dir=f"audio_samples-{np}-{bd}-kbps"):
    model.to(device)

    model.eval()

    os.makedirs(save_dir, exist_ok=True)

    #Generate audio samples
    for temperature in tqdm(temperatures, desc="Temperature"):
        print("test")
        input_ids = torch.tensor([[0]]).to(device)  #start token
        output = model.generate(input_ids=input_ids, min_length=np-1, max_length=np, do_sample=True, temperature=temperature, num_return_sequences=1)
        output = output.squeeze(0)
        audio_sample = codes_to_wav(output)
        audio_data = audio_sample.squeeze().detach().cpu().numpy()
        file_path = os.path.join(save_dir, f"audio_sample_step_{step}_temp_{temperature}.wav")
        torchaudio.save(file_path, torch.from_numpy(audio_data), 48000)

    return {}

In [None]:
Audio(data=audio.squeeze().detach().cpu().numpy(), rate=48000, autoplay=True)

In [None]:
from torch.utils.data.dataset import random_split
from transformers import TrainerCallback

wandb.init(project="dance-44khz")

train_size = int(0.95 * len(dance_dataset))
eval_size = len(dance_dataset) - train_size
train_dataset, eval_dataset = random_split(dance_dataset, [train_size, eval_size])

#Training arguments
training_args = TrainingArguments(
    output_dir=f"./results-{np}-{bd}-kbps",
    num_train_epochs=1000,
    save_total_limit=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    logging_dir="./logs",
    logging_steps=100,
    save_steps=5000,
    evaluation_strategy="steps",
    eval_steps=5000,
    report_to="wandb",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

#callback will generate new 5 audio samples (at varying temperature settings) every 5000 iterations
class CustomCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, model=None, tokenizer=None, **kwargs):
        pass

    def on_step_end(self, args, state, control, **kwargs):
        global_step = state.global_step
        if global_step % 5000 == 0:
            evaluate_fn(global_step)

custom_callback = CustomCallback()

trainer.add_callback(custom_callback)

#Perform initial evaluation
#initial_evaluation_result = model.evaluate(eval_dataset=eval_dataset, step=0, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

#Train the model
trainer.train()

### INFERENCE

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

#Load the trained model
model_path = checkpoint_dir
#tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)

### Generation Testing

In [None]:
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

max_length = np 

input_ids = torch.tensor([[0]]).to(device)
output = model.generate(input_ids=input_ids, max_length=max_length, do_sample=True, temperature=2.5, num_return_sequences=1)

In [None]:
print(output.shape)
print(output)

In [None]:
output = output.squeeze(0)
audio = codes_to_wav(output)

In [None]:
from IPython.display import Audio

Audio(data=audio.squeeze().detach().cpu().numpy(), rate=48000, autoplay=True)

In [None]:
print(output)

### Musical Continuation

In [None]:
import librosa
import torch, torchaudio

#Load wav file which you want to extend
file_path = 'example_audio//sample.wav'
wav, sr = torchaudio.load(file_path)

wav = convert_audio(wav, sr, enc_model.sample_rate, enc_model.channels)
wav = wav.unsqueeze(0).to(device)

#Calculate the start and end index for the audio clip
start_time = 7  #Start time (s)
end_time = 15  #End time (s)
start_index = int(start_time * sr)
end_index = int(end_time * sr)

segment = wav[:, :, start_index:end_index]

segment_tensor = torch.tensor(segment)

In [None]:
print(segment_tensor.shape)

In [None]:
audio = segment_tensor.to(device)

#Extract codes
with torch.no_grad():
    encoded_frames = enc_model.encode(audio)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [B, n_q, T]

codes = codes.mT.flatten()

In [None]:
print(codes.shape)

In [None]:
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

max_length = np 

all_ids = []

input_ids = codes.unsqueeze(0).to(device)
    
w = int(np//2) #(window) size
    
for i in tqdm(range(4)):
    output = model.generate(input_ids=input_ids, max_length=max_length, do_sample=True, temperature=1.3, num_return_sequences=1)

    last_ids = output[:, -(np-w):]
    
    if i == 0:
        all_ids.extend(output.squeeze().tolist())
    else:
        all_ids.extend(last_ids[:, -w:].squeeze().tolist())
    
    input_ids = torch.tensor(last_ids).to(device)

In [None]:
all_ids = torch.tensor(all_ids)
    
output = all_ids.squeeze(0)
audio = codes_to_wav(output)

gen_save_dir = f"cont_audio_samples-{np}-{bd}"

os.makedirs(gen_save_dir, exist_ok=True)
    
audio_data = audio.squeeze().detach().cpu().numpy()
file_path = os.path.join(gen_save_dir, f"audio_sample.wav")
torchaudio.save(file_path, torch.from_numpy(audio_data), 48000)

torch.cuda.empty_cache()

In [None]:
from IPython.display import Audio

Audio(data=segment_tensor.squeeze().detach().cpu().numpy(), rate=48000, autoplay=True)

### Longer Generation

In [None]:
from tqdm import tqdm
import torchaudio

model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

max_length = np 

for h in range(200):

    all_ids = []
    
    input_ids = torch.tensor([[0]]).to(device)
    
    w = int(np//2) #(window) size
    
    for i in tqdm(range(10)):
        output = model.generate(input_ids=input_ids, max_length=max_length, do_sample=True, temperature=1.5, num_return_sequences=1)
    
        last_ids = output[:, -(np-w):]
    
        if i == 0:
            all_ids.extend(output.squeeze().tolist())
        else:
            all_ids.extend(last_ids[:, -w:].squeeze().tolist())
    
        input_ids = torch.tensor(last_ids).to(device)
    
    all_ids = torch.tensor(all_ids)
    
    output = all_ids.squeeze(0)
    audio = codes_to_wav(output)

    gen_save_dir = f"gen_audio_samples-{np}-{bd}"

    os.makedirs(gen_save_dir, exist_ok=True)
    
    audio_data = audio.squeeze().detach().cpu().numpy()
    file_path = os.path.join(gen_save_dir, f"audio_sample_{h}.wav")
    torchaudio.save(file_path, torch.from_numpy(audio_data), 48000)

    torch.cuda.empty_cache()

In [None]:
from IPython.display import Audio

Audio(data=audio.squeeze().detach().cpu().numpy(), rate=48000, autoplay=True)

In [None]:
print(all_ids.shape)

In [None]:
print(output.shape)

In [None]:
torch.cuda.empty_cache()

In [None]:
print(all_ids[2800:2850])