<a href="https://colab.research.google.com/github/hjangir080/EmotionAwareMusicGeneration/blob/main/DL_Project_EmotionAwareMusic_error_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
class MuLanEmbedding(nn.Module):
    def __init__(self, text_embedding_dim=768, audio_embedding_dim=512, joint_embedding_dim=256):
        super(MuLanEmbedding, self).__init__()

        # Text encoder (BERT-based)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.text_projection = nn.Linear(768, joint_embedding_dim)

        # Audio encoder
        self.audio_encoder = nn.Sequential(
            nn.Conv1d(80, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.audio_projection = nn.Linear(512, joint_embedding_dim)

    def encode_text(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token
        text_embedding = self.text_projection(cls_embedding)
        return F.normalize(text_embedding, p=2, dim=1)

    def encode_audio(self, mel_spectrogram):
        # mel_spectrogram shape: [batch_size, freq_bins, time_frames]
        audio_features = self.audio_encoder(mel_spectrogram).squeeze(-1)
        audio_embedding = self.audio_projection(audio_features)
        return F.normalize(audio_embedding, p=2, dim=1)

    def forward(self, input_ids, attention_mask, mel_spectrogram=None):
        text_embedding = self.encode_text(input_ids, attention_mask)

        if mel_spectrogram is not None:
            audio_embedding = self.encode_audio(mel_spectrogram)
            return text_embedding, audio_embedding

        return text_embedding

In [None]:
class SoundStreamDecoder(nn.Module):
    def __init__(self, latent_dim=256, output_channels=1):
        super(SoundStreamDecoder, self).__init__()

        self.initial_layer = nn.Linear(latent_dim, 16 * 256)

        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(32, output_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, latent_vector):
        # latent_vector shape: [batch_size, latent_dim]
        x = self.initial_layer(latent_vector)
        x = x.view(-1, 256, 16)  # Reshape to [batch_size, channels, length]
        waveform = self.decoder(x)
        return waveform

In [None]:
class CustomMusicLM(nn.Module):
    def __init__(self):
        super(CustomMusicLM, self).__init__()

        # Text-to-music embedding model
        self.mulan_model = MuLanEmbedding()

        # Latent vector generation
        self.latent_generator = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.Tanh()
        )

        # Audio decoder
        self.soundstream_decoder = SoundStreamDecoder()

    def forward(self, input_ids, attention_mask):
        # Get text embedding
        text_embedding = self.mulan_model.encode_text(input_ids, attention_mask)

        # Generate latent vector
        latent_vector = self.latent_generator(text_embedding)

        # Decode to audio
        waveform = self.soundstream_decoder(latent_vector)

        return waveform

In [None]:
def train_custom_musiclm(model, train_dataloader, num_epochs=100):
    # Define loss functions
    reconstruction_loss = nn.MSELoss()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_dataloader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            audio_targets = batch['audio_waveform']

            # Forward pass
            audio_output = model(input_ids, attention_mask)

            # Compute loss
            loss = reconstruction_loss(audio_output, audio_targets)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Print progress
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    print("Training complete!")

In [None]:
def use_custom_musiclm(literary_music_generator, text_prompts):
    # Initialize custom MusicLM model
    custom_musiclm = CustomMusicLM()

    # Load trained weights
    custom_musiclm.load_state_dict(torch.load('custom_musiclm.pth'))
    custom_musiclm.eval()

    # Tokenize the text prompts
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

    audio_outputs = []

    for prompt in text_prompts:
        # Tokenize
        encoded_input = tokenizer(
            prompt,
            padding='max_length',
            truncation=True,
            max_length=128,
            return_tensors='pt'
        )

        # Generate audio
        with torch.no_grad():
            audio_output = custom_musiclm(
                encoded_input['input_ids'],
                encoded_input['attention_mask']
            )

        audio_outputs.append(audio_output.squeeze().cpu().numpy())

    return audio_outputs

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import BertModel, BertTokenizer
from tqdm import tqdm


class MuLANTextEncoder(nn.Module):
    """Text encoder based on BERT with MuLAN-style additions for music understanding"""
    def __init__(self, bert_model="bert-base-uncased", embedding_dim=512):
        super(MuLANTextEncoder, self).__init__()
        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(bert_model)
        self.tokenizer = BertTokenizer.from_pretrained(bert_model)

        # Project BERT outputs to our embedding space
        self.projection = nn.Linear(self.bert.config.hidden_size, embedding_dim)

        # Music-specific token embeddings (words related to music semantics)
        self.music_token_embedding = nn.Embedding(1000, embedding_dim)
        self.music_vocab = self._create_music_vocab()

    def _create_music_vocab(self):
        """Create vocabulary mapping for music-specific terms"""
        music_terms = [
            "tempo", "rhythm", "melody", "harmony", "bass", "treble",
            "major", "minor", "piano", "guitar", "drums", "strings",
            "loud", "soft", "fast", "slow", "staccato", "legato",
            # Emotions
            "happy", "sad", "angry", "fearful", "tender", "excited"
            # Add more music-specific terms
        ]
        return {term: i for i, term in enumerate(music_terms)}

    def forward(self, text):
        """Encode text into music-aware embeddings"""
        # Tokenize input
        tokens = self.tokenizer(text, return_tensors="pt",
                               padding=True, truncation=True, max_length=512)

        # Move to the same device as the model
        tokens = {k: v.to(self.bert.device) for k, v in tokens.items()}

        # Get BERT embeddings
        with torch.no_grad():
            outputs = self.bert(**tokens)

        # Use the [CLS] token embedding as sequence representation
        sequence_embedding = outputs.last_hidden_state[:, 0, :]

        # Project to our embedding space
        projected_embedding = self.projection(sequence_embedding)

        # Extract and enhance music-specific terms
        music_embedding = self._enhance_music_terms(text, projected_embedding)

        return music_embedding

    def _enhance_music_terms(self, text, embedding):
        """Enhance embeddings with music-specific token information"""
        # Simple version - just detect music terms and add their embeddings
        batch_enhanced = []

        for i, t in enumerate(text):
            t_lower = t.lower()

            # Initialize music embedding contribution
            music_contrib = torch.zeros_like(embedding[i])
            count = 0

            # Look for music terms
            for term, idx in self.music_vocab.items():
                if term in t_lower:
                    music_contrib += self.music_token_embedding(torch.tensor([idx],
                                    device=embedding.device)).squeeze(0)
                    count += 1

            # Add weighted music embedding if any terms found
            if count > 0:
                enhanced = embedding[i] + (music_contrib / count) * 0.3  # 30% contribution
            else:
                enhanced = embedding[i]

            batch_enhanced.append(enhanced)

        return torch.stack(batch_enhanced)


class SoundStreamEncoder(nn.Module):
    """Audio encoder inspired by SoundStream architecture"""
    def __init__(self, input_channels=1, embedding_dim=512):
        super(SoundStreamEncoder, self).__init__()

        # Convolutional encoder
        self.encoder = nn.Sequential(
            # Initial convolution
            nn.Conv1d(input_channels, 32, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),

            # Downsampling convolutions
            nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1),  # Downsample 2x
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),  # Downsample 2x
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),  # Downsample 2x
            nn.ReLU(),
            nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1),  # Downsample 2x
            nn.ReLU(),

            # Additional convolutions
            nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

        # Projection to embedding space
        self.projection = nn.Linear(512, embedding_dim)

        # Quantizer (VQ layer)
        self.codebook_size = 1024
        self.codebook = nn.Embedding(self.codebook_size, embedding_dim)

    def forward(self, x):
        """
        Encode audio into latent representations
        x shape: [batch_size, 1, time]
        """
        # Apply convolutional encoder
        encoded = self.encoder(x)  # [batch_size, 512, time/16]

        # Global pooling to get fixed-size representation
        pooled = F.adaptive_avg_pool1d(encoded, 1).squeeze(-1)  # [batch_size, 512]

        # Project to embedding space
        embedding = self.projection(pooled)  # [batch_size, embedding_dim]

        # Nearest-neighbor lookup in the codebook (simplified VQ)
        # In practice, SoundStream uses a more complex residual VQ approach
        distances = torch.sum(embedding.unsqueeze(1)**2, dim=2) + \
                   torch.sum(self.codebook.weight**2, dim=1) - \
                   2 * torch.matmul(embedding, self.codebook.weight.t())

        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.codebook(encoding_indices)

        # Straight-through estimator
        quantized_st = embedding + (quantized - embedding).detach()

        return quantized_st, encoded


class SoundStreamDecoder(nn.Module):
    """Audio decoder inspired by SoundStream architecture"""
    def __init__(self, embedding_dim=512, output_channels=1):
        super(SoundStreamDecoder, self).__init__()

        # Project embedding to the right dimension for the decoder
        self.pre_decoder = nn.Linear(embedding_dim, 512)

        # Convolutional decoder
        self.decoder = nn.Sequential(
            # Initial convolution
            nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),

            # Upsampling convolutions using transposed convolutions
            nn.ConvTranspose1d(512, 256, kernel_size=4, stride=2, padding=1),  # Upsample 2x
            nn.ReLU(),
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),  # Upsample 2x
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),   # Upsample 2x
            nn.ReLU(),
            nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1),    # Upsample 2x
            nn.ReLU(),

            # Final convolution to get to the right number of channels
            nn.Conv1d(32, output_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()  # Output in [-1, 1] range for audio
        )

    def forward(self, z, encoded=None, length=16000):
        """
        Decode latent representation to audio
        z shape: [batch_size, embedding_dim]
        encoded shape (optional): [batch_size, 512, time/16]
        """
        # Project to the right dimension
        z_proj = self.pre_decoder(z)  # [batch_size, 512]

        if encoded is not None:
            # Use the temporal information from the encoder
            z_temporal = z_proj.unsqueeze(-1) * F.adaptive_avg_pool1d(encoded, encoded.size(-1))
        else:
            # Create a temporal dimension
            time_steps = length // 16  # Depends on the encoder downsampling
            z_temporal = z_proj.unsqueeze(-1).repeat(1, 1, time_steps)

        # Apply convolutional decoder
        decoded = self.decoder(z_temporal)  # [batch_size, output_channels, time]

        return decoded


class MusicConditionedUNet(nn.Module):
    """UNet-style model for high-resolution audio generation conditioned on text"""
    def __init__(self, input_channels=1, output_channels=1, base_channels=32, embedding_dim=512):
        super(MusicConditionedUNet, self).__init__()

        # Encoder (downsampling) blocks
        self.enc1 = self._encoder_block(input_channels, base_channels)
        self.enc2 = self._encoder_block(base_channels, base_channels*2)
        self.enc3 = self._encoder_block(base_channels*2, base_channels*4)
        self.enc4 = self._encoder_block(base_channels*4, base_channels*8)

        # Bottleneck with text condition integration
        self.bottleneck = nn.Sequential(
            nn.Conv1d(base_channels*8, base_channels*16, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=base_channels*16),
            nn.ReLU(),
            nn.Conv1d(base_channels*16, base_channels*16, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=base_channels*16),
            nn.ReLU()
        )
        self.text_projection = nn.Linear(embedding_dim, base_channels*16)

        # Decoder (upsampling) blocks
        self.dec4 = self._decoder_block(base_channels*16, base_channels*8)
        self.dec3 = self._decoder_block(base_channels*16, base_channels*4)  # Double because of skip connection
        self.dec2 = self._decoder_block(base_channels*8, base_channels*2)
        self.dec1 = self._decoder_block(base_channels*4, base_channels)

        # Final convolution
        self.final_conv = nn.Conv1d(base_channels*2, output_channels, kernel_size=1)

    def _encoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=4, num_channels=out_channels),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, stride=2),  # Downsample
            nn.GroupNorm(num_groups=4, num_channels=out_channels),
            nn.ReLU()
        )

    def _decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose1d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),  # Upsample
            nn.GroupNorm(num_groups=4, num_channels=out_channels),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=4, num_channels=out_channels),
            nn.ReLU()
        )

    def forward(self, x, text_embedding):
        """
        Forward pass with text conditioning
        x: [batch_size, 1, time] - Initial audio or noise
        text_embedding: [batch_size, embedding_dim] - Text embeddings
        """
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc4)

        # Apply text conditioning
        text_proj = self.text_projection(text_embedding).unsqueeze(-1)
        text_proj = text_proj.repeat(1, 1, bottleneck.size(-1))
        bottleneck = bottleneck + 0.1 * text_proj  # Add text features with scaling

        # Decoder with skip connections
        dec4 = self.dec4(bottleneck)
        dec3 = self.dec3(torch.cat([dec4, enc4], dim=1))
        dec2 = self.dec2(torch.cat([dec3, enc3], dim=1))
        dec1 = self.dec1(torch.cat([dec2, enc2], dim=1))

        # Final layer with skip connection to first encoder output
        output = self.final_conv(torch.cat([dec1, enc1], dim=1))

        return output


class MusicLMModel(nn.Module):
    """Complete MusicLM-inspired model combining all components"""
    def __init__(self, embedding_dim=512, sample_rate=16000):
        super(MusicLMModel, self).__init__()

        # Text encoder (MuLAN-inspired)
        self.text_encoder = MuLANTextEncoder(embedding_dim=embedding_dim)

        # Audio encoder (SoundStream-inspired)
        self.audio_encoder = SoundStreamEncoder(embedding_dim=embedding_dim)

        # Audio decoder (SoundStream-inspired)
        self.audio_decoder = SoundStreamDecoder(embedding_dim=embedding_dim)

        # High-resolution audio refinement
        self.audio_unet = MusicConditionedUNet(embedding_dim=embedding_dim)

        # Sampling parameters
        self.sample_rate = sample_rate

    def encode_text(self, text_prompts):
        """Encode text prompts to the joint embedding space"""
        return self.text_encoder(text_prompts)

    def encode_audio(self, audio):
        """Encode audio to the joint embedding space"""
        return self.audio_encoder(audio)

    def decode_audio(self, embedding, length=16000):
        """Decode embedding to audio"""
        return self.audio_decoder(embedding, length=length)

    def refine_audio(self, audio, text_embedding):
        """Apply high-resolution refinement to audio based on text embedding"""
        return self.audio_unet(audio, text_embedding)

    def generate_from_text(self, text_prompt, length_seconds=10,
                          refinement_steps=50, noise_level=0.5):
        """Generate audio from text prompt"""
        # Encode text prompt
        text_embedding = self.encode_text([text_prompt])

        # Calculate audio length in samples
        audio_length = int(length_seconds * self.sample_rate)

        # Generate initial audio from embedding
        z_audio, _ = self.audio_encoder(torch.randn(1, 1, audio_length).to(text_embedding.device))
        rough_audio = self.audio_decoder(z_audio, length=audio_length)

        # Iterative refinement using the UNet
        current_audio = rough_audio

        for step in tqdm(range(refinement_steps), desc="Refining audio"):
            # Add noise proportional to the remaining steps
            remaining_factor = (refinement_steps - step) / refinement_steps
            noise_scale = noise_level * remaining_factor
            noised_audio = current_audio + torch.randn_like(current_audio) * noise_scale

            # Refine using UNet with text conditioning
            refined = self.audio_unet(noised_audio, text_embedding)

            # Weighted average of previous and refined audio to ensure stability
            alpha = min(0.1 + step / (refinement_steps * 0.8), 0.8)
            current_audio = (1 - alpha) * current_audio + alpha * refined

        return current_audio.squeeze(0)

In [None]:
import os
import torch
import torchaudio
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class MusicTextPairDataset(Dataset):
    """Dataset of paired music and text descriptions"""
    def __init__(self, data_dir, metadata_file, max_audio_length=160000, sample_rate=16000):
        self.data_dir = data_dir
        self.sample_rate = sample_rate
        self.max_audio_length = max_audio_length

        # Load metadata
        self.metadata = pd.read_csv(os.path.join(data_dir, metadata_file))

        # Filter out entries with missing audio or text
        self.metadata = self.metadata.dropna(subset=['audio_file', 'text_description'])

        print(f"Loaded dataset with {len(self.metadata)} examples")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # Get metadata for this item
        item = self.metadata.iloc[idx]

        # Load audio file
        audio_path = os.path.join(self.data_dir, item['audio_file'])
        waveform, sr = torchaudio.load(audio_path)

        # Convert to mono if needed
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Resample if needed
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            waveform = resampler(waveform)

        # Pad or trim to max length
        if waveform.shape[1] < self.max_audio_length:
            # Pad
            padding = torch.zeros(1, self.max_audio_length - waveform.shape[1])
            waveform = torch.cat([waveform, padding], dim=1)
        else:
            # Trim
            waveform = waveform[:, :self.max_audio_length]

        # Get text description
        text = item['text_description']

        # Get any additional metadata for fine-grained conditioning
        metadata = {}
        for col in self.metadata.columns:
            if col not in ['audio_file', 'text_description']:
                metadata[col] = item[col]

        return {
            'audio': waveform,
            'text': text,
            'metadata': metadata
        }

def prepare_musiccaps_dataset():
    """
    Prepare dataset based on MusicCaps or similar dataset structure

    Note: This is a placeholder implementation. You would need to:
    1. Download the MusicCaps dataset or similar
    2. Create a metadata CSV with audio_file and text_description columns
    """
    # This is a placeholder - you would implement actual data download and preparation
    # For a real implementation, you might use the AudioSet-based MusicCaps dataset

    # Create a sample metadata file structure
    sample_data = []
    for i in range(100):
        sample_data.append({
            'audio_file': f'audio/sample_{i}.wav',
            'text_description': f'A {["happy", "sad", "energetic", "calm"][i%4]} music piece with {["piano", "guitar", "drums", "violin"][i%4]}',
            'genre': ['classical', 'rock', 'jazz', 'electronic'][i%4],
            'tempo': np.random.randint(60, 180)
        })

    metadata = pd.DataFrame(sample_data)
    metadata.to_csv('sample_metadata.csv', index=False)

    print("Created sample metadata file for demonstration")
    print("In a real implementation, you'd need to download the actual audio files")

def create_audio_synthesizer_dataset(text_processor, emotion_extractor, music_generator,
                                    output_dir, num_samples=1000):
    """
    Create a dataset of MIDI files and corresponding text descriptions
    for training the audio synthesis model
    """
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'midi'), exist_ok=True)

    metadata = []

    # Generate example texts (for demonstration)
    example_texts = [
        "A bright sunny day with birds chirping in the trees",
        "The storm raged through the night, thunder shaking the windows",
        "She felt a deep sadness as she read the letter from her old friend",
        "The excitement of the race filled the stadium with energy",
        "A quiet moment of reflection by the peaceful lake"
    ]

    # Extend examples by repeating and modifying
    expanded_texts = []
    for _ in range(num_samples // len(example_texts) + 1):
        for text in example_texts:
            # Add some variation
            words = text.split()
            if len(words) > 10:
                # Randomly remove or duplicate some words
                if np.random.random() > 0.5:
                    remove_idx = np.random.randint(0, len(words))
                    words.pop(remove_idx)
                else:
                    dup_idx = np.random.randint(0, len(words))
                    words.insert(dup_idx, words[dup_idx])

            expanded_texts.append(" ".join(words))

    expanded_texts = expanded_texts[:num_samples]

    # Process each text
    for i, text in enumerate(tqdm(expanded_texts, desc="Generating dataset")):
        # Process text to extract emotions
        _, cleaned_segments = text_processor.process_text(text)
        emotion_maps = emotion_extractor.extract_emotions(cleaned_segments)

        # For simplicity, use only the first segment
        if emotion_maps:
            emotion_map = emotion_maps[0]

            # Convert to emotion vector
            emotion_vector = torch.tensor([
                emotion_map['joy'], emotion_map['sadness'], emotion_map['anger'],
                emotion_map['fear'], emotion_map['surprise'], emotion_map['disgust'],
                emotion_map['neutral']
            ], dtype=torch.float32).unsqueeze(0)

            # Map to musical features
            with torch.no_grad():
                features = music_generator.emotion_to_music_mapper(emotion_vector)

            # Convert to actual values
            actual_values = music_generator.emotion_to_music_mapper.map_to_actual_values(features.squeeze(0))

            # Generate MusicLM prompt
            prompt = music_generator.emotion_to_music_mapper.generate_musiclm_prompt(
                actual_values, emotion_scores=emotion_map
            )

            # Generate MIDI
            midi_buffer = music_generator.music_generator.create_midi_from_features(actual_values)

            # Save MIDI
            midi_filename = f'sample_{i:04d}.mid'
            midi_path = os.path.join(output_dir, 'midi', midi_filename)

            with open(midi_path, 'wb') as f:
                f.write(midi_buffer.getvalue())

            # Add to metadata
            metadata.append({
                'midi_file': f'midi/{midi_filename}',
                'text_description': prompt,
                'original_text': text,
                'joy': float(emotion_map['joy']),
                'sadness': float(emotion_map['sadness']),
                'anger': float(emotion_map['anger']),
                'fear': float(emotion_map['fear']),
                'surprise': float(emotion_map['surprise']),
                'disgust': float(emotion_map['disgust']),
                'neutral': float(emotion_map['neutral']),
                'tempo': float(actual_values['tempo']),
                'key': int(actual_values['key']),
                'mode': float(actual_values['mode']),
                'intensity': float(actual_values['intensity'])
            })

    # Save metadata
    metadata_df = pd.DataFrame(metadata)
    metadata_path = os.path.join(output_dir, 'metadata.csv')
    metadata_df.to_csv(metadata_path, index=False)

    print(f"Created dataset with {len(metadata)} examples")
    print(f"Metadata saved to {metadata_path}")

    return metadata_path

In [None]:
# 📍 SETUP & DEPENDENCIES
!pip install transformers torchaudio librosa datasets

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
import librosa
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F

# 📌 PART 1: MuLanEmbedding
class MuLanEmbedding(nn.Module):
    def __init__(self, input_dim=256):
        super().__init__()
        self.fc = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        return self.fc(x)

# 📌 PART 2: SoundStreamDecoder
class SoundStreamDecoder(nn.Module):
    def __init__(self, input_dim=256):
        super().__init__()
        self.decoder = nn.Linear(input_dim, 16000)

    def forward(self, x):
        return self.decoder(x)

# 📌 PART 3: CustomMusicLM
class CustomMusicLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = MuLanEmbedding()
        self.decoder = SoundStreamDecoder()

    def forward(self, input_ids, attention_mask):
        embedded = self.embedding(torch.randn(input_ids.size(0), 256))
        audio = self.decoder(embedded)
        return audio

# 📌 PART 5: EmotionExtractor
class EmotionExtractor:
    def __init__(self):
        self.model = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
        self.tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
        self.model.eval()

    def extract_emotions(self, text_segments):
        emotions = []
        for segment in text_segments:
            inputs = self.tokenizer(segment, return_tensors="pt", truncation=True)
            with torch.no_grad():
                outputs = self.model(**inputs)
            probs = F.softmax(outputs.logits, dim=-1)
            label = torch.argmax(probs, dim=1).item()
            emotions.append(label)
        return emotions

# 📌 PART 6: TextProcessor
class TextProcessor:
    def split_text(self, text, segment_length=50):
        words = text.split()
        return [" ".join(words[i:i+segment_length]) for i in range(0, len(words), segment_length)]

# 📌 PART 7: EmotionToPromptMapper
class EmotionToPromptMapper:
    def map_emotions_to_prompts(self, emotions):
        emotion_to_prompt = {
            0: "sad and slow",
            1: "joyful and bright",
            2: "angry and intense",
            3: "calm and peaceful",
            4: "scared and tense",
            5: "surprised and whimsical"
        }
        return [emotion_to_prompt.get(e, "neutral") for e in emotions]

# 📌 PART 10: DummyMusicDataset
class DummyMusicDataset(Dataset):
    def __init__(self, tokenizer, prompts, waveforms):
        self.tokenizer = tokenizer
        self.prompts = prompts
        self.waveforms = waveforms

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        tokens = self.tokenizer(
            self.prompts[idx], padding='max_length', truncation=True, max_length=128, return_tensors='pt'
        )
        return tokens['input_ids'].squeeze(), tokens['attention_mask'].squeeze(), self.waveforms[idx]

# 📌 PART 9: Sample Data
prompts = ["A happy forest adventure", "A tragic night alone"]
dummy_waveforms = [torch.randn(1, 256) for _ in range(len(prompts))]
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
train_dataset = DummyMusicDataset(tokenizer, prompts, dummy_waveforms)
train_loader = DataLoader(train_dataset, batch_size=2)

# 📌 PART 8: Training Function
def train_custom_musiclm(model, dataloader, num_epochs=3):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.MSELoss()

    model.train()
    for epoch in range(num_epochs):
        for input_ids, attention_mask, waveforms in dataloader:
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, waveforms.view(outputs.shape))
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

# 📌 PART 4: Main Generator Class
class LiteraryMusicGenerator:
    def __init__(self):
        self.text_processor = TextProcessor()
        self.emotion_extractor = EmotionExtractor()
        self.mapper = EmotionToPromptMapper()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.custom_musiclm = CustomMusicLM()

    def generate_music(self, text):
        segments = self.text_processor.split_text(text)
        emotions = self.emotion_extractor.extract_emotions(segments)
        text_prompts = self.mapper.map_emotions_to_prompts(emotions)
        self.audio_waveforms = use_custom_musiclm(self, text_prompts)
        return self.audio_waveforms

# 📌 PART 11: Generation Wrapper
def use_custom_musiclm(literary_music_generator, text_prompts):
    custom_musiclm = literary_music_generator.custom_musiclm
    tokenizer = literary_music_generator.tokenizer

    audio_outputs = []
    for prompt in text_prompts:
        encoded_input = tokenizer(prompt, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        with torch.no_grad():
            audio_output = custom_musiclm(encoded_input['input_ids'], encoded_input['attention_mask'])
        audio_outputs.append(audio_output.squeeze().cpu().numpy())
    return audio_outputs

# ✅ TRAIN AND GENERATE
model = CustomMusicLM()
train_custom_musiclm(model, train_loader, num_epochs=3)
generator = LiteraryMusicGenerator()
generator.custom_musiclm = model

sample_text = "The sun rose gently, casting golden hues over the quiet meadow. A feeling of joy filled the air."
audio_waveforms = generator.generate_music(sample_text)

print("Generated", len(audio_waveforms), "audio segments 🎵")


Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.6.0->to

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

RuntimeError: shape '[2, 16000]' is invalid for input of size 512

In [None]:
!pip install midiutil

Collecting midiutil
  Downloading MIDIUtil-1.2.1.tar.gz (1.0 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/1.0 MB[0m [31m4.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.0/1.0 MB[0m [31m15.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: midiutil
  Building wheel for midiutil (setup.py) ... [?25l[?25hdone
  Created wheel for midiutil: filename=MIDIUtil-1.2.1-py3-none-any.whl size=54569 sha256=56f0e829aaef33b2ef2c28122f99f805b9f240a3533b0e03b1e6006c67e34497
  Stored in directory: /root/.cache/pip/wheels/6c/42/75/fce10c67f06fe627fad8acd1fd3a004a24e07b0f077761fbbd
Successf

In [None]:
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=2fe33c2bcbfa712e0f1c0852b43378bac76a2dc1d8ac9d30dfd44ce99095e665
  Stored in directory: /root/.cache/pip/wheels/e6/95/ac/15ceaeb2823b04d8e638fd1495357adb8d26c00ccac9d7782e
Successfully built pretty_midi
Installing collected packages: mido, pretty_midi
Successf

In [None]:
!pip install pydub

Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: pydub
Successfully installed pydub-0.25.1


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import nltk
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from sklearn.preprocessing import MinMaxScaler
import json
import os
from google.colab import files
from midiutil import MIDIFile
import pretty_midi
import IPython.display as ipd
import random
import librosa
from pydub import AudioSegment
from pydub.playback import play
import io
import base64
import tempfile

# Download NLTK resources properly
# This ensures the data is downloaded correctly in Colab
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
# Make sure the above downloads complete before proceeding
try:
    word_tokenize("Testing NLTK")
    print("NLTK resources downloaded successfully!")
except LookupError:
    print("Downloading additional NLTK resources...")
    # Alternative download method
    !python -m nltk.downloader punkt
    !python -m nltk.downloader stopwords
    !python -m nltk.downloader wordnet

class TextProcessor:
    def __init__(self):
        # Make sure stopwords are available
        try:
            self.stop_words = set(stopwords.words('english'))
        except LookupError:
            nltk.download('stopwords')
            self.stop_words = set(stopwords.words('english'))

        # Make sure wordnet is available
        try:
            self.lemmatizer = WordNetLemmatizer()
        except LookupError:
            nltk.download('wordnet')
            self.lemmatizer = WordNetLemmatizer()

    def clean_text(self, text):
        """Remove punctuation, lowercase, remove stopwords, and lemmatize"""
        # Lowercase the text
        text = text.lower()
        # Remove punctuation
        text = re.sub(r'[^\w\s]', '', text)

        # Try to tokenize, with fallback method if it fails
        try:
            # Tokenize into words
            words = word_tokenize(text)
        except LookupError:
            # Fallback: simple space-based tokenization
            words = text.split()

        # Remove stopwords and lemmatize
        cleaned_words = [self.lemmatizer.lemmatize(word) for word in words if word not in self.stop_words]
        return ' '.join(cleaned_words)

    def segment_text(self, text, segment_size=500):
        """Split text into segments of roughly equal size"""
        try:
            sentences = sent_tokenize(text)
        except LookupError:
            # Fallback: simple period-based sentence splitting
            sentences = [s.strip() + '.' for s in text.split('.') if s.strip()]

        segments = []
        current_segment = []
        current_length = 0

        for sentence in sentences:
            sentence_length = len(sentence)
            if current_length + sentence_length > segment_size and current_segment:
                segments.append(' '.join(current_segment))
                current_segment = [sentence]
                current_length = sentence_length
            else:
                current_segment.append(sentence)
                current_length += sentence_length

        # Add the last segment if it exists
        if current_segment:
            segments.append(' '.join(current_segment))

        return segments

    def process_text(self, text, segment_size=500):
        """Process the full text: segment first, then clean each segment"""
        segments = self.segment_text(text, segment_size)
        cleaned_segments = [self.clean_text(segment) for segment in segments]
        # Also keep the original segments for display purposes
        return segments, cleaned_segments

class EmotionExtractor:
    def __init__(self):
        # Load pre-trained emotion classification model
        try:
            self.emotion_classifier = pipeline(
                "text-classification",
                model="j-hartmann/emotion-english-distilroberta-base",
                return_all_scores=True
            )
        except:
            # Handle the warning about return_all_scores being deprecated
            self.emotion_classifier = pipeline(
                "text-classification",
                model="j-hartmann/emotion-english-distilroberta-base",
                top_k=None  # Using top_k=None instead of return_all_scores=True
            )

        # Define our emotion categories
        self.emotion_categories = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'disgust', 'neutral']

    def extract_emotions(self, text_segments):
        """Extract emotions from text segments"""
        emotion_maps = []

        for segment in text_segments:
            # Get emotion scores for the segment
            emotion_scores = self.emotion_classifier(segment)[0]

            # Convert to dictionary with emotion as key and score as value
            emotion_dict = {item['label']: item['score'] for item in emotion_scores}

            # Map the model's emotions to our simplified set if needed
            mapped_emotions = {
                'joy': emotion_dict.get('joy', 0),
                'sadness': emotion_dict.get('sadness', 0),
                'anger': emotion_dict.get('anger', 0),
                'fear': emotion_dict.get('fear', 0),
                'surprise': emotion_dict.get('surprise', 0),
                'disgust': emotion_dict.get('disgust', 0),
                'neutral': emotion_dict.get('neutral', 0)
            }

            emotion_maps.append(mapped_emotions)

        return emotion_maps

    def create_emotional_progression(self, emotion_maps):
        """Create a time series of emotions for the entire text"""
        progression = {emotion: [] for emotion in self.emotion_categories}

        for emotion_map in emotion_maps:
            for emotion in self.emotion_categories:
                progression[emotion].append(emotion_map[emotion])

        return progression

    def get_dominant_emotions(self, emotion_maps, top_n=2):
        """Get the dominant emotions for each segment"""
        dominant_emotions = []

        for emotion_map in emotion_maps:
            # Sort emotions by score
            sorted_emotions = sorted(emotion_map.items(), key=lambda x: x[1], reverse=True)
            # Take top n emotions
            top_emotions = sorted_emotions[:top_n]
            dominant_emotions.append(top_emotions)

        return dominant_emotions
class EmotionToMusicMapper(nn.Module):
    def __init__(self, input_dim=7, hidden_dim=64, output_dim=10):
        super(EmotionToMusicMapper, self).__init__()

        # Neural network for mapping emotions to musical features
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()  # Normalized output
        )

        # Define music feature mappings
        self.musical_features = {
            'tempo': {'min': 60, 'max': 180},     # BPM
            'key': {'min': 0, 'max': 11},         # C=0, C#=1, ..., B=11
            'mode': {'min': 0, 'max': 1},         # Minor=0, Major=1
            'intensity': {'min': 0, 'max': 1},    # Soft to loud
            'instrumentation': {'min': 0, 'max': 5},  # Different instrument groups
            'rhythm_complexity': {'min': 0, 'max': 1},
            'harmonic_complexity': {'min': 0, 'max': 1},
            'melodic_range': {'min': 0, 'max': 1},
            'texture': {'min': 0, 'max': 1},      # Sparse to dense
            'articulation': {'min': 0, 'max': 1}  # Staccato to legato
        }

        # Map features to index in output
        self.feature_to_idx = {feature: i for i, feature in enumerate(self.musical_features.keys())}

        # Define emotional correlations (expanded with more nuanced mappings)
        self.emotion_correlations = {
            'joy': {
                'tempo': 'high',
                'mode': 'major',
                'intensity': 'moderate-high',
                'rhythm_complexity': 'moderate',
                'harmonic_complexity': 'moderate',
                'instrumentation': ['piano', 'strings', 'synth'],
                'description': ['uplifting', 'bright', 'cheerful', 'buoyant', 'exuberant', 'playful', 'optimistic']
            },
            'sadness': {
                'tempo': 'low',
                'mode': 'minor',
                'intensity': 'low',
                'rhythm_complexity': 'low',
                'harmonic_complexity': 'moderate-high',
                'instrumentation': ['piano', 'strings', 'guitar'],
                'description': ['melancholic', 'somber', 'wistful', 'contemplative', 'haunting', 'bittersweet', 'mournful']
            },
            'anger': {
                'tempo': 'high',
                'mode': 'minor',
                'intensity': 'high',
                'rhythm_complexity': 'high',
                'harmonic_complexity': 'high',
                'instrumentation': ['synth', 'percussion', 'orchestral'],
                'description': ['intense', 'aggressive', 'powerful', 'driving', 'dissonant', 'chaotic', 'turbulent']
            },
            'fear': {
                'tempo': 'variable',
                'mode': 'minor',
                'intensity': 'variable',
                'rhythm_complexity': 'low',
                'harmonic_complexity': 'high',
                'instrumentation': ['strings', 'synth', 'percussion'],
                'description': ['tense', 'suspenseful', 'eerie', 'unsettling', 'mysterious', 'foreboding', 'chilling']
            },
            'surprise': {
                'tempo': 'variable',
                'mode': 'variable',
                'intensity': 'variable',
                'rhythm_complexity': 'high',
                'harmonic_complexity': 'moderate',
                'instrumentation': ['piano', 'synth', 'orchestral'],
                'description': ['unexpected', 'quirky', 'sudden', 'playful', 'whimsical', 'unpredictable', 'startling']
            },
            'disgust': {
                'tempo': 'low-moderate',
                'mode': 'minor',
                'intensity': 'moderate',
                'rhythm_complexity': 'moderate',
                'harmonic_complexity': 'high',
                'instrumentation': ['synth', 'percussion', 'orchestral'],
                'description': ['dissonant', 'unsettling', 'gritty', 'uncomfortable', 'jarring', 'off-kilter', 'distorted']
            },
            'neutral': {
                'tempo': 'moderate',
                'mode': 'variable',
                'intensity': 'moderate',
                'rhythm_complexity': 'moderate',
                'harmonic_complexity': 'moderate',
                'instrumentation': ['piano', 'strings', 'guitar'],
                'description': ['balanced', 'ambient', 'atmospheric', 'calm', 'steady', 'peaceful', 'flowing']
            }
        }

    def forward(self, emotion_vector):
        """Map emotion vector to musical features"""
        return self.network(emotion_vector)

    def map_to_actual_values(self, normalized_features):
        """Convert normalized outputs to actual musical values"""
        actual_values = {}

        for feature, feature_range in self.musical_features.items():
            idx = self.feature_to_idx[feature]
            norm_value = normalized_features[idx].item()

            # Scale to the actual range
            min_val = feature_range['min']
            max_val = feature_range['max']
            actual_value = min_val + norm_value * (max_val - min_val)

            # Round as needed
            if feature in ['key', 'instrumentation']:
                actual_value = round(actual_value)

            actual_values[feature] = actual_value

        return actual_values

    def get_descriptors_for_emotion_blend(self, emotion_scores):
        """Get appropriate musical descriptors based on a blend of emotions"""
        # Get top 2 emotions
        sorted_emotions = sorted(emotion_scores.items(), key=lambda x: x[1], reverse=True)
        top_emotions = sorted_emotions[:2]

        descriptors = []
        instruments = set()

        # Weighted selection of descriptors based on emotion intensity
        total_weight = sum(score for _, score in top_emotions)

        for emotion, score in top_emotions:
            # Skip emotions with very low scores
            if score < 0.1:
                continue

            # Weight by the emotion's intensity
            weight = score / total_weight

            # Get descriptors for this emotion
            emotion_descriptors = self.emotion_correlations[emotion]['description']

            # Add descriptors proportional to the emotion's weight
            num_descriptors = max(1, int(weight * 3))  # At least 1, up to 3 descriptors
            selected_descriptors = random.sample(emotion_descriptors, min(num_descriptors, len(emotion_descriptors)))
            descriptors.extend(selected_descriptors)

            # Add potential instruments
            emotion_instruments = self.emotion_correlations[emotion]['instrumentation']
            # Choose 1-2 instruments based on weight
            num_instruments = max(1, int(weight * 2))
            selected_instruments = random.sample(emotion_instruments, min(num_instruments, len(emotion_instruments)))
            instruments.update(selected_instruments)

        # Return unique descriptors and instruments
        return list(set(descriptors)), list(instruments)

    def generate_musiclm_prompt(self, musical_features, emotion_scores=None):
        """Convert musical features to a MusicLM text prompt"""
        # Map key number to name
        key_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
        key_name = key_names[round(musical_features['key'])]

        # Map mode number to name
        mode_name = "major" if musical_features['mode'] > 0.5 else "minor"

        # Map instrumentation to instrument types
        instrument_types = ['piano', 'strings', 'guitar', 'synth', 'orchestral', 'percussion']
        instrument = instrument_types[round(musical_features['instrumentation'])]

        # Determine tempo description
        tempo = musical_features['tempo']
        if tempo < 80:
            tempo_desc = "slow"
        elif tempo < 120:
            tempo_desc = "moderate"
        else:
            tempo_desc = "fast"

        # Determine intensity description
        intensity = musical_features['intensity']
        if intensity < 0.3:
            intensity_desc = "soft"
        elif intensity < 0.7:
            intensity_desc = "moderate"
        else:
            intensity_desc = "powerful"

        # Get emotional descriptors if emotion scores are provided
        if emotion_scores:
            emotional_descriptors, suggested_instruments = self.get_descriptors_for_emotion_blend(emotion_scores)

            # Use suggested instruments if available, otherwise use the mapped one
            if suggested_instruments:
                # Prioritize the instrument from musical features but include others
                all_instruments = [instrument] + [i for i in suggested_instruments if i != instrument]
                # Take up to 2 instruments
                instrument_phrase = " and ".join(all_instruments[:2])
            else:
                instrument_phrase = instrument

            # Build a more emotionally nuanced prompt
            prompt = f"A {' and '.join(emotional_descriptors[:2])} piece in {key_name} {mode_name}, "
            prompt += f"{tempo_desc} tempo, {intensity_desc} in intensity, "
            prompt += f"featuring {instrument_phrase}, "

        else:
            # Use the original approach as fallback
            # Build the prompt
            prompt = f"A {intensity_desc} {tempo_desc} melody in {key_name} {mode_name}, "

            # Add texture
            texture = musical_features['texture']
            if texture < 0.3:
                prompt += "with a sparse arrangement, "
            elif texture > 0.7:
                prompt += "with a dense, layered arrangement, "

            # Add instrumentation
            prompt += f"featuring {instrument}, "

        # Add complexity
        rhythm_complexity = musical_features['rhythm_complexity']
        harmonic_complexity = musical_features['harmonic_complexity']

        if rhythm_complexity > 0.7 and harmonic_complexity > 0.7:
            prompt += "with complex rhythms and harmonies, "
        elif rhythm_complexity > 0.7:
            prompt += "with complex rhythms, "
        elif harmonic_complexity > 0.7:
            prompt += "with rich harmonies, "
        elif rhythm_complexity < 0.3 and harmonic_complexity < 0.3:
            prompt += "with simple, straightforward patterns, "

        # Add articulation
        articulation = musical_features['articulation']
        if articulation < 0.3:
            prompt += "played with staccato articulation."
        elif articulation > 0.7:
            prompt += "played with smooth, legato phrasing."
        else:
            prompt += "with balanced articulation."

        # Add emotional context at the end for a more complete prompt
        if emotion_scores:
            top_emotion = max(emotion_scores.items(), key=lambda x: x[1])
            if top_emotion[1] > 0.3:  # Only if the emotion is significant
                prompt += f" The music conveys a sense of {top_emotion[0]}."

        return prompt

    def generate_transition_prompt(self, prev_features, current_features, prev_emotions, current_emotions):
        """Generate a prompt specifically for transitions between emotional states"""
        # Identify the most significant emotional shift
        prev_top_emotion = max(prev_emotions.items(), key=lambda x: x[1])
        current_top_emotion = max(current_emotions.items(), key=lambda x: x[1])

        # Check if there's a significant emotional shift
        if prev_top_emotion[0] != current_top_emotion[0] and prev_top_emotion[1] > 0.3 and current_top_emotion[1] > 0.3:
            transition_type = f"transition from {prev_top_emotion[0]} to {current_top_emotion[0]}"
        else:
            # If no major emotional shift, focus on musical parameter changes
            changes = []

            # Check tempo change
            prev_tempo = prev_features['tempo']
            current_tempo = current_features['tempo']
            if abs(prev_tempo - current_tempo) > 20:  # Significant tempo change
                direction = "accelerating" if current_tempo > prev_tempo else "decelerating"
                changes.append(direction)

            # Check intensity change
            prev_intensity = prev_features['intensity']
            current_intensity = current_features['intensity']
            if abs(prev_intensity - current_intensity) > 0.3:  # Significant intensity change
                direction = "building in intensity" if current_intensity > prev_intensity else "becoming more subdued"
                changes.append(direction)

            # Check mode change
            prev_mode = "major" if prev_features['mode'] > 0.5 else "minor"
            current_mode = "major" if current_features['mode'] > 0.5 else "minor"
            if prev_mode != current_mode:
                changes.append(f"shifting from {prev_mode} to {current_mode}")

            if changes:
                transition_type = ", ".join(changes)
            else:
                transition_type = "gradual transition"

        # Generate the transition prompt
        prompt = f"A {transition_type} that maintains musical coherence while "

        # Add emotional descriptors
        prev_descriptors, _ = self.get_descriptors_for_emotion_blend(prev_emotions)
        current_descriptors, _ = self.get_descriptors_for_emotion_blend(current_emotions)

        if prev_descriptors and current_descriptors:
            prompt += f"evolving from {prev_descriptors[0]} to {current_descriptors[0]}. "

        # Add instrumentation continuity
        instrument_types = ['piano', 'strings', 'guitar', 'synth', 'orchestral', 'percussion']
        prev_instrument = instrument_types[round(prev_features['instrumentation'])]
        current_instrument = instrument_types[round(current_features['instrumentation'])]

        if prev_instrument == current_instrument:
            prompt += f"Featuring {prev_instrument} throughout. "
        else:
            prompt += f"Transitioning from {prev_instrument} to {current_instrument}. "

        return prompt
class TemporalCoherenceModel(nn.Module):
    def __init__(self, input_dim=10, hidden_dim=64, output_dim=10):
        super(TemporalCoherenceModel, self).__init__()

        # LSTM for sequence modeling
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.3
        )

        # Output layer
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()
        )

    def forward(self, sequence):
        """Process a sequence of musical features to ensure temporal coherence"""
        # Expected shape: [batch_size, sequence_length, features]
        lstm_out, _ = self.lstm(sequence)

        # Apply output layer to each time step
        coherent_sequence = self.output_layer(lstm_out)

        return coherent_sequence

class MusicGenerator:
    def __init__(self):
        # MIDI settings
        self.instruments = {
            'piano': 0,       # Acoustic Grand Piano
            'strings': 48,    # String Ensemble 1
            'guitar': 24,     # Acoustic Guitar (nylon)
            'synth': 80,      # Lead 1 (square)
            'orchestral': 48, # String Ensemble 1
            'percussion': 118 # Synth Drum
        }

        # Scales (for different keys and modes)
        self.major_scale = [0, 2, 4, 5, 7, 9, 11]  # Whole, Whole, Half, Whole, Whole, Whole, Half
        self.minor_scale = [0, 2, 3, 5, 7, 8, 10]  # Whole, Half, Whole, Whole, Half, Whole, Whole

        # Common chord progressions
        self.progressions = {
            'major': [
                [1, 4, 5, 1],       # I-IV-V-I
                [1, 6, 4, 5],       # I-vi-IV-V
                [1, 5, 6, 4],       # I-V-vi-IV
                [2, 5, 1, 6]        # ii-V-I-vi
            ],
            'minor': [
                [1, 4, 5, 1],       # i-iv-v-i
                [1, 6, 3, 7],       # i-VI-III-VII
                [1, 7, 6, 5],       # i-VII-VI-v
                [1, 4, 7, 3]        # i-iv-VII-III
            ]
        }

    def create_midi_from_features(self, musical_features, duration_seconds=15):
        """Generate MIDI file from musical features"""
        # Extract features
        tempo = musical_features['tempo']
        key = int(musical_features['key'])
        mode = 'major' if musical_features['mode'] > 0.5 else 'minor'
        instrument_type = ['piano', 'strings', 'guitar', 'synth', 'orchestral', 'percussion'][int(musical_features['instrumentation'])]
        instrument = self.instruments[instrument_type]

        # Derived parameters
        rhythm_complexity = musical_features['rhythm_complexity']
        harmonic_complexity = musical_features['harmonic_complexity']
        melodic_range = musical_features['melodic_range']
        intensity = musical_features['intensity']
        texture = musical_features['texture']
        articulation = musical_features['articulation']

        # Create MIDI file
        midi = MIDIFile(2)  # 2 tracks - one for melody, one for accompaniment
        track_melody = 0
        track_accomp = 1

        # Set tempo
        midi.addTempo(track_melody, 0, tempo)
        midi.addTempo(track_accomp, 0, tempo)

        # Choose scale based on mode
        scale = self.major_scale if mode == 'major' else self.minor_scale

        # Calculate beats based on tempo and duration
        total_beats = int((tempo / 60) * duration_seconds)

        # Set program (instrument)
        midi.addProgramChange(track_melody, 0, 0, instrument)
        midi.addProgramChange(track_accomp, 0, 0, instrument)

        # Choose chord progression based on mode and complexity
        prog_idx = min(int(harmonic_complexity * len(self.progressions[mode])), len(self.progressions[mode]) - 1)
        progression = self.progressions[mode][prog_idx]

        # Chord duration in beats
        chord_duration = max(4, total_beats // len(progression))
        repetitions = max(1, total_beats // (chord_duration * len(progression)))

        # Base octave for melody
        base_octave = 5 if instrument_type != 'piano' else 4

        # Calculate note range based on melodic range
        low_note = base_octave * 12  # C in the base octave
        high_note = low_note + int(12 + 12 * musical_features['melodic_range'])  # Up to an octave or two higher

        # Velocity (volume) based on intensity
        melody_velocity = 64 + int(intensity * 63)  # Between 64-127
        chord_velocity = int(melody_velocity * 0.8)  # Slightly quieter

        # Note duration modifier based on articulation (staccato to legato)
        duration_modifier = 0.5 + (articulation * 0.5)  # 0.5 (staccato) to 1.0 (legato)

        # Generate melody and chords
        current_beat = 0
        for rep in range(repetitions):
            for chord_idx, chord_root in enumerate(progression):
                # Map chord root to actual note in the key
                chord_root_idx = chord_root - 1  # Adjust for 0-indexing
                root_note = (key + scale[chord_root_idx]) % 12

                # Determine chord type based on position in scale
                is_major_chord = chord_root_idx in [0, 3, 4] if mode == 'major' else chord_root_idx in [2, 5]

                # Create chord notes (root, third, fifth)
                third_offset = 4 if is_major_chord else 3
                chord_notes = [
                    root_note + 60,  # Root note (C4 = 60, middle C)
                    root_note + 60 + third_offset,  # Third
                    root_note + 60 + 7   # Fifth
                ]

                # Add chord to accompaniment track with texture variation
                if texture < 0.3:
                    # Sparse - just root and fifth
                    midi.addNote(track_accomp, 0, chord_notes[0], current_beat, chord_duration * 0.9, chord_velocity)
                    midi.addNote(track_accomp, 0, chord_notes[2], current_beat, chord_duration * 0.9, chord_velocity)
                elif texture < 0.7:
                    # Medium - broken chord
                    for i, note in enumerate(chord_notes):
                        midi.addNote(track_accomp, 0, note, current_beat + i*0.5, chord_duration * 0.9 - i*0.5, chord_velocity)
                else:
                    # Dense - full chord plus extra notes
                    for note in chord_notes:
                        midi.addNote(track_accomp, 0, note, current_beat, chord_duration * 0.9, chord_velocity)
                    # Add extra notes for texture
                    midi.addNote(track_accomp, 0, chord_notes[0] + 12, current_beat + 1, chord_duration * 0.4, chord_velocity - 10)

                # Generate melody for this chord
                notes_per_beat = 1 + int(rhythm_complexity * 3)  # 1 to 4 notes per beat

                for beat_offset in range(chord_duration):
                    # Skip some beats randomly for variation
                    if random.random() < 0.2:
                        continue

                    for note_idx in range(notes_per_beat):
                        # Calculate precise timing
                        note_start = current_beat + beat_offset + (note_idx / notes_per_beat)

                        # Choose note from scale
                        scale_idx = random.randint(0, len(scale) - 1)
                        note = key + scale[scale_idx]

                        # Map to the right octave range
                        octave = random.randint(base_octave, base_octave + 1)
                        note = (note % 12) + (octave * 12)

                        # Ensure note is in our range
                        note = max(low_note, min(note, high_note))

                        # Calculate duration based on articulation and rhythm
                        note_duration = (1.0 / notes_per_beat) * duration_modifier

                        # Add note to melody track
                        if random.random() < 0.8:  # 80% chance to add a note (for rests)
                            midi.addNote(track_melody, 0, note, note_start, note_duration, melody_velocity)

                current_beat += chord_duration

        # Write MIDI file to bytes buffer
        buffer = io.BytesIO()
        midi.writeFile(buffer)
        buffer.seek(0)

        return buffer

# Helper functions for audio playback in Colab
def midi_to_audio(midi_buffer, sr=22050):
    """Convert MIDI to audio using pretty_midi"""
    # Create a temporary file to save the MIDI
    with tempfile.NamedTemporaryFile(suffix=".mid", delete=False) as temp_file:
        temp_file.write(midi_buffer.getvalue())
        temp_file_path = temp_file.name

    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(temp_file_path)

    # Synthesize audio
    audio = midi_data.synthesize(fs=sr)

    # Remove the temporary file
    os.remove(temp_file_path)

    return audio, sr

def play_midi_in_colab(midi_buffer):
    """Play MIDI in Colab notebook"""
    # Convert MIDI to audio
    audio, sr = midi_to_audio(midi_buffer)

    # Display audio player
    return ipd.Audio(audio, rate=sr)

def play_all_segments(midi_files):
    """Play all music segments sequentially"""
    for i, midi_buffer in enumerate(midi_files):
        print(f"Playing music for segment {i+1}...")
        display(play_midi_in_colab(midi_buffer))

# Example usage and demonstration
def demo():
    # Initialize the system
    music_generator = LiteraryMusicGenerator()

    # Example literary text (a short excerpt)
    example_text = """
    The morning dawned bright and clear, with a crispness in the air that promised a beautiful day ahead. Sarah smiled as she stepped outside, breathing in the fresh scent of wildflowers that carpeted the meadow beyond her cottage. It had been too long since she felt this sense of peace.

    Suddenly, dark clouds appeared on the horizon, rolling in with unexpected speed. The wind picked up, carrying with it the scent of rain and an undercurrent of foreboding. Sarah's smile faded as she watched the storm approach, her heart racing with a nameless anxiety.

    Lightning flashed, followed by a deafening crack of thunder that seemed to shake the very earth. Sarah ran back inside, slamming the door just as the rain began to pound against the windows like an angry fist demanding entry. She pressed her back against the door, breathing heavily, fighting the irrational fear that had gripped her.

    As quickly as it had arrived, the storm began to subside. The violent raindrops slowed to a gentle patter, and the thunder rumbled away into the distance. A ray of sunlight broke through the clouds, casting a golden glow over the rain-soaked landscape. Sarah felt her tension melting away, replaced by a profound sense of wonder at the resilience of nature and the beauty that can follow chaos.
    """

    # Process the text
    result = music_generator.process_literary_text(example_text)

    # Display the results
    print("Text Processing Complete!")
    print(f"Text divided into {len(result['segments'])} segments")

    print("\nEmotional Progression:")
    for emotion, values in result['emotion_progression'].items():
        print(f"  {emotion}: {[round(v, 2) for v in values]}")

    print("\nOverall Emotional Flow Narrative:")
    print(result['emotional_narrative'])

    print("\nSegment-by-Segment Music Prompts:")
    for i, prompt_data in enumerate(result['musiclm_prompts']):
        print(f"\nSegment {i+1}:")
        print(f"  Context: {prompt_data['context']}")
        print(f"  Dominant emotions: {prompt_data['dominant_emotions']}")
        print(f"  Enhanced MusicLM prompt: {prompt_data['prompt']}")
        print(f"  Playing music for this segment...")
        display(play_midi_in_colab(result['midi_files'][i]))

    if 'transition_prompts' in result and result['transition_prompts']:
        print("\nTransition Prompts Between Segments:")
        for i, transition in enumerate(result['transition_prompts']):
            print(f"\nTransition {i+1}: From Segment {transition['from_segment']+1} to {transition['to_segment']+1}")
            print(f"  Prompt: {transition['prompt']}")

    # Visualize emotional progression
    plot = music_generator.visualize_emotional_progression(result['emotion_progression'])
    plot.show()

    print("\nAll music files generated successfully!")

# Mock training data generator
def generate_mock_training_data(num_samples=100):
    """Generate synthetic data for training the models"""
    emotions = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'disgust', 'neutral']
    musical_features = ['tempo', 'key', 'mode', 'intensity', 'instrumentation',
                         'rhythm_complexity', 'harmonic_complexity', 'melodic_range', 'texture', 'articulation']

    # Generate random emotion vectors
    emotion_vectors = np.random.rand(num_samples, len(emotions))
    # Normalize to sum to 1
    emotion_vectors = emotion_vectors / emotion_vectors.sum(axis=1, keepdims=True)

    # Generate matching musical features
    musical_feature_vectors = np.random.rand(num_samples, len(musical_features))

    # Create training dataset
    X_train = torch.tensor(emotion_vectors, dtype=torch.float32)
    y_train = torch.tensor(musical_feature_vectors, dtype=torch.float32)

    return X_train, y_train

# Training function for the emotion-to-music mapper
def train_emotion_to_music_mapper(model, num_epochs=100, batch_size=16):
    """Train the emotion-to-music mapping model"""
    # Generate mock training data
    X_train, y_train = generate_mock_training_data(500)

    # Create DataLoader
    dataset = torch.utils.data.TensorDataset(X_train, y_train)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Define loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, targets in dataloader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Print statistics
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader):.4f}')

    print("Training complete!")
    return model

# Training function for the temporal coherence model
def train_temporal_coherence_model(model, seq_length=5, num_epochs=100, batch_size=8):
    """Train the temporal coherence model"""
    # For simplicity, we'll generate sequences of musical features
    num_samples = 100
    feature_dim = 10

    # Generate sequences with some temporal patterns
    sequences = []
    for _ in range(num_samples):
        # Start with random features
        start_features = np.random.rand(feature_dim)
        # Generate a sequence with gradual changes
        sequence = [start_features]
        for i in range(1, seq_length):
            # Each step is a small random modification of the previous step
            next_features = sequence[i-1] + np.random.normal(0, 0.1, feature_dim)
            # Keep within bounds
            next_features = np.clip(next_features, 0, 1)
            sequence.append(next_features)
        sequences.append(sequence)

    # Convert to torch tensors
    X_train = torch.tensor(np.array(sequences), dtype=torch.float32)
    # Target is the same sequence (we want to learn the patterns)
    y_train = X_train.clone()

    # Create DataLoader
    dataset = torch.utils.data.TensorDataset(X_train, y_train)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Define loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, targets in dataloader:
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Print statistics
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader):.4f}')

    print("Training complete!")
    return model

# Main function to run everything
def main():
    # Initialize the system
    music_generator = LiteraryMusicGenerator()

    # Train models
    print("Training emotion-to-music mapper...")
    train_emotion_to_music_mapper(music_generator.emotion_to_music_mapper, num_epochs=50)

    print("\nTraining temporal coherence model...")
    train_temporal_coherence_model(music_generator.temporal_coherence_model, num_epochs=50)

    # Run demo
    print("\nRunning demonstration...")
    demo()

    # Save models
    torch.save(music_generator.emotion_to_music_mapper.state_dict(), 'emotion_to_music_mapper.pth')
    torch.save(music_generator.temporal_coherence_model.state_dict(), 'temporal_coherence_model.pth')
    print("\nModels saved successfully!")

# Function to save MIDI files to disk
def save_midi_files(midi_files, base_filename="segment"):
    """Save MIDI files to disk and download them"""
    saved_files = []

    for i, midi_buffer in enumerate(midi_files):
        filename = f"{base_filename}_{i+1}.mid"

        # Write to file
        with open(filename, 'wb') as f:
            f.write(midi_buffer.getvalue())

        saved_files.append(filename)

        # Initiate download in Colab
        try:
            files.download(filename)
        except:
            print(f"Note: Download of {filename} may not be supported in this environment.")

    return saved_files

# Function to combine all segments into a single music piece
def combine_segments(midi_files, output_filename="combined_music.mid"):
    """Combine all MIDI segments into a single continuous piece"""
    # Create a new MIDI file with more tracks to prevent conflicts
    combined_midi = MIDIFile(8)  # Increase number of tracks to avoid note overlaps

    # Add tempo to all tracks
    for track in range(8):
        combined_midi.addTempo(track, 0, 120)

    current_time = 0
    track_offset = 0  # To distribute notes across different tracks

    for midi_buffer in midi_files:
        # Create a temporary file to save the MIDI
        with tempfile.NamedTemporaryFile(suffix=".mid", delete=False) as temp_file:
            temp_file.write(midi_buffer.getvalue())
            temp_file_path = temp_file.name

        try:
            # Load the MIDI file
            midi_data = pretty_midi.PrettyMIDI(temp_file_path)

            # Get segment duration
            end_time = 0
            if midi_data.instruments:
                for instr in midi_data.instruments:
                    if instr.notes:
                        instr_end = max([note.end for note in instr.notes])
                        end_time = max(end_time, instr_end)

            if end_time == 0:
                end_time = 15  # Default duration if no notes found

            # Add each note from the segment to the combined file with offset
            for instr_idx, instr in enumerate(midi_data.instruments):
                # Use a different track for each instrument to avoid conflicts
                track = (track_offset + instr_idx) % 8
                program = instr.program

                # Set instrument
                combined_midi.addProgramChange(track, 0, current_time, program)

                # Add notes with time offset
                for note in instr.notes:
                    start = current_time + note.start
                    duration = note.end - note.start

                    # Ensure notes have valid duration
                    if duration <= 0:
                        duration = 0.1  # Set a minimum duration

                    pitch = note.pitch
                    velocity = note.velocity

                    combined_midi.addNote(track, 0, pitch, start, duration, velocity)

            # Increment the track offset for the next segment to use different tracks
            track_offset = (track_offset + len(midi_data.instruments)) % 8

        except Exception as e:
            print(f"Warning: Error processing segment: {e}")
        finally:
            # Remove the temporary file
            try:
                os.remove(temp_file_path)
            except:
                pass

        # Update current time for next segment
        current_time += end_time + 2  # Add 2 beat gap between segments

    # Save the combined MIDI
    try:
        with open(output_filename, 'wb') as output_file:
            combined_midi.writeFile(output_file)

        print(f"Successfully combined music saved to {output_filename}")

        # Initiate download in Colab
        try:
            files.download(output_filename)
        except:
            print(f"Note: Download of {output_filename} may not be supported in this environment.")

        return output_filename
    except Exception as e:
        print(f"Error saving combined file: {e}")
        return None

# Function to take user input and process it
def process_user_text():
    print("Please enter or paste your literary text (type 'END' on a new line when finished):")
    lines = []
    while True:
        line = input()
        if line.strip() == 'END':
            break
        lines.append(line)

    user_text = '\n'.join(lines)

    if not user_text.strip():
        print("No text entered. Using example text instead.")
        user_text = """
        The morning dawned bright and clear, with a crispness in the air that promised a beautiful day ahead.
        Sarah smiled as she stepped outside, breathing in the fresh scent of wildflowers that carpeted the meadow beyond her cottage.
        As she walked along the path, her mind wandered to thoughts of the past, memories both sweet and bitter.
        """

    # Process the text
    music_generator = LiteraryMusicGenerator()
    result = music_generator.process_literary_text(user_text)

    # Display results
    print(f"\nProcessed {len(result['segments'])} text segments")

    # Visualize emotional progression
    plot = music_generator.visualize_emotional_progression(result['emotion_progression'])
    plot.show()

    # Play music for each segment
    print("\nPlaying music for each segment:")
    for i, prompt_data in enumerate(result['musiclm_prompts']):
        print(f"\nSegment {i+1}:")
        emotions_display = []
        for e in prompt_data['dominant_emotions']:
            emotions_display.append(f"{e['emotion']} ({e['intensity']:.2f})")
        print(f"  Dominant emotions: {emotions_display}")
        print(f"  MusicLM prompt: {prompt_data['prompt']}")

        # Play the music
        display(play_midi_in_colab(result['midi_files'][i]))

    # Option to save MIDI files
    save_choice = input("\nWould you like to save the MIDI files? (y/n): ")
    if save_choice.lower() == 'y':
        saved_files = save_midi_files(result['midi_files'])
        print(f"Saved {len(saved_files)} MIDI files.")

        # Option to combine segments
        combine_choice = input("Would you like to combine all segments into one continuous piece? (y/n): ")
        if combine_choice.lower() == 'y':
            combined_file = combine_segments(result['midi_files'])
            print(f"Combined music saved to {combined_file}")

            # Play the combined piece
            print("\nPlaying the combined music piece:")
            with open(combined_file, 'rb') as f:
                combined_buffer = io.BytesIO(f.read())
            display(play_midi_in_colab(combined_buffer))

    # Save results to JSON
    output_data = {
        "segments": result['segments'],
        "emotion_progression": {k: [float(v) for v in vals] for k, vals in result['emotion_progression'].items()},
        "musiclm_prompts": result['musiclm_prompts']
    }

    with open('literary_music_results.json', 'w') as f:
        json.dump(output_data, f, indent=2)

    print("\nResults saved to 'literary_music_results.json'")
    try:
        files.download('literary_music_results.json')
    except:
        print("File download not supported in this environment.")

# Interface for Google Colab
def colab_interface():
    print("Literary Text to Music Converter")
    print("================================")
    print("1. Run demonstration with example text")
    print("2. Process your own literary text")
    print("3. Train models")
    print("4. Exit")

    choice = input("\nEnter your choice (1-4): ")

    if choice == '1':
        demo()
    elif choice == '2':
        process_user_text()
    elif choice == '3':
        main()
    elif choice == '4':
        print("Exiting...")
    else:
        print("Invalid choice. Please try again.")

# Run the Colab interface
if __name__ == "__main__":
    colab_interface()

class MuLanEmbedding(nn.Module):
    def __init__(self, text_embedding_dim=768, audio_embedding_dim=512, joint_embedding_dim=256):
        super(MuLanEmbedding, self).__init__()

        # Text encoder (BERT-based)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.text_projection = nn.Linear(768, joint_embedding_dim)

        # Audio encoder
        self.audio_encoder = nn.Sequential(
            nn.Conv1d(80, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
            nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.audio_projection = nn.Linear(512, joint_embedding_dim)

    def encode_text(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token
        text_embedding = self.text_projection(cls_embedding)
        return F.normalize(text_embedding, p=2, dim=1)

    def encode_audio(self, mel_spectrogram):
        # mel_spectrogram shape: [batch_size, freq_bins, time_frames]
        audio_features = self.audio_encoder(mel_spectrogram).squeeze(-1)
        audio_embedding = self.audio_projection(audio_features)
        return F.normalize(audio_embedding, p=2, dim=1)

    def forward(self, input_ids, attention_mask, mel_spectrogram=None):
        text_embedding = self.encode_text(input_ids, attention_mask)

        if mel_spectrogram is not None:
            audio_embedding = self.encode_audio(mel_spectrogram)
            return text_embedding, audio_embedding

        return text_embedding

def train_custom_musiclm(model, train_dataloader, num_epochs=100):
    # Define loss functions
    reconstruction_loss = nn.MSELoss()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_dataloader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            audio_targets = batch['audio_waveform']

            # Forward pass
            audio_output = model(input_ids, attention_mask)

            # Compute loss
            loss = reconstruction_loss(audio_output, audio_targets)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Print progress
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    print("Training complete!")

class CustomMusicLM(nn.Module):
    def __init__(self):
        super(CustomMusicLM, self).__init__()

        # Text-to-music embedding model
        self.mulan_model = MuLanEmbedding()

        # Latent vector generation
        self.latent_generator = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.Tanh()
        )

        # Audio decoder
        self.soundstream_decoder = SoundStreamDecoder()

    def forward(self, input_ids, attention_mask):
        # Get text embedding
        text_embedding = self.mulan_model.encode_text(input_ids, attention_mask)

        # Generate latent vector
        latent_vector = self.latent_generator(text_embedding)

        # Decode to audio
        waveform = self.soundstream_decoder(latent_vector)

        return waveform

class SoundStreamDecoder(nn.Module):
    def __init__(self, latent_dim=256, output_channels=1):
        super(SoundStreamDecoder, self).__init__()

        self.initial_layer = nn.Linear(latent_dim, 16 * 256)

        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(32, output_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, latent_vector):
        # latent_vector shape: [batch_size, latent_dim]
        x = self.initial_layer(latent_vector)
        x = x.view(-1, 256, 16)  # Reshape to [batch_size, channels, length]
        waveform = self.decoder(x)
        return waveform

class LiteraryMusicGenerator:
    def __init__(self):
        self.text_processor = TextProcessor()
        self.emotion_extractor = EmotionExtractor()
        self.emotion_to_music_mapper = EmotionToMusicMapper()
        self.temporal_coherence_model = TemporalCoherenceModel()
        self.music_generator = MusicGenerator()

        # Initialize with some pre-trained weights (in a real scenario)
        # self.load_pretrained_models()

    def load_pretrained_models(self):
        """Load pre-trained models if available"""
        try:
            self.emotion_to_music_mapper.load_state_dict(torch.load('emotion_to_music_mapper.pth'))
            self.temporal_coherence_model.load_state_dict(torch.load('temporal_coherence_model.pth'))
            print("Loaded pre-trained models successfully.")
        except:
            print("Pre-trained models not found. Using initialized models.")

    def process_literary_text(self, text):
        """Process literary text end-to-end with enhanced emotional flow"""
        # Step 1: Process text
        original_segments, cleaned_segments = self.text_processor.process_text(text)

        # Step 2: Extract emotions
        emotion_maps = self.emotion_extractor.extract_emotions(original_segments)
        emotional_progression = self.emotion_extractor.create_emotional_progression(emotion_maps)
        dominant_emotions = self.emotion_extractor.get_dominant_emotions(emotion_maps)

        # Step 3: Generate musical features for each segment
        musical_features = []

        for emotion_map in emotion_maps:
            # Convert emotion map to tensor
            emotion_vector = torch.tensor([
                emotion_map['joy'], emotion_map['sadness'], emotion_map['anger'],
                emotion_map['fear'], emotion_map['surprise'], emotion_map['disgust'],
                emotion_map['neutral']
            ], dtype=torch.float32).unsqueeze(0)  # Add batch dimension

            # Map emotions to musical features
            with torch.no_grad():
                features = self.emotion_to_music_mapper(emotion_vector)

            musical_features.append(features.squeeze(0))

        # Step 4: Ensure temporal coherence
        if len(musical_features) > 1:
            # Stack features into a sequence
            feature_sequence = torch.stack(musical_features).unsqueeze(0)  # [1, seq_len, features]

            # Apply temporal coherence model
            with torch.no_grad():
                coherent_features = self.temporal_coherence_model(feature_sequence)

            # Convert back to list
            musical_features = [feat for feat in coherent_features.squeeze(0)]

        # Step 5: Generate MusicLM prompts and MIDI files with emotional flow awareness
        musiclm_prompts = []
        midi_files = []
        transition_prompts = []

        for i, features in enumerate(musical_features):
            # Map normalized features to actual values
            actual_values = self.emotion_to_music_mapper.map_to_actual_values(features)

            # Get current emotion map
            current_emotion_map = emotion_maps[i]

            # Generate emotion-aware prompt
            prompt = self.emotion_to_music_mapper.generate_musiclm_prompt(
                actual_values,
                emotion_scores=current_emotion_map
            )

            # Generate MIDI file
            midi_buffer = self.music_generator.create_midi_from_features(actual_values)
            midi_files.append(midi_buffer)

            # Add context from the text segment
            segment_context = f"Music representing: '{original_segments[i][:100]}...'"

            # Generate transition prompt if not the first segment
            if i > 0:
                prev_features = self.emotion_to_music_mapper.map_to_actual_values(musical_features[i-1])
                prev_emotion_map = emotion_maps[i-1]

                transition_prompt = self.emotion_to_music_mapper.generate_transition_prompt(
                    prev_features,
                    actual_values,
                    prev_emotion_map,
                    current_emotion_map
                )
                transition_prompts.append({
                    "from_segment": i-1,
                    "to_segment": i,
                    "prompt": transition_prompt
                })

            musiclm_prompts.append({
                "segment_idx": i,
                "prompt": prompt,
                "context": segment_context,
                "dominant_emotions": [{"emotion": e[0], "intensity": e[1]} for e in dominant_emotions[i]],
                "musical_features": actual_values,
                "emotion_map": {k: float(v) for k, v in current_emotion_map.items()}
            })

        # Generate a full emotional flow narrative for the entire piece
        emotional_narrative = self.generate_emotional_flow_narrative(emotion_maps, original_segments)

        return {
            "segments": original_segments,
            "emotion_progression": emotional_progression,
            "musiclm_prompts": musiclm_prompts,
            "transition_prompts": transition_prompts,
            "emotional_narrative": emotional_narrative,
            "midi_files": midi_files
        }

    def generate_emotional_flow_narrative(self, emotion_maps, text_segments):
        """Generate a cohesive narrative describing the emotional flow of the entire text"""
        # Identify major emotional shifts
        narrative = "The emotional journey of this text flows through: "

        # Track dominant emotions through segments
        dominant_per_segment = []
        for emotion_map in emotion_maps:
            top_emotion = max(emotion_map.items(), key=lambda x: x[1])
            if top_emotion[1] > 0.2:  # Only if the emotion is significant
                dominant_per_segment.append(top_emotion[0])
            else:
                dominant_per_segment.append("neutral")

        # Find emotional arcs (consecutive segments with the same dominant emotion)
        emotional_arcs = []
        current_arc = {"emotion": dominant_per_segment[0], "start": 0, "end": 0}

        for i in range(1, len(dominant_per_segment)):
            if dominant_per_segment[i] == current_arc["emotion"]:
                # Continue the current arc
                current_arc["end"] = i
            else:
                # End the current arc and start a new one
                emotional_arcs.append(current_arc)
                current_arc = {"emotion": dominant_per_segment[i], "start": i, "end": i}

        # Add the last arc
        emotional_arcs.append(current_arc)

        # Generate narrative from arcs
        if len(emotional_arcs) <= 3:
            # Simple emotional progression
            emotions_journey = " → ".join([arc["emotion"] for arc in emotional_arcs])
            narrative += emotions_journey + "."
        else:
            # More complex emotional journey
            narrative += "an initial " + emotional_arcs[0]["emotion"]

            # Middle arcs with text context
            for i in range(1, len(emotional_arcs) - 1):
                arc = emotional_arcs[i]
                # Get a snippet from the text segment for context
                segment_snippet = text_segments[arc["start"]][:50] + "..."
                narrative += f", then {arc['emotion']} during \"{segment_snippet}\""

            # End with the final emotion
            narrative += f", and concludes with {emotional_arcs[-1]['emotion']}."

        # Add musical suggestion based on the overall emotional flow
        narrative += "\n\nA MusicLM composition capturing this emotional flow should: "

        # Build specific suggestions
        suggestions = []
        for i, arc in enumerate(emotional_arcs):
            emotion = arc["emotion"]

            # Get musical characteristics for this emotion
            if hasattr(self.emotion_to_music_mapper, 'emotion_correlations'):
                correlations = self.emotion_to_music_mapper.emotion_correlations.get(emotion, {})

                if correlations:
                    # Pick a random descriptor
                    if 'description' in correlations and correlations['description']:
                        descriptor = random.choice(correlations['description'])

                        # Different suggestion format based on position in narrative
                        if i == 0:
                            suggestions.append(f"begin with a {descriptor} section")
                        elif i == len(emotional_arcs) - 1:
                            suggestions.append(f"conclude with a {descriptor} finale")
                        else:
                            suggestions.append(f"transition to a {descriptor} middle section")

        narrative += ", ".join(suggestions) + "."

        return narrative

    def visualize_emotional_progression(self, emotional_progression):
        """Visualize the emotional progression throughout the text"""
        plt.figure(figsize=(12, 6))

        for emotion, values in emotional_progression.items():
            plt.plot(values, label=emotion)

        plt.xlabel('Segment Index')
        plt.ylabel('Emotion Intensity')
        plt.title('Emotional Progression Throughout the Text')
        plt.legend()
        plt.grid(True, alpha=0.3)
        return plt

    def train_models(self, training_data):
        """Train the emotion-to-music mapper and temporal coherence models"""
        # In a real implementation, this would use actual training data
        # This is a simplified placeholder
        pass

def use_custom_musiclm(literary_music_generator, text_prompts):
    # Initialize custom MusicLM model
    custom_musiclm = CustomMusicLM()

    # Load trained weights
    custom_musiclm.load_state_dict(torch.load('custom_musiclm.pth'))
    custom_musiclm.eval()

    # Tokenize the text prompts
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

    audio_outputs = []

    for prompt in text_prompts:
        # Tokenize
        encoded_input = tokenizer(
            prompt,
            padding='max_length',
            truncation=True,
            max_length=128,
            return_tensors='pt'
        )

        # Generate audio
        with torch.no_grad():
            audio_output = custom_musiclm(
                encoded_input['input_ids'],
                encoded_input['attention_mask']
            )

        audio_outputs.append(audio_output.squeeze().cpu().numpy())

    return audio_outputs




[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


Downloading additional NLTK resources...
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
Literary Text to Music Converter
1. Run demonstration with example text
2. Process your own literary text
3. Train models
4. Exit

Enter your choice (1-4): 2
Please enter or paste your literary text (type 'END' on a new line when finished):
It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife.  However little known the feelings or views of such a man may be on his first entering a neighbourhood, this truth is so well fixed in the minds of the surrounding families that he is considered as the rightful property of some one or other of their daugh

config.json:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


pytorch_model.bin:   0%|          | 0.00/329M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/294 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/329M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Device set to use cpu


AttributeError: 'LiteraryMusicGenerator' object has no attribute 'process_literary_text'

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nltk
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from sklearn.preprocessing import MinMaxScaler
import json
import os
import io
import base64
import tempfile
import random

# For MIDI generation
from midiutil import MIDIFile
import pretty_midi
import IPython.display as ipd

# Try to import Google Colab files, but handle the case where not running in Colab
try:
    from google.colab import files
except ImportError:
    # Create a dummy files object
    class DummyFiles:
        def download(self, filename):
            print(f"File download not available outside Colab: {filename}")
    files = DummyFiles()

# Download NLTK resources properly
# This ensures the data is downloaded correctly
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
nltk.download('wordnet', quiet=True)

class TextProcessor:
    def __init__(self):
        # Make sure stopwords are available
        try:
            self.stop_words = set(stopwords.words('english'))
        except LookupError:
            nltk.download('stopwords')
            self.stop_words = set(stopwords.words('english'))

        # Make sure wordnet is available
        try:
            self.lemmatizer = WordNetLemmatizer()
        except LookupError:
            nltk.download('wordnet')
            self.lemmatizer = WordNetLemmatizer()

    def clean_text(self, text):
        """Remove punctuation, lowercase, remove stopwords, and lemmatize"""
        # Lowercase the text
        text = text.lower()
        # Remove punctuation
        text = re.sub(r'[^\w\s]', '', text)

        # Try to tokenize, with fallback method if it fails
        try:
            # Tokenize into words
            words = word_tokenize(text)
        except LookupError:
            # Fallback: simple space-based tokenization
            words = text.split()

        # Remove stopwords and lemmatize
        cleaned_words = [self.lemmatizer.lemmatize(word) for word in words if word not in self.stop_words]
        return ' '.join(cleaned_words)

    # Improved TextProcessor with paragraph-aware segmentation
    def segment_text(self, text, segment_size=500):
        """Split text into paragraphs first, then into segments"""
        paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]

        segments = []
        for para in paragraphs:
            para_sentences = sent_tokenize(para)
            current_segment = []
            current_length = 0

            for sent in para_sentences:
                sent_length = len(sent)
                if current_length + sent_length > segment_size and current_segment:
                    segments.append(' '.join(current_segment))
                    current_segment = [sent]
                    current_length = sent_length
                else:
                    current_segment.append(sent)
                    current_length += sent_length

            if current_segment:
                segments.append(' '.join(current_segment))

        return segments


    def process_text(self, text, segment_size=500):
        """Process the full text: segment first, then clean each segment"""
        segments = self.segment_text(text, segment_size)
        cleaned_segments = [self.clean_text(segment) for segment in segments]
        # Also keep the original segments for display purposes
        return segments, cleaned_segments

class EmotionExtractor:
    def __init__(self):
        # Load pre-trained emotion classification model
        try:
            self.emotion_classifier = pipeline(
                "text-classification",
                model="j-hartmann/emotion-english-distilroberta-base",
                top_k=None  # Using top_k=None instead of return_all_scores=True
            )
        except Exception as e:
            print(f"Error loading emotion classifier: {e}")
            # Create a fallback emotion classifier that returns random values
            self.emotion_classifier = self._fallback_classifier

        # Define our emotion categories
        self.emotion_categories = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'disgust', 'neutral']

    def _fallback_classifier(self, text):
        """Fallback classifier that returns random emotion scores"""
        # Generate random scores
        scores = np.random.rand(len(self.emotion_categories))
        # Normalize to sum to 1
        scores = scores / scores.sum()

        # Create the expected output format
        result = [{
            'label': emotion,
            'score': float(score)
        } for emotion, score in zip(self.emotion_categories, scores)]

        return [result]

    def extract_emotions(self, text_segments):
        """Extract emotions from text segments"""
        emotion_maps = []

        for segment in text_segments:
            # Get emotion scores for the segment
            try:
                emotion_scores = self.emotion_classifier(segment)[0]
            except Exception as e:
                print(f"Error in emotion extraction for segment: {e}")
                # Use fallback with random scores
                emotion_scores = self._fallback_classifier(segment)[0]

            # Convert to dictionary with emotion as key and score as value
            emotion_dict = {item['label']: item['score'] for item in emotion_scores}

            # Map the model's emotions to our simplified set
            mapped_emotions = {
                'joy': emotion_dict.get('joy', 0),
                'sadness': emotion_dict.get('sadness', 0),
                'anger': emotion_dict.get('anger', 0),
                'fear': emotion_dict.get('fear', 0),
                'surprise': emotion_dict.get('surprise', 0),
                'disgust': emotion_dict.get('disgust', 0),
                'neutral': emotion_dict.get('neutral', 0)
            }

            emotion_maps.append(mapped_emotions)

        return emotion_maps

    def create_emotional_progression(self, emotion_maps):
        """Create a time series of emotions for the entire text"""
        progression = {emotion: [] for emotion in self.emotion_categories}

        for emotion_map in emotion_maps:
            for emotion in self.emotion_categories:
                progression[emotion].append(emotion_map[emotion])

        return progression

    def get_dominant_emotions(self, emotion_maps, top_n=2):
        """Get the dominant emotions for each segment"""
        dominant_emotions = []

        for emotion_map in emotion_maps:
            # Sort emotions by score
            sorted_emotions = sorted(emotion_map.items(), key=lambda x: x[1], reverse=True)
            # Take top n emotions
            top_emotions = sorted_emotions[:top_n]
            dominant_emotions.append(top_emotions)

        return dominant_emotions

# Added sequence-aware LSTM architecture
class EmotionToMusicMapper(nn.Module):
    def __init__(self, input_dim=7, hidden_dim=64, output_dim=10):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        # Take last output in sequence
        last_output = lstm_out[:, -1, :]
        return self.fc(last_output)

    def map_to_actual_values(self, normalized_features):
        """Convert normalized outputs to actual musical values"""
        actual_values = {}

        for feature, feature_range in self.musical_features.items():
            idx = self.feature_to_idx[feature]
            norm_value = normalized_features[idx].item()

            # Scale to the actual range
            min_val = feature_range['min']
            max_val = feature_range['max']
            actual_value = min_val + norm_value * (max_val - min_val)

            # Round as needed
            if feature in ['key', 'instrumentation']:
                actual_value = round(actual_value)

            actual_values[feature] = actual_value

        return actual_values

    def get_descriptors_for_emotion_blend(self, emotion_scores):
        """Get appropriate musical descriptors based on a blend of emotions"""
        # Get top 2 emotions
        sorted_emotions = sorted(emotion_scores.items(), key=lambda x: x[1], reverse=True)
        top_emotions = sorted_emotions[:2]

        descriptors = []
        instruments = set()

        # Weighted selection of descriptors based on emotion intensity
        total_weight = sum(score for _, score in top_emotions)

        for emotion, score in top_emotions:
            # Skip emotions with very low scores
            if score < 0.1:
                continue

            # Weight by the emotion's intensity
            weight = score / total_weight

            # Get descriptors for this emotion
            emotion_descriptors = self.emotion_correlations[emotion]['description']

            # Add descriptors proportional to the emotion's weight
            num_descriptors = max(1, int(weight * 3))  # At least 1, up to 3 descriptors
            selected_descriptors = random.sample(emotion_descriptors, min(num_descriptors, len(emotion_descriptors)))
            descriptors.extend(selected_descriptors)

            # Add potential instruments
            emotion_instruments = self.emotion_correlations[emotion]['instrumentation']
            # Choose 1-2 instruments based on weight
            num_instruments = max(1, int(weight * 2))
            selected_instruments = random.sample(emotion_instruments, min(num_instruments, len(emotion_instruments)))
            instruments.update(selected_instruments)

        # Return unique descriptors and instruments
        return list(set(descriptors)), list(instruments)

    def generate_musiclm_prompt(self, musical_features, emotion_scores=None):
        """Convert musical features to a MusicLM text prompt"""
        # Map key number to name
        key_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
        key_name = key_names[round(musical_features['key'])]

        # Map mode number to name
        mode_name = "major" if musical_features['mode'] > 0.5 else "minor"

        # Map instrumentation to instrument types
        instrument_types = ['piano', 'strings', 'guitar', 'synth', 'orchestral', 'percussion']
        instrument = instrument_types[round(musical_features['instrumentation'])]

        # Determine tempo description
        tempo = musical_features['tempo']
        if tempo < 80:
            tempo_desc = "slow"
        elif tempo < 120:
            tempo_desc = "moderate"
        else:
            tempo_desc = "fast"

        # Determine intensity description
        intensity = musical_features['intensity']
        if intensity < 0.3:
            intensity_desc = "soft"
        elif intensity < 0.7:
            intensity_desc = "moderate"
        else:
            intensity_desc = "powerful"

        # Get emotional descriptors if emotion scores are provided
        if emotion_scores:
            emotional_descriptors, suggested_instruments = self.get_descriptors_for_emotion_blend(emotion_scores)

            # Use suggested instruments if available, otherwise use the mapped one
            if suggested_instruments:
                # Prioritize the instrument from musical features but include others
                all_instruments = [instrument] + [i for i in suggested_instruments if i != instrument]
                # Take up to 2 instruments
                instrument_phrase = " and ".join(all_instruments[:2])
            else:
                instrument_phrase = instrument

            # Build a more emotionally nuanced prompt
            prompt = f"A {' and '.join(emotional_descriptors[:2])} piece in {key_name} {mode_name}, "
            prompt += f"{tempo_desc} tempo, {intensity_desc} in intensity, "
            prompt += f"featuring {instrument_phrase}, "

        else:
            # Use the original approach as fallback
            # Build the prompt
            prompt = f"A {intensity_desc} {tempo_desc} melody in {key_name} {mode_name}, "

            # Add texture
            texture = musical_features['texture']
            if texture < 0.3:
                prompt += "with a sparse arrangement, "
            elif texture > 0.7:
                prompt += "with a dense, layered arrangement, "

            # Add instrumentation
            prompt += f"featuring {instrument}, "

        # Add complexity
        rhythm_complexity = musical_features['rhythm_complexity']
        harmonic_complexity = musical_features['harmonic_complexity']

        if rhythm_complexity > 0.7 and harmonic_complexity > 0.7:
            prompt += "with complex rhythms and harmonies, "
        elif rhythm_complexity > 0.7:
            prompt += "with complex rhythms, "
        elif harmonic_complexity > 0.7:
            prompt += "with rich harmonies, "
        elif rhythm_complexity < 0.3 and harmonic_complexity < 0.3:
            prompt += "with simple, straightforward patterns, "

        # Add articulation
        articulation = musical_features['articulation']
        if articulation < 0.3:
            prompt += "played with staccato articulation."
        elif articulation > 0.7:
            prompt += "played with smooth, legato phrasing."
        else:
            prompt += "with balanced articulation."

        # Add emotional context at the end for a more complete prompt
        if emotion_scores:
            top_emotion = max(emotion_scores.items(), key=lambda x: x[1])
            if top_emotion[1] > 0.3:  # Only if the emotion is significant
                prompt += f" The music conveys a sense of {top_emotion[0]}."

        return prompt

    def generate_transition_prompt(self, prev_features, current_features, prev_emotions, current_emotions):
        """Generate a prompt specifically for transitions between emotional states"""
        # Identify the most significant emotional shift
        prev_top_emotion = max(prev_emotions.items(), key=lambda x: x[1])
        current_top_emotion = max(current_emotions.items(), key=lambda x: x[1])

        # Check if there's a significant emotional shift
        if prev_top_emotion[0] != current_top_emotion[0] and prev_top_emotion[1] > 0.3 and current_top_emotion[1] > 0.3:
            transition_type = f"transition from {prev_top_emotion[0]} to {current_top_emotion[0]}"
        else:
            # If no major emotional shift, focus on musical parameter changes
            changes = []

            # Check tempo change
            prev_tempo = prev_features['tempo']
            current_tempo = current_features['tempo']
            if abs(prev_tempo - current_tempo) > 20:  # Significant tempo change
                direction = "accelerating" if current_tempo > prev_tempo else "decelerating"
                changes.append(direction)

            # Check intensity change
            prev_intensity = prev_features['intensity']
            current_intensity = current_features['intensity']
            if abs(prev_intensity - current_intensity) > 0.3:  # Significant intensity change
                direction = "building in intensity" if current_intensity > prev_intensity else "becoming more subdued"
                changes.append(direction)

            # Check mode change
            prev_mode = "major" if prev_features['mode'] > 0.5 else "minor"
            current_mode = "major" if current_features['mode'] > 0.5 else "minor"
            if prev_mode != current_mode:
                changes.append(f"shifting from {prev_mode} to {current_mode}")

            if changes:
                transition_type = ", ".join(changes)
            else:
                transition_type = "gradual transition"

        # Generate the transition prompt
        prompt = f"A {transition_type} that maintains musical coherence while "

        # Add emotional descriptors
        prev_descriptors, _ = self.get_descriptors_for_emotion_blend(prev_emotions)
        current_descriptors, _ = self.get_descriptors_for_emotion_blend(current_emotions)

        if prev_descriptors and current_descriptors:
            prompt += f"evolving from {prev_descriptors[0]} to {current_descriptors[0]}. "

        # Add instrumentation continuity
        instrument_types = ['piano', 'strings', 'guitar', 'synth', 'orchestral', 'percussion']
        prev_instrument = instrument_types[round(prev_features['instrumentation'])]
        current_instrument = instrument_types[round(current_features['instrumentation'])]

        if prev_instrument == current_instrument:
            prompt += f"Featuring {prev_instrument} throughout. "
        else:
            prompt += f"Transitioning from {prev_instrument} to {current_instrument}. "

        return prompt


# Added MusicLM interface component
class MusicLMInterface:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("google/musiclm-medium")
        self.model = AutoModelForCausalLM.from_pretrained("google/musiclm-medium")
        self.soundstream = SoundStreamDecoder()

    def generate_audio(self, prompt, max_length=512):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        outputs = self.model.generate(**inputs, max_length=max_length)
        audio = self.soundstream.decode(outputs)
        return audio

# Added sequence modeling wrapper
class TemporalCoherenceWrapper(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.temporal_adjuster = nn.LSTM(
            input_size=10,
            hidden_size=10,
            num_layers=2,
            bidirectional=True
        )

    def forward(self, emotion_sequence):
        base_output = self.base_model(emotion_sequence)
        adjusted_output, _ = self.temporal_adjuster(base_output)
        return adjusted_output.mean(dim=1)

class TemporalCoherenceModel(nn.Module):
    def __init__(self, input_dim=10, hidden_dim=64, output_dim=10):
        super(TemporalCoherenceModel, self).__init__()

        # LSTM for sequence modeling
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.3
        )

        # Output layer
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()
        )

    def forward(self, sequence):
        """Process a sequence of musical features to ensure temporal coherence"""
        # Expected shape: [batch_size, sequence_length, features]
        lstm_out, _ = self.lstm(sequence)

        # Apply output layer to each time step
        coherent_sequence = self.output_layer(lstm_out)

        return coherent_sequence

class MusicGenerator:
    def __init__(self):
        # MIDI settings
        self.instruments = {
            'piano': 0,       # Acoustic Grand Piano
            'strings': 48,    # String Ensemble 1
            'guitar': 24,     # Acoustic Guitar (nylon)
            'synth': 80,      # Lead 1 (square)
            'orchestral': 48, # String Ensemble 1
            'percussion': 118 # Synth Drum
        }

        # Scales (for different keys and modes)
        self.major_scale = [0, 2, 4, 5, 7, 9, 11]  # Whole, Whole, Half, Whole, Whole, Whole, Half
        self.minor_scale = [0, 2, 3, 5, 7, 8, 10]  # Whole, Half, Whole, Whole, Half, Whole, Whole

        # Common chord progressions
        self.progressions = {
            'major': [
                [1, 4, 5, 1],       # I-IV-V-I
                [1, 6, 4, 5],       # I-vi-IV-V
                [1, 5, 6, 4],       # I-V-vi-IV
                [2, 5, 1, 6]        # ii-V-I-vi
            ],
            'minor': [
                [1, 4, 5, 1],       # i-iv-v-i
                [1, 6, 3, 7],       # i-VI-III-VII
                [1, 7, 6, 5],       # i-VII-VI-v
                [1, 4, 7, 3]        # i-iv-VII-III
            ]
        }

    def create_midi_from_features(self, musical_features, duration_seconds=15):
        """Generate MIDI file from musical features"""
        # Extract features
        tempo = musical_features['tempo']
        key = int(musical_features['key'])
        mode = 'major' if musical_features['mode'] > 0.5 else 'minor'
        instrument_type = ['piano', 'strings', 'guitar', 'synth', 'orchestral', 'percussion'][int(musical_features['instrumentation'])]
        instrument = self.instruments[instrument_type]

        # Derived parameters
        rhythm_complexity = musical_features['rhythm_complexity']
        harmonic_complexity = musical_features['harmonic_complexity']
        melodic_range = musical_features['melodic_range']
        intensity = musical_features['intensity']
        texture = musical_features['texture']
        articulation = musical_features['articulation']

        # Create MIDI file
        midi = MIDIFile(2)  # 2 tracks - one for melody, one for accompaniment
        track_melody = 0
        track_accomp = 1

        # Set tempo
        midi.addTempo(track_melody, 0, tempo)
        midi.addTempo(track_accomp, 0, tempo)

        # Choose scale based on mode
        scale = self.major_scale if mode == 'major' else self.minor_scale

        # Calculate beats based on tempo and duration
        total_beats = int((tempo / 60) * duration_seconds)

        # Set program (instrument)
        midi.addProgramChange(track_melody, 0, 0, instrument)
        midi.addProgramChange(track_accomp, 0, 0, instrument)

        # Choose chord progression based on mode and complexity
        prog_idx = min(int(harmonic_complexity * len(self.progressions[mode])), len(self.progressions[mode]) - 1)
        progression = self.progressions[mode][prog_idx]

        # Chord duration in beats
        chord_duration = max(4, total_beats // len(progression))
        repetitions = max(1, total_beats // (chord_duration * len(progression)))

        # Base octave for melody
        base_octave = 5 if instrument_type != 'piano' else 4

        # Calculate note range based on melodic range
        low_note = base_octave * 12  # C in the base octave
        high_note = low_note + int(12 + 12 * melodic_range)  # Up to an octave or two higher

        # Velocity (volume) based on intensity
        melody_velocity = 64 + int(intensity * 63)  # Between 64-127
        chord_velocity = int(melody_velocity * 0.8)  # Slightly quieter

        # Note duration modifier based on articulation (staccato to legato)
        duration_modifier = 0.5 + (articulation * 0.5)  # 0.5 (staccato) to 1.0 (legato)

        # Generate melody and chords
        current_beat = 0
        for rep in range(repetitions):
            for chord_idx, chord_root in enumerate(progression):
                # Map chord root to actual note in the key
                chord_root_idx = chord_root - 1  # Adjust for 0-indexing
                root_note = (key + scale[chord_root_idx]) % 12

                # Determine chord type based on position in scale
                is_major_chord = chord_root_idx in [0, 3, 4] if mode == 'major' else chord_root_idx in [2, 5]

                # Create chord notes (root, third, fifth)
                third_offset = 4 if is_major_chord else 3
                chord_notes = [
                    root_note + 60,  # Root note (C4 = 60, middle C)
                    root_note + 60 + third_offset,  # Third
                    root_note + 60 + 7   # Fifth
                ]

                # Add chord to accompaniment track with texture variation
                if texture < 0.3:
                    # Sparse - just root and fifth
                    midi.addNote(track_accomp, 0, chord_notes[0], current_beat, chord_duration * 0.9, chord_velocity)
                    midi.addNote(track_accomp, 0, chord_notes[2], current_beat, chord_duration * 0.9, chord_velocity)
                elif texture < 0.7:
                    # Medium - broken chord
                    for i, note in enumerate(chord_notes):
                        midi.addNote(track_accomp, 0, note, current_beat + i*0.5, chord_duration * 0.9 - i*0.5, chord_velocity)
                else:
                    # Dense - full chord plus extra notes
                    for note in chord_notes:
                        midi.addNote(track_accomp, 0, note, current_beat, chord_duration * 0.9, chord_velocity)
                    # Add extra notes for texture
                    midi.addNote(track_accomp, 0, chord_notes[0] + 12, current_beat + 1, chord_duration * 0.4, chord_velocity - 10)

                # Generate melody for this chord
                notes_per_beat = 1 + int(rhythm_complexity * 3)  # 1 to 4 notes per beat

                for beat_offset in range(chord_duration):
                    # Skip some beats randomly for variation
                    if random.random() < 0.2:
                        continue

                    for note_idx in range(notes_per_beat):
                        # Calculate precise timing
                        note_start = current_beat + beat_offset + (note_idx / notes_per_beat)

                        # Choose note from scale
                        scale_idx = random.randint(0, len(scale) - 1)
                        note = key + scale[scale_idx]

                        # Map to the right octave range
                        octave = random.randint(base_octave, base_octave + 1)
                        note = (note % 12) + (octave * 12)

                        # Ensure note is in our range
                        note = max(low_note, min(note, high_note))

                        # Calculate duration based on articulation and rhythm
                        note_duration = (1.0 / notes_per_beat) * duration_modifier

                        # Add note to melody track
                        if random.random() < 0.8:  # 80% chance to add a note (for rests)
                            midi.addNote(track_melody, 0, note, note_start, note_duration, melody_velocity)

                current_beat += chord_duration

        # Write MIDI file to bytes buffer
        buffer = io.BytesIO()
        midi.writeFile(buffer)
        buffer.seek(0)

        return buffer

class LiteraryMusicGenerator:
    # ... existing code ...

    def generate_music_score(self, emotion_maps, musical_features):
        """Generate a detailed music score with more advanced musical theory"""
        scores = []

        for i, (emotion_map, features) in enumerate(zip(emotion_maps, musical_features)):
            # Map normalized features to actual values
            actual_values = self.emotion_to_music_mapper.map_to_actual_values(features)

            # Determine key signature and scale
            key_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
            key_name = key_names[round(actual_values['key'])]
            mode_name = "major" if actual_values['mode'] > 0.5 else "minor"
            scale = self.music_generator.major_scale if mode_name == "major" else self.music_generator.minor_scale

            # Determine chord progression based on emotion
            top_emotion = max(emotion_map.items(), key=lambda x: x[1])[0]
            progression_index = min(int(actual_values['harmonic_complexity'] *
                                    len(self.music_generator.progressions[mode_name])),
                                    len(self.music_generator.progressions[mode_name]) - 1)
            progression = self.music_generator.progressions[mode_name][progression_index]

            # Create detailed score information
            score_info = {
                "segment": i,
                "key": key_name,
                "mode": mode_name,
                "tempo": actual_values['tempo'],
                "time_signature": "4/4",  # Default, could be variable
                "dominant_emotion": top_emotion,
                "chord_progression": [f"{key_name} {self._chord_notation(p, mode_name)}" for p in progression],
                "instrumentation": self._get_instrumentation(actual_values['instrumentation']),
                "dynamics": self._get_dynamics(actual_values['intensity']),
                "rhythmic_feel": self._get_rhythmic_feel(actual_values['rhythm_complexity']),
                "melodic_contour": self._generate_melodic_contour(emotion_map)
            }

            scores.append(score_info)

        return scores

    def _chord_notation(self, degree, mode):
        """Convert scale degree to chord notation"""
        if mode == "major":
            if degree in [1, 4, 5]:
                return ""  # Major chord
            elif degree in [2, 3, 6]:
                return "m"  # Minor chord
            else:
                return "dim"  # Diminished chord
        else:  # minor
            if degree in [3, 6, 7]:
                return ""  # Major chord
            elif degree in [1, 4, 5]:
                return "m"  # Minor chord
            else:
                return "dim"  # Diminished chord

    def _get_instrumentation(self, instrumentation_value):
        """Get detailed instrumentation based on the value"""
        instrument_types = ['piano', 'strings', 'guitar', 'synth', 'orchestral', 'percussion']
        primary = instrument_types[round(instrumentation_value)]

        # Add secondary instruments based on primary
        if primary == "piano":
            secondary = ["soft strings", "light percussion"]
        elif primary == "strings":
            secondary = ["harp", "woodwinds"]
        elif primary == "guitar":
            secondary = ["bass", "light percussion"]
        elif primary == "synth":
            secondary = ["bass synth", "electronic drums"]
        elif primary == "orchestral":
            secondary = ["brass", "timpani"]
        else:  # percussion
            secondary = ["bass", "synth pads"]

        return {"primary": primary, "secondary": secondary}

    def _get_dynamics(self, intensity):
        """Convert intensity to musical dynamics notation"""
        if intensity < 0.2:
            return "pp (pianissimo)"
        elif intensity < 0.4:
            return "p (piano)"
        elif intensity < 0.6:
            return "mp (mezzo-piano) to mf (mezzo-forte)"
        elif intensity < 0.8:
            return "f (forte)"
        else:
            return "ff (fortissimo)"

    def _get_rhythmic_feel(self, complexity):
        """Determine rhythmic feel based on complexity"""
        if complexity < 0.3:
            return "simple, steady rhythm with minimal syncopation"
        elif complexity < 0.7:
            return "moderately complex with some syncopation and rhythmic variety"
        else:
            return "complex polyrhythms with significant syncopation and rhythmic tension"

    def _generate_melodic_contour(self, emotion_map):
        """Generate description of melodic contour based on emotions"""
        if emotion_map['joy'] > 0.5:
            return "rising, upward melodic contour with occasional leaps"
        elif emotion_map['sadness'] > 0.5:
            return "gradually descending melodic contour with stepwise motion"
        elif emotion_map['anger'] > 0.5:
            return "angular melodic contour with dramatic leaps and falls"
        elif emotion_map['fear'] > 0.5:
            return "unstable melodic contour with unpredictable direction changes"
        elif emotion_map['surprise'] > 0.5:
            return "playful melodic contour with unexpected intervals"
        else:
            return "balanced melodic contour with moderate movement"

    def generate_music_visualization(self, musical_features, emotion_maps):
        """Generate visual representation of the music for each segment"""
        visualizations = []

        for i, (features, emotion_map) in enumerate(zip(musical_features, emotion_maps)):
            # Map normalized features to actual values
            actual_values = self.emotion_to_music_mapper.map_to_actual_values(features)

            # Create visualization data
            viz_data = {
                "segment": i,
                "waveform_parameters": {
                    "amplitude": actual_values['intensity'] * 100,
                    "frequency": 220 + actual_values['key'] * 20,  # Base frequency affected by key
                    "waveform_type": self._get_waveform_type(emotion_map)
                },
                "color_scheme": self._get_emotion_color(emotion_map),
                "texture_density": actual_values['texture'] * 100,
                "motion_parameters": {
                    "speed": actual_values['tempo'] / 120,  # Normalized to 1.0 at tempo 120
                    "direction": self._get_motion_direction(emotion_map),
                    "pattern": self._get_motion_pattern(actual_values['rhythm_complexity'])
                }
            }

            visualizations.append(viz_data)

        return visualizations

    def _get_waveform_type(self, emotion_map):
        """Determine appropriate waveform type based on emotion"""
        top_emotion = max(emotion_map.items(), key=lambda x: x[1])[0]

        if top_emotion == 'joy':
            return "sine"
        elif top_emotion == 'sadness':
            return "triangle"
        elif top_emotion == 'anger':
            return "sawtooth"
        elif top_emotion == 'fear':
            return "square with noise"
        elif top_emotion == 'surprise':
            return "modulated sine"
        else:
            return "sine with harmonics"

    def _get_emotion_color(self, emotion_map):
        """Map emotions to colors using common associations"""
        # Calculate weighted color based on emotion intensities
        colors = {
            'joy': {"r": 255, "g": 215, "b": 0},      # Gold/Yellow
            'sadness': {"r": 0, "g": 0, "b": 139},    # Dark Blue
            'anger': {"r": 220, "g": 20, "b": 60},    # Crimson/Red
            'fear': {"r": 75, "g": 0, "b": 130},      # Indigo/Purple
            'surprise': {"r": 0, "g": 191, "b": 255}, # Deep Sky Blue
            'disgust': {"r": 107, "g": 142, "b": 35}, # Olive Green
            'neutral': {"r": 169, "g": 169, "b": 169}  # Gray
        }

        # Sum the weighted RGB values
        r = g = b = 0
        total_weight = 0

        for emotion, intensity in emotion_map.items():
            if intensity > 0.1:  # Only consider emotions with significant intensity
                r += colors[emotion]["r"] * intensity
                g += colors[emotion]["g"] * intensity
                b += colors[emotion]["b"] * intensity
                total_weight += intensity

        # Normalize
        if total_weight > 0:
            r = int(r / total_weight)
            g = int(g / total_weight)
            b = int(b / total_weight)
        else:
            r, g, b = 169, 169, 169  # Default gray

        return {"r": r, "g": g, "b": b, "hex": f"#{r:02x}{g:02x}{b:02x}"}

    def _get_motion_direction(self, emotion_map):
        """Determine motion direction based on dominant emotion"""
        if emotion_map['joy'] > 0.5:
            return "upward"
        elif emotion_map['sadness'] > 0.5:
            return "downward"
        elif emotion_map['anger'] > 0.5:
            return "outward"
        elif emotion_map['fear'] > 0.5:
            return "inward"
        else:
            return "horizontal"

    def _get_motion_pattern(self, complexity):
        """Determine motion pattern based on rhythmic complexity"""
        if complexity < 0.3:
            return "linear"
        elif complexity < 0.7:
            return "undulating"
        else:
            return "complex"

    def create_emotional_narrative_score(self, emotion_maps, text_segments):
        """Create a narrative score that maps emotions to musical directions"""
        narrative_score = []

        for i, (emotion_map, text) in enumerate(zip(emotion_maps, text_segments)):
            # Get dominant emotions
            emotions_sorted = sorted(emotion_map.items(), key=lambda x: x[1], reverse=True)
            primary_emotion = emotions_sorted[0][0]
            primary_intensity = emotions_sorted[0][1]

            # Get secondary emotion if available and significant
            secondary_emotion = None
            secondary_intensity = 0
            if len(emotions_sorted) > 1 and emotions_sorted[1][1] > 0.25:
                secondary_emotion = emotions_sorted[1][0]
                secondary_intensity = emotions_sorted[1][1]

            # Create musical direction based on emotions
            musical_direction = self._emotion_to_musical_direction(primary_emotion, primary_intensity)

            # Add modifications from secondary emotion if applicable
            if secondary_emotion:
                musical_direction += " " + self._secondary_emotion_modification(secondary_emotion, secondary_intensity)

            # Create entry
            narrative_entry = {
                "segment": i,
                "text_excerpt": text[:100] + "..." if len(text) > 100 else text,
                "primary_emotion": {"name": primary_emotion, "intensity": float(primary_intensity)},
                "secondary_emotion": {"name": secondary_emotion, "intensity": float(secondary_intensity)} if secondary_emotion else None,
                "musical_direction": musical_direction,
                "performance_notes": self._generate_performance_notes(emotion_map)
            }

            narrative_score.append(narrative_entry)

        return narrative_score

    def _emotion_to_musical_direction(self, emotion, intensity):
        """Convert emotion to musical performance direction"""
        intensity_modifier = ""
        if intensity > 0.8:
            intensity_modifier = "intensely "
        elif intensity > 0.5:
            intensity_modifier = "moderately "
        else:
            intensity_modifier = "slightly "

        if emotion == 'joy':
            return f"{intensity_modifier}vivace (lively and brisk)"
        elif emotion == 'sadness':
            return f"{intensity_modifier}lamentoso (lamenting, sorrowful)"
        elif emotion == 'anger':
            return f"{intensity_modifier}con fuoco (with fire and passion)"
        elif emotion == 'fear':
            return f"{intensity_modifier}misterioso (mysterious, uneasy)"
        elif emotion == 'surprise':
            return f"{intensity_modifier}capriccioso (playful, unpredictable)"
        elif emotion == 'disgust':
            return f"{intensity_modifier}pesante (heavy, ponderous)"
        else:  # neutral
            return "moderato (moderate tempo with balanced expression)"

    def _secondary_emotion_modification(self, emotion, intensity):
        """Generate modification based on secondary emotion"""
        if emotion == 'joy':
            return "with occasional bright passages"
        elif emotion == 'sadness':
            return "with moments of reflection"
        elif emotion == 'anger':
            return "with underlying tension"
        elif emotion == 'fear':
            return "with hints of uncertainty"
        elif emotion == 'surprise':
            return "with unexpected shifts"
        elif emotion == 'disgust':
            return "with moments of dissonance"
        else:  # neutral
            return "maintaining balanced phrasing"

    def _generate_performance_notes(self, emotion_map):
        """Generate specific performance notes based on emotional content"""
        notes = []

        # Check for specific emotional combinations and generate appropriate notes
        if emotion_map['joy'] > 0.3 and emotion_map['surprise'] > 0.3:
            notes.append("Use rubato freely to emphasize surprising elements while maintaining a joyful character")

        if emotion_map['sadness'] > 0.3 and emotion_map['fear'] > 0.2:
            notes.append("Allow subtle dissonances to linger, emphasizing the uneasy quality beneath the sadness")

        if emotion_map['anger'] > 0.5:
            notes.append("Emphasize rhythmic accents and use more percussive articulation")

        if emotion_map['neutral'] > 0.4:
            notes.append("Maintain tonal clarity and balanced phrasing")

        if emotion_map['surprise'] > 0.5:
            notes.append("Insert brief pauses before unexpected harmonic shifts")

        # Add default note if none were generated
        if not notes:
            top_emotion = max(emotion_map.items(), key=lambda x: x[1])[0]
            intensity = emotion_map[top_emotion]

            if intensity > 0.7:
                notes.append(f"Focus primarily on expressing {top_emotion} with full emotional commitment")
            else:
                notes.append(f"Express {top_emotion} with nuance, allowing for subtle emotional shifts")

        return notes

    def generate_adaptive_composition(self, user_parameters, text):
        """Generate a composition with user-defined parameters that affect emotional mapping"""
        # Process text normally first
        result = self.process_literary_text(text)

        # Apply user parameters to modify the results
        modified_result = self._apply_user_modifications(result, user_parameters)

        return modified_result

    def _apply_user_modifications(self, original_result, user_parameters):
        """Apply user customizations to the generated music"""
        modified_result = original_result.copy()

        # Modify tempo range if specified
        if 'tempo_range' in user_parameters:
            min_tempo, max_tempo = user_parameters['tempo_range']
            for i, prompt_data in enumerate(modified_result['musiclm_prompts']):
                # Adjust the tempo within musical features
                original_tempo = prompt_data['musical_features']['tempo']
                # Scale the tempo to the new range
                normalized_tempo = (original_tempo - 60) / 120  # Normalize to 0-1 range
                new_tempo = min_tempo + normalized_tempo * (max_tempo - min_tempo)
                modified_result['musiclm_prompts'][i]['musical_features']['tempo'] = new_tempo

                # Regenerate MIDI with new tempo
                midi_buffer = self.music_generator.create_midi_from_features(
                    modified_result['musiclm_prompts'][i]['musical_features'])
                modified_result['midi_files'][i] = midi_buffer

        # Modify instrument selection if specified
        if 'preferred_instruments' in user_parameters:
            preferred_instruments = user_parameters['preferred_instruments']
            instrument_mapping = {
                'piano': 0,
                'strings': 1,
                'guitar': 2,
                'synth': 3,
                'orchestral': 4,
                'percussion': 5
            }

            for i, prompt_data in enumerate(modified_result['musiclm_prompts']):
                # Get the emotion that would be best represented by each preferred instrument
                emotion_instrument_match = self._match_emotions_to_instruments(
                    prompt_data['emotion_map'], preferred_instruments)

                # Update the instrumentation value
                if emotion_instrument_match:
                    instrument_value = instrument_mapping.get(emotion_instrument_match, 0)
                    modified_result['musiclm_prompts'][i]['musical_features']['instrumentation'] = instrument_value

                    # Regenerate MIDI with new instrumentation
                    midi_buffer = self.music_generator.create_midi_from_features(
                        modified_result['musiclm_prompts'][i]['musical_features'])
                    modified_result['midi_files'][i] = midi_buffer

        # Modify mode preference if specified
        if 'mode_preference' in user_parameters:
            mode_pref = user_parameters['mode_preference']  # 'major', 'minor', or 'follow_emotion'

            if mode_pref != 'follow_emotion':
                mode_value = 1.0 if mode_pref == 'major' else 0.0

                for i, prompt_data in enumerate(modified_result['musiclm_prompts']):
                    modified_result['musiclm_prompts'][i]['musical_features']['mode'] = mode_value

                    # Regenerate MIDI with new mode
                    midi_buffer = self.music_generator.create_midi_from_features(
                        modified_result['musiclm_prompts'][i]['musical_features'])
                    modified_result['midi_files'][i] = midi_buffer

        # Add user-specific narrative if requested
        if 'personalized_narrative' in user_parameters and user_parameters['personalized_narrative']:
            user_name = user_parameters.get('user_name', 'User')
            modified_result['personalized_narrative'] = self._generate_personalized_narrative(
                original_result['emotion_progression'], user_name)

        return modified_result

    def _match_emotions_to_instruments(self, emotion_map, preferred_instruments):
        """Match emotions in the segment to preferred instruments"""
        # Define emotional affinities for different instruments
        instrument_emotion_affinity = {
            'piano': ['joy', 'sadness', 'neutral'],
            'strings': ['sadness', 'fear', 'joy'],
            'guitar': ['neutral', 'sadness', 'joy'],
            'synth': ['surprise', 'fear', 'anger'],
            'orchestral': ['joy', 'fear', 'anger'],
            'percussion': ['anger', 'surprise', 'disgust']
        }

        # Find the dominant emotion
        dominant_emotion = max(emotion_map.items(), key=lambda x: x[1])[0]

        # Check which preferred instrument has the best affinity for this emotion
        best_match = None
        best_affinity = -1

        for instrument in preferred_instruments:
            if instrument in instrument_emotion_affinity:
                affinities = instrument_emotion_affinity[instrument]
                if dominant_emotion in affinities:
                    affinity_score = 3 - affinities.index(dominant_emotion)  # Higher score for earlier in the list
                    if affinity_score > best_affinity:
                        best_affinity = affinity_score
                        best_match = instrument

        # If no good match, just use the first preferred instrument
        if not best_match and preferred_instruments:
            best_match = preferred_instruments[0]

        return best_match

    def _generate_personalized_narrative(self, emotion_progression, user_name):
        """Generate a personalized narrative about the emotional journey"""
        # Identify the emotional journey
        emotional_journey = []
        for emotion, values in emotion_progression.items():
            avg_intensity = sum(values) / len(values)
            if avg_intensity > 0.2:  # Only consider emotions with significant presence
                emotional_journey.append((emotion, avg_intensity))

        # Sort by intensity
        emotional_journey.sort(key=lambda x: x[1], reverse=True)

        # Create personalized narrative
        narrative = f"Dear {user_name},\n\n"
        narrative += "Your literary journey evokes a rich emotional tapestry that I've translated into music. "

        if emotional_journey:
            top_emotion, top_intensity = emotional_journey[0]
            narrative += f"The predominant feeling is {top_emotion}, "

            if top_emotion == 'joy':
                narrative += "suggesting a bright, uplifting quality in your writing. "
            elif top_emotion == 'sadness':
                narrative += "revealing depth and poignancy in your narrative. "
            elif top_emotion == 'anger':
                narrative += "showing passionate intensity and powerful expression. "
            elif top_emotion == 'fear':
                narrative += "creating tension and anticipation throughout your text. "
            elif top_emotion == 'surprise':
                narrative += "indicating unexpected turns and creative unpredictability. "
            else:
                narrative += "creating a distinctive mood throughout your text. "

            # Add secondary emotions
            if len(emotional_journey) > 1:
                narrative += "This is beautifully complemented by undertones of "
                secondary_emotions = [e[0] for e in emotional_journey[1:3]]  # Up to 2 secondary emotions
                narrative += " and ".join(secondary_emotions) + ". "

        narrative += "\n\nThe music I've composed reflects these emotional qualities, with each segment "
        narrative += "carefully crafted to enhance the specific mood of your writing. "
        narrative += "As the narrative progresses, listen for the subtle shifts in melody, harmony, and tempo "
        narrative += "that mirror the emotional journey of your text.\n\n"
        narrative += "I hope this musical interpretation resonates with your creative vision.\n\n"
        narrative += "Musically yours,\nLiterary Music Generator"

        return narrative

    def export_full_composition(self, result, output_format="midi", output_path="literary_composition"):
        """Export the full musical composition in the specified format"""
        if output_format == "midi":
            return self._export_midi_composition(result, output_path)
        elif output_format == "json":
            return self._export_json_composition(result, output_path)
        elif output_format == "musicxml":
            return self._export_musicxml_composition(result, output_path)
        else:
            raise ValueError(f"Unsupported output format: {output_format}")

    def _export_midi_composition(self, result, output_path):
        """Export the composition as individual MIDI files and a combined file"""
        # Ensure output directory exists
        os.makedirs(output_path, exist_ok=True)

        # Save individual segment files
        segment_files = []
        for i, midi_buffer in enumerate(result['midi_files']):
            filename = f"{output_path}/segment_{i+1}.mid"
            with open(filename, 'wb') as f:
                f.write(midi_buffer.getvalue())
            segment_files.append(filename)

        # Create combined file
        combined_filename = f"{output_path}/combined_composition.mid"
        combined_file = combine_segments(result['midi_files'], combined_filename)

        return {
            "segment_files": segment_files,
            "combined_file": combined_file
        }

    def _export_json_composition(self, result, output_path):
        """Export all composition data to JSON format"""
        # Ensure output directory exists
        os.makedirs(output_path, exist_ok=True)

        # Create exportable data structure
        export_data = {
            "title": "Literary Music Composition",
            "created_date": pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S"),
            "segments": [],
            "emotion_progression": {k: [float(v) for v in vals] for k, vals in result['emotion_progression'].items()},
            "emotional_narrative": result['emotional_narrative']
        }

        # Add segment data
        for i, prompt_data in enumerate(result['musiclm_prompts']):
            segment_data = {
                "segment_number": i + 1,
                "text": result['segments'][i],
                "dominant_emotions": prompt_data['dominant_emotions'],
                "musical_features": {k: float(v) if isinstance(v, (int, float)) else v
                                    for k, v in prompt_data['musical_features'].items()},
                "musiclm_prompt": prompt_data['prompt']
            }
            export_data["segments"].append(segment_data)

        # Add transition data if available
        if 'transition_prompts' in result and result['transition_prompts']:
            export_data["transitions"] = result['transition_prompts']

        # Save to file
        output_file = f"{output_path}/composition_data.json"
        with open(output_file, 'w') as f:
            json.dump(export_data, f, indent=2)

        return {"json_file": output_file}

    def _export_musicxml_composition(self, result, output_path):
        """Export composition in MusicXML format for notation software"""
        # Note: This is a placeholder for the MusicXML export functionality
        # In a real implementation, you would use a library like music21
        # to create MusicXML files from the MIDI data

        # Ensure output directory exists
        os.makedirs(output_path, exist_ok=True)

        # Create a mock MusicXML file for demonstration
        output_file = f"{output_path}/composition.musicxml"

        with open(output_file, 'w') as f:
            f.write('<?xml version="1.0" encoding="UTF-8"?>\n')
            f.write('<!DOCTYPE score-partwise PUBLIC "-//Recordare//DTD MusicXML 3.1 Partwise//EN" "http://www.musicxml.com/dtds/partwise.dtd">\n')
            f.write('<score-partwise version="3.1">\n')
            f.write('  <part-list>\n')
            f.write('    <score-part id="P1">\n')
            f.write('      <part-name>Literary Music Composition</part-name>\n')
            f.write('    </score-part>\n')
            f.write('  </part-list>\n')
            f.write('  <part id="P1">\n')
            f.write('    <!-- Placeholder for actual music notation -->\n')
            f.write('  </part>\n')
            f.write('</score-partwise>\n')

        return {"musicxml_file": output_file}

    def analyze_user_feedback(self, user_feedback, result):
        """Analyze user feedback to improve future compositions"""
        # Extract feedback metrics
        emotion_accuracy = user_feedback.get('emotion_accuracy', 0)
        musical_quality = user_feedback.get('musical_quality', 0)
        narrative_coherence = user_feedback.get('narrative_coherence', 0)
        overall_satisfaction = user_feedback.get('overall_satisfaction', 0)

        # Calculate weighted score
        weighted_score = (emotion_accuracy * 0.4 +
                          musical_quality * 0.3 +
                          narrative_coherence * 0.2 +
                          overall_satisfaction * 0.1)

        # Process specific comments
        comments = user_feedback.get('comments', '')

        # Generate suggestions for improvement
        improvement_suggestions = self._generate_improvement_suggestions(
            weighted_score, emotion_accuracy, musical_quality, narrative_coherence, comments)

        # Create summary report
        feedback_analysis = {
            "weighted_score": weighted_score,
            "strengths": [],
            "areas_for_improvement": [],
            "suggested_adjustments": improvement_suggestions
        }

        # Identify strengths
        if emotion_accuracy >= 4:
            feedback_analysis["strengths"].append("Emotional mapping accuracy")
        if musical_quality >= 4:
            feedback_analysis["strengths"].append("Musical composition quality")
        if narrative_coherence >= 4:
            feedback_analysis["strengths"].append("Narrative coherence")

        # Identify weaknesses
        if emotion_accuracy < 3:
            feedback_analysis["areas_for_improvement"].append("Emotional mapping accuracy")
        if musical_quality < 3:
            feedback_analysis["areas_for_improvement"].append("Musical composition quality")
        if narrative_coherence < 3:
            feedback_analysis["areas_for_improvement"].append("Narrative coherence")

        return feedback_analysis

    def _generate_improvement_suggestions(self, weighted_score, emotion_accuracy,
                                         musical_quality, narrative_coherence, comments):
        """Generate specific suggestions based on feedback"""
        suggestions = []

        # Low emotion accuracy suggestions
        if emotion_accuracy < 3:
            suggestions.append("Refine emotional analysis by adjusting the text segmentation algorithm")
            suggestions.append("Update emotion-to-music mapping with more nuanced correlations")

        # Low musical quality suggestions
        if musical_quality < 3:
            suggestions.append("Enhance musical complexity in composition algorithm")
            suggestions.append("Improve dynamic range and expressive elements")

        # Low narrative coherence suggestions
        if narrative_coherence < 3:
            suggestions.append("Strengthen transitions between segments")
            suggestions.append("Improve temporal coherence model parameters")

        # Check comments for specific keywords
        if 'repetitive' in comments.lower():
            suggestions.append("Increase variation in melodic and harmonic patterns")

        if 'disjointed' in comments.lower() or 'disconnected' in comments.lower():
            suggestions.append("Apply smoother transition algorithms between emotional segments")

        if 'simple' in comments.lower() or 'basic' in comments.lower():
            suggestions.append("Increase complexity parameters for harmony and rhythm")

        # Add general suggestions based on overall score
        if weighted_score < 2:
            suggestions.append("Consider fundamental revision of the emotion-to-music mapping algorithm")
        elif weighted_score < 3:
            suggestions.append("Fine-tune parameters across multiple dimensions for better coherence")
        elif weighted_score < 4:
            suggestions.append("Make minor adjustments to enhance specific areas of weakness")

        return suggestions

In [None]:
#

In [None]:
#

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import BertModel, BertTokenizer
from tqdm import tqdm


class MuLANTextEncoder(nn.Module):
    """Text encoder based on BERT with MuLAN-style additions for music understanding"""
    def __init__(self, bert_model="bert-base-uncased", embedding_dim=512):
        super(MuLANTextEncoder, self).__init__()
        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(bert_model)
        self.tokenizer = BertTokenizer.from_pretrained(bert_model)

        # Project BERT outputs to our embedding space
        self.projection = nn.Linear(self.bert.config.hidden_size, embedding_dim)

        # Music-specific token embeddings (words related to music semantics)
        self.music_token_embedding = nn.Embedding(1000, embedding_dim)
        self.music_vocab = self._create_music_vocab()

    def _create_music_vocab(self):
        """Create vocabulary mapping for music-specific terms"""
        music_terms = [
            "tempo", "rhythm", "melody", "harmony", "bass", "treble",
            "major", "minor", "piano", "guitar", "drums", "strings",
            "loud", "soft", "fast", "slow", "staccato", "legato",
            # Emotions
            "happy", "sad", "angry", "fearful", "tender", "excited",
            # Instruments
            "violin", "cello", "flute", "trumpet", "saxophone", "harp",
            "synthesizer", "electric guitar", "acoustic guitar", "bass guitar",
            # Musical styles
            "classical", "jazz", "rock", "electronic", "ambient", "folk",
            "pop", "hip hop", "blues", "country", "metal", "funk",
            # Descriptive terms
            "bright", "dark", "warm", "cold", "mellow", "harsh",
            "smooth", "rough", "ethereal", "gritty", "lush", "sparse",
            "uplifting", "melancholic", "energetic", "calm", "tense", "relaxed"
        ]
        return {term: i for i, term in enumerate(music_terms)}

    def forward(self, text):
        """Encode text into music-aware embeddings"""
        # Tokenize input
        tokens = self.tokenizer(text, return_tensors="pt",
                               padding=True, truncation=True, max_length=512)

        # Move to the same device as the model
        tokens = {k: v.to(self.bert.device) for k, v in tokens.items()}

        # Get BERT embeddings
        with torch.no_grad():
            outputs = self.bert(**tokens)

        # Use the [CLS] token embedding as sequence representation
        sequence_embedding = outputs.last_hidden_state[:, 0, :]

        # Project to our embedding space
        projected_embedding = self.projection(sequence_embedding)

        # Extract and enhance music-specific terms
        music_embedding = self._enhance_music_terms(text, projected_embedding)

        return music_embedding

    def _enhance_music_terms(self, text, embedding):
        """Enhance embeddings with music-specific token information"""
        # Check if text is a list or a single string
        if isinstance(text, str):
            text = [text]  # Convert to list for consistent processing

        batch_enhanced = []

        for i, t in enumerate(text):
            t_lower = t.lower()

            # Initialize music embedding contribution
            music_contrib = torch.zeros_like(embedding[i])
            count = 0

            # Look for music terms - use word boundaries to avoid partial matches
            for term, idx in self.music_vocab.items():
                # Check for word boundaries to find whole words
                import re
                matches = re.findall(r'\b' + re.escape(term) + r'\b', t_lower)
                if matches:
                    term_tensor = torch.tensor([idx], device=embedding.device)
                    music_contrib += self.music_token_embedding(term_tensor).squeeze(0)
                    count += 1

            # Add weighted music embedding if any terms found
            if count > 0:
                enhanced = embedding[i] + (music_contrib / count) * 0.3  # 30% contribution
            else:
                enhanced = embedding[i]

            batch_enhanced.append(enhanced)

        return torch.stack(batch_enhanced)


class SoundStreamEncoder(nn.Module):
    """Audio encoder inspired by SoundStream architecture"""
    def __init__(self, input_channels=1, embedding_dim=512):
        super(SoundStreamEncoder, self).__init__()

        # Convolutional encoder
        self.encoder = nn.Sequential(
            # Initial convolution
            nn.Conv1d(input_channels, 32, kernel_size=7, stride=1, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),

            # Downsampling convolutions
            nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1),  # Downsample 2x
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),  # Downsample 2x
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),  # Downsample 2x
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1),  # Downsample 2x
            nn.BatchNorm1d(512),
            nn.ReLU(),

            # Additional convolutions
            nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
        )

        # Projection to embedding space
        self.projection = nn.Linear(512, embedding_dim)

        # Quantizer (VQ layer)
        self.codebook_size = 1024
        self.codebook = nn.Embedding(self.codebook_size, embedding_dim)

        # Residual vector quantizer
        self.num_quantizers = 8
        self.codebook_dim = embedding_dim // self.num_quantizers
        self.codebooks = nn.ModuleList([
            nn.Embedding(self.codebook_size, self.codebook_dim)
            for _ in range(self.num_quantizers)
        ])

    def vector_quantize(self, x, codebook):
        """
        Vector quantization implementation with improved efficiency and numerical stability
        x: (batch_size, codebook_dim)
        """
        # Calculate squared L2 norm - more stable implementation
        x_norm = torch.sum(x**2, dim=1, keepdim=True)
        codebook_norm = torch.sum(codebook.weight**2, dim=1)

        # Calculate dot product
        dot_product = torch.matmul(x, codebook.weight.t())

        # Calculate distance using the expanded form of Euclidean distance
        d = x_norm + codebook_norm - 2 * dot_product

        # Prevent numerical instability
        d = torch.clamp(d, min=1e-5)

        # Find nearest neighbor
        indices = torch.argmin(d, dim=1)
        quantized = codebook(indices)

        # Straight-through estimator
        quantized_st = x + (quantized - x).detach()

        # Calculate VQ loss components for training
        commitment_loss = F.mse_loss(x, quantized.detach())
        codebook_loss = F.mse_loss(quantized, x.detach())

        return quantized_st, indices, commitment_loss + codebook_loss

    def forward(self, x):
        """
        Encode audio into latent representations
        x shape: [batch_size, 1, time]
        """
        # Apply convolutional encoder
        encoded = self.encoder(x)  # [batch_size, 512, time/16]

        # Global pooling to get fixed-size representation
        pooled = F.adaptive_avg_pool1d(encoded, 1).squeeze(-1)  # [batch_size, 512]

        # Project to embedding space
        embedding = self.projection(pooled)  # [batch_size, embedding_dim]

        # Reshape for residual VQ
        batch_size = embedding.shape[0]
        reshaped = embedding.view(batch_size, self.num_quantizers, self.codebook_dim)

        # Apply residual vector quantization
        quantized_list = []
        indices_list = []

        for i in range(self.num_quantizers):
            q, idx = self.vector_quantize(reshaped[:, i], self.codebooks[i])
            quantized_list.append(q)
            indices_list.append(idx)

        # Reshape back
        quantized = torch.cat(quantized_list, dim=1).view(batch_size, -1)

        # Combine with encoded for temporal information
        return quantized, encoded


class SoundStreamDecoder(nn.Module):
    """Audio decoder inspired by SoundStream architecture"""
    def __init__(self, embedding_dim=512, output_channels=1):
        super(SoundStreamDecoder, self).__init__()

        # Project embedding to the right dimension for the decoder
        self.pre_decoder = nn.Linear(embedding_dim, 512)

        # Convolutional decoder
        self.decoder = nn.Sequential(
            # Initial convolution
            nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            # Upsampling convolutions using transposed convolutions
            nn.ConvTranspose1d(512, 256, kernel_size=4, stride=2, padding=1),  # Upsample 2x
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),  # Upsample 2x
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),   # Upsample 2x
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1),    # Upsample 2x
            nn.BatchNorm1d(32),
            nn.ReLU(),

            # Final convolution to get to the right number of channels
            nn.Conv1d(32, output_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh()  # Output in [-1, 1] range for audio
        )

    def forward(self, z, encoded=None, length=16000):
        """
        Decode latent representation to audio
        z shape: [batch_size, embedding_dim]
        encoded shape (optional): [batch_size, 512, time/16]
        """
        # Project to the right dimension
        z_proj = self.pre_decoder(z)  # [batch_size, 512]

        if encoded is not None:
            # Use the temporal information from the encoder
            z_temporal = z_proj.unsqueeze(-1) * F.adaptive_avg_pool1d(encoded, encoded.size(-1))
        else:
            # Create a temporal dimension
            time_steps = length // 16  # Depends on the encoder downsampling
            z_temporal = z_proj.unsqueeze(-1).repeat(1, 1, time_steps)

        # Apply convolutional decoder
        decoded = self.decoder(z_temporal)  # [batch_size, output_channels, time]

        return decoded


class AttentionBlock(nn.Module):
    """Self-attention block for the UNet architecture"""
    def __init__(self, channels):
        super(AttentionBlock, self).__init__()
        self.channels = channels

        self.norm = nn.GroupNorm(num_groups=8, num_channels=channels)
        self.qkv_proj = nn.Conv1d(channels, channels * 3, kernel_size=1)
        self.out_proj = nn.Conv1d(channels, channels, kernel_size=1)

    def forward(self, x):
        batch, channels, length = x.shape
        assert channels == self.channels

        # Apply normalization
        h = self.norm(x)

        # Compute query, key, value
        qkv = self.qkv_proj(h)
        qkv = qkv.reshape(batch, 3, channels, length)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        # Compute attention
        scale = 1.0 / (channels ** 0.5)
        attention = torch.einsum('bct,bcs->bts', q, k) * scale
        attention = F.softmax(attention, dim=2)

        # Apply attention
        output = torch.einsum('bts,bcs->bct', attention, v)

        # Project back to original dimension
        output = self.out_proj(output)

        # Residual connection
        return x + output


class TextConditionedResidualBlock(nn.Module):
    """Residual block with text conditioning for UNet architecture"""
    def __init__(self, in_channels, out_channels, text_channels, dropout=0.1):
        super(TextConditionedResidualBlock, self).__init__()

        self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels)
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)

        self.norm2 = nn.GroupNorm(num_groups=8, num_channels=out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)

        # Text conditioning
        self.text_proj = nn.Linear(text_channels, out_channels)

        # Shortcut connection
        if in_channels != out_channels:
            self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, text_embed):
        """
        Forward pass with text conditioning
        x: [batch_size, in_channels, time]
        text_embed: [batch_size, text_channels]
        """
        # First convolution block
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        # Apply text conditioning
        text_embed = self.text_proj(text_embed).unsqueeze(-1)
        text_embed = text_embed.expand(-1, -1, h.size(-1))
        h = h + text_embed

        # Second convolution block
        h = self.norm2(h)
        h = F.silu(h)
        h = self.dropout(h)
        h = self.conv2(h)

        # Skip connection
        return h + self.shortcut(x)


class MusicConditionedUNet(nn.Module):
    """UNet-style model for high-resolution audio generation conditioned on text"""
    def __init__(self, input_channels=1, output_channels=1, base_channels=32, embedding_dim=512):
        super(MusicConditionedUNet, self).__init__()

        # Initial convolution
        self.initial_conv = nn.Conv1d(input_channels, base_channels, kernel_size=3, padding=1)

        # Encoder (downsampling) blocks
        self.down1 = self._make_down_block(base_channels, base_channels*2, embedding_dim)
        self.down2 = self._make_down_block(base_channels*2, base_channels*4, embedding_dim)
        self.down3 = self._make_down_block(base_channels*4, base_channels*8, embedding_dim)
        self.down4 = self._make_down_block(base_channels*8, base_channels*8, embedding_dim)

        # Middle blocks
        self.mid_resblock1 = TextConditionedResidualBlock(base_channels*8, base_channels*8, embedding_dim)
        self.mid_attn = AttentionBlock(base_channels*8)
        self.mid_resblock2 = TextConditionedResidualBlock(base_channels*8, base_channels*8, embedding_dim)

        # Decoder (upsampling) blocks
        self.up4 = self._make_up_block(base_channels*16, base_channels*4, embedding_dim)
        self.up3 = self._make_up_block(base_channels*8, base_channels*2, embedding_dim)
        self.up2 = self._make_up_block(base_channels*4, base_channels, embedding_dim)
        self.up1 = self._make_up_block(base_channels*2, base_channels, embedding_dim)

        # Final convolution
        self.final_norm = nn.GroupNorm(num_groups=8, num_channels=base_channels)
        self.final_conv = nn.Conv1d(base_channels, output_channels, kernel_size=3, padding=1)

    def _make_down_block(self, in_channels, out_channels, text_channels):
        return nn.ModuleList([
            TextConditionedResidualBlock(in_channels, in_channels, text_channels),
            TextConditionedResidualBlock(in_channels, out_channels, text_channels),
            nn.Conv1d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)  # Downsample
        ])

    def _make_up_block(self, in_channels, out_channels, text_channels):
        return nn.ModuleList([
            TextConditionedResidualBlock(in_channels, in_channels, text_channels),
            TextConditionedResidualBlock(in_channels, out_channels, text_channels),
            nn.ConvTranspose1d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)  # Upsample
        ])

    def forward(self, x, text_embedding):
        """
        Forward pass with text conditioning
        x: [batch_size, 1, time] - Initial audio or noise
        text_embedding: [batch_size, embedding_dim] - Text embeddings
        """
        # Initial convolution
        h = self.initial_conv(x)

        # Downsampling path with skip connections
        skips = []

        # Down 1
        h = self.down1[0](h, text_embedding)
        h = self.down1[1](h, text_embedding)
        skips.append(h)
        h = self.down1[2](h)

        # Down 2
        h = self.down2[0](h, text_embedding)
        h = self.down2[1](h, text_embedding)
        skips.append(h)
        h = self.down2[2](h)

        # Down 3
        h = self.down3[0](h, text_embedding)
        h = self.down3[1](h, text_embedding)
        skips.append(h)
        h = self.down3[2](h)

        # Down 4
        h = self.down4[0](h, text_embedding)
        h = self.down4[1](h, text_embedding)
        skips.append(h)
        h = self.down4[2](h)

        # Middle
        h = self.mid_resblock1(h, text_embedding)
        h = self.mid_attn(h)
        h = self.mid_resblock2(h, text_embedding)

        # Upsampling path with skip connections
        # Up 4 (connecting with Down 4 skip)
        h = torch.cat([h, skips.pop()], dim=1)
        h = self.up4[0](h, text_embedding)
        h = self.up4[1](h, text_embedding)
        h = self.up4[2](h)

        # Up 3 (connecting with Down 3 skip)
        h = torch.cat([h, skips.pop()], dim=1)
        h = self.up3[0](h, text_embedding)
        h = self.up3[1](h, text_embedding)
        h = self.up3[2](h)

        # Up 2 (connecting with Down 2 skip)
        h = torch.cat([h, skips.pop()], dim=1)
        h = self.up2[0](h, text_embedding)
        h = self.up2[1](h, text_embedding)
        h = self.up2[2](h)

        # Up 1 (connecting with Down 1 skip)
        h = torch.cat([h, skips.pop()], dim=1)
        h = self.up1[0](h, text_embedding)
        h = self.up1[1](h, text_embedding)
        h = self.up1[2](h)

        # Final layers
        h = self.final_norm(h)
        h = F.silu(h)
        h = self.final_conv(h)

        return h

class TextToMusicGenerationModel(nn.Module):
    """Full text-to-music generation model combining all components"""
    def __init__(self, embedding_dim=512):
        super(TextToMusicGenerationModel, self).__init__()

        # Text encoder
        self.text_encoder = MuLANTextEncoder(embedding_dim=embedding_dim)

        # Audio encoder for conditioning
        self.audio_encoder = SoundStreamEncoder(embedding_dim=embedding_dim)

        # Audio decoder for reconstruction
        self.audio_decoder = SoundStreamDecoder(embedding_dim=embedding_dim)

        # UNet for high-resolution generation
        self.unet = MusicConditionedUNet(embedding_dim=embedding_dim)

    def encode_text(self, text):
        """Encode text description"""
        return self.text_encoder(text)

    def encode_audio(self, audio):
        """Encode audio"""
        return self.audio_encoder(audio)

    def decode_audio(self, z, encoded=None, length=16000):
        """Decode latent representation to audio"""
        return self.audio_decoder(z, encoded, length)

    # Add this method to your CustomMusicLM class
    def generate_from_text(self, text, noise_level=0.5, steps=50, length=16000):
        """Generate audio from text using iterative refinement"""
        print(f"Starting generation for text: '{text}'")

        # Encode text
        text_embedding = self.text_encoder(text)
        print(f"Text encoded, embedding shape: {text_embedding.shape}")

        # Initialize with noise
        device = next(self.parameters()).device
        audio = torch.randn(1, 1, length).to(device) * noise_level
        print(f"Initial noise created with shape: {audio.shape}")

        # Iterative refinement
        from tqdm.notebook import tqdm
        for i in tqdm(range(steps)):
            # Progressively decrease noise influence
            t = 1.0 - (i / steps)
            noise_scale = noise_level * t

            # Generate an update using the UNet
            update = self.unet(audio, text_embedding)

            # Apply the update with noise scheduling
            audio = audio * (1.0 - noise_scale) + update * noise_scale

            # Optional: apply constraints to keep audio in range
            audio = torch.clamp(audio, -1.0, 1.0)

            # Print progress occasionally
            if i % 10 == 0:
                print(f"Step {i}/{steps}, audio range: {audio.min().item():.3f} to {audio.max().item():.3f}")

        print("Generation complete!")
        return audio

    def forward(self, audio_input=None, text_input=None, mode='train'):
        """
        Forward pass through the model
        Modes:
        - 'train': Train the full model (reconstruction + generation)
        - 'encode_text': Only encode text
        - 'encode_audio': Only encode audio
        - 'generate': Generate audio from text
        """
        if mode == 'encode_text' and text_input is not None:
            return self.encode_text(text_input)

        elif mode == 'encode_audio' and audio_input is not None:
            return self.encode_audio(audio_input)

        elif mode == 'generate' and text_input is not None:
            # Use default parameters for generation
            return self.generate_from_text(text_input)

        elif mode == 'train':
            if audio_input is None or text_input is None:
                raise ValueError("Both audio_input and text_input are required for training")

            # Encode text
            text_embedding = self.encode_text(text_input)

            # Encode audio
            quantized, encoded = self.encode_audio(audio_input)

            # Decode audio (reconstruction path)
            reconstructed = self.decode_audio(quantized, encoded, audio_input.shape[-1])

            # UNet for high-resolution details (enhancement path)
            enhanced = self.unet(reconstructed, text_embedding)

            return {
                "reconstructed": reconstructed,
                "enhanced": enhanced,
                "quantized": quantized,
                "encoded": encoded,
                "text_embedding": text_embedding
            }

        else:
            raise ValueError(f"Unknown mode: {mode}")

# 5. Define training functions for the model
def train_step(model, audio_batch, text_batch, optimizer, device, lambda_rec=1.0, lambda_enhance=0.5):
    """Single training step for the text-to-music model"""
    # Move data to device
    audio_batch = audio_batch.to(device)

    # Forward pass
    outputs = model(audio_batch, text_batch, mode='train')

    # Calculate losses
    # Reconstruction loss - compare original audio with reconstructed audio
    rec_loss = F.mse_loss(outputs["reconstructed"], audio_batch)

    # Enhancement loss - compare enhanced audio with original audio
    enhance_loss = F.mse_loss(outputs["enhanced"], audio_batch)

    # Total loss
    loss = lambda_rec * rec_loss + lambda_enhance * enhance_loss

    # Backward pass
    optimizer.zero_grad()
    loss.backward()

    # Gradient clipping for stability
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Update weights
    optimizer.step()

    return {
        "loss": loss.item(),
        "rec_loss": rec_loss.item(),
        "enhance_loss": enhance_loss.item()
    }

# 6. Example usage of the model
def example_usage():
    # Initialize model
    model = TextToMusicGenerationModel(embedding_dim=512)
    model.to(device)

    # Set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Example training loop
    for epoch in range(num_epochs):
        for audio_batch, text_batch in dataloader:
            metrics = train_step(model, audio_batch, text_batch, optimizer, device)
            print(f"Epoch {epoch}, Loss: {metrics['loss']:.4f}")

    # Example generation
    text_prompt = "A calm classical piano piece with a melancholic melody"
    generated_audio = model.generate_from_text(text_prompt, steps=100)

    # Save the generated audio (pseudo-code)
    save_audio(generated_audio.cpu().numpy(), "generated_music.wav")

# Add this at the top of your notebook
import IPython.display as ipd
from scipy.io import wavfile
import numpy as np
import matplotlib.pyplot as plt

# After generating your audio
def display_audio_in_colab(audio_tensor, sample_rate=16000, filename="generated_audio.wav"):
    # Convert to numpy and scale appropriately
    audio_np = audio_tensor.squeeze().detach().cpu().numpy()
    audio_np = np.clip(audio_np, -1.0, 1.0)

    # Display waveform
    plt.figure(figsize=(10, 4))
    plt.plot(audio_np)
    plt.title("Generated Audio Waveform")
    plt.xlabel("Time (samples)")
    plt.ylabel("Amplitude")
    plt.grid(True)
    plt.show()

    # Scale for 16-bit WAV
    audio_int16 = (audio_np * 32767).astype(np.int16)

    # Save as WAV file
    wavfile.write(filename, sample_rate, audio_int16)

    # Display playable audio in the notebook
    print("Generated audio playback:")
    return ipd.Audio(audio_np, rate=sample_rate)

# Example usage:
# Define necessary variables first
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TextToMusicGenerationModel(embedding_dim=512)
model.to(device)

# Try generating audio
text_prompt = "A calm piano melody with soft strings"
generated_audio = model.generate_from_text(text_prompt)
display_audio_in_colab(generated_audio)

Starting generation for text: 'A calm piano melody with soft strings'
Text encoded, embedding shape: torch.Size([1, 512])
Initial noise created with shape: torch.Size([1, 1, 16000])


  0%|          | 0/50 [00:00<?, ?it/s]

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1000 but got size 2000 for tensor number 1 in the list.