In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW
from jukebox.make_models import make_vqvae
import librosa
import os

ModuleNotFoundError: No module named 'jukebox'

In [None]:
# ------------------- Dataset Class -------------------
class AudioTextDataset(Dataset):
    def __init__(self, audio_dir, text_file, tokenizer, max_length=512):
        """
        Dataset to load audio and text pairs.
        :param audio_dir: Directory containing audio files.
        :param text_file: File containing corresponding text descriptions (one per line).
        :param tokenizer: Tokenizer for text processing.
        :param max_length: Maximum token length for text sequences.
        """
        self.audio_files = sorted([os.path.join(audio_dir, f) for f in os.listdir(audio_dir) if f.endswith(".wav")])
        with open(text_file, 'r') as f:
            self.text_descriptions = [line.strip() for line in f.readlines()]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.audio_files)

    def preprocess_audio(self, filepath):
        audio, sr = librosa.load(filepath, sr=44100, mono=True)
        return torch.tensor(audio, dtype=torch.float32).unsqueeze(0)  # Add batch dimension

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        text = self.text_descriptions[idx]
        audio = self.preprocess_audio(audio_path)
        tokens = self.tokenizer(
            text, return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length"
        ).input_ids.squeeze(0)
        return audio, tokens


In [None]:
def load_jukebox_encoder():
    """
    Load the Jukebox encoder.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    vqvae = make_vqvae(model_name='5b', device=device)
    encoder = vqvae.encoder
    return encoder

In [None]:

# ------------------- Fine-Tuning Function -------------------
def fine_tune(audio_dir, text_file, output_dir, epochs=3, batch_size=8, lr=5e-5):
    """
    Fine-tune a GPT-2 model using Jukebox embeddings and text descriptions.
    """
    # Load Jukebox encoder
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    encoder = load_jukebox_encoder()
    encoder.eval().to(device)

    # Load tokenizer and GPT-2 model
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

    # Prepare dataset and dataloader
    dataset = AudioTextDataset(audio_dir, text_file, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)

    # Optimizer
    optimizer = AdamW(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in dataloader:
            audios, token_ids = zip(*batch)
            
            # Process audio through encoder
            audio_tensors = torch.cat([audio.to(device) for audio in audios], dim=0)
            with torch.no_grad():
                embeddings = encoder(audio_tensors)
            
            # Flatten embeddings for input into GPT-2
            embeddings = embeddings.flatten(start_dim=1)

            # Prepare text inputs
            token_ids = torch.stack(token_ids).to(device)

            # Forward pass
            outputs = model(inputs_embeds=embeddings, labels=token_ids)
            loss = outputs.loss

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")

    # Save the fine-tuned model
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

In [None]:
audio_dir = "./audio_files"
text_file = "./descriptions.txt"
output_dir = "./fine_tuned_model"

# Train the model
fine_tune(audio_dir, text_file, output_dir, epochs=3, batch_size=4, lr=5e-5)