In [1]:
import torch
from torch import nn
from tqdm import tqdm
from pathlib import Path
from audiocraft.models import MusicGen
import argparse
import random
import numpy as np

SEED = 1

random.seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_num_threads(1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
music_model = MusicGen.get_pretrained("facebook/musicgen-small", device=device)
music_model.set_generation_params(duration=4, use_sampling=False)
music_model.compression_model.eval()
music_model.lm.eval()
print(f"Transformer layers: {len(music_model.lm.transformer.layers)}")



Transformer layers: 24


In [3]:
display(music_model.lm)

LMModel(
  (cfg_dropout): ClassifierFreeGuidanceDropout(p=0.3)
  (att_dropout): AttributeDropout({})
  (condition_provider): ConditioningProvider(
    (conditioners): ModuleDict(
      (description): T5Conditioner(
        (output_proj): Linear(in_features=768, out_features=1024, bias=True)
      )
    )
  )
  (fuser): ConditionFuser()
  (emb): ModuleList(
    (0-3): 4 x ScaledEmbedding(2049, 1024)
  )
  (transformer): StreamingTransformer(
    (layers): ModuleList(
      (0-23): 24 x StreamingTransformerLayer(
        (self_attn): StreamingMultiheadAttention(
          (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (linear1): Linear(in_features=1024, out_features=4096, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=4096, out_features=1024, bias=False)
        (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  

In [4]:
out_norm_layer = music_model.lm.out_norm
print(f"Output norm: {out_norm_layer}")

Output norm: LayerNorm((1024,), eps=1e-05, elementwise_affine=True)


In [5]:
linear_layers = music_model.lm.linears
print(f"Output norm: {linear_layers}")

Output norm: ModuleList(
  (0-3): 4 x Linear(in_features=1024, out_features=2048, bias=False)
)


In [6]:
from typing import Optional, Tuple, Literal, get_args

MUSICGEN_SAMPLE_RATE, PARLER_TTS_SAMPLE_RATE = 32_000, 44_100
_AUDIO_TYPES = Literal["music", "speech"]

def display_audio(samples: torch.Tensor, sample_rate: int = None, audio_type: str = None) -> None:
    """Renders an audio player for the given audio samples.

    Args:
        samples (torch.Tensor): a Tensor of decoded audio samples
            with shapes [B, C, T] or [C, T]
        sample_rate (int): sample rate audio should be displayed with.
        audio_type (str): choose music or speech to set sample_rate to model
            default sample rates.
    """
    audio_options = get_args(_AUDIO_TYPES)
    if audio_type != None:
      assert audio_type in audio_options, f"'{audio_type}' is not in {audio_options}"
    assert (sample_rate != None) != (audio_type != None), \
      f" either sample_rate or audio_type has to be set but not both or neither, \
      sample_rate: '{sample_rate}' audio_type: '{audio_type}'"
    assert samples.dim() == 2 or samples.dim() == 3


    if audio_type == 'music':
      sample_rate = MUSICGEN_SAMPLE_RATE
    elif audio_type == 'speech':
      sample_rate = PARLER_TTS_SAMPLE_RATE

    samples = samples.detach().cpu()
    if samples.dim() == 2:
        samples = samples[None, ...]

    for audio in samples:
        ipd.display(ipd.Audio(audio, rate=sample_rate))

def display_music(samples: torch.Tensor) -> None:
  display_audio(samples, audio_type='music')

def display_speech(samples: torch.Tensor) -> None:
  display_audio(samples, audio_type='speech')

def display_wav_file(paths: str|list) -> None:
  if isinstance(paths, list):
    for path in paths:
      ipd.display(ipd.Audio(path))
  else:
    ipd.display(ipd.Audio(paths))

In [7]:
import torch.nn.functional as F
import IPython.display as ipd

x = torch.randn(2, 1024)

def top_k_sample(logits, k=10):
    values, indices = torch.topk(logits, k, dim=-1)
    probs = F.softmax(values, dim=-1)
    idx = torch.multinomial(probs, 1).squeeze(-1)
    return torch.gather(indices, -1, idx.unsqueeze(-1)).squeeze(-1)


with torch.no_grad():
    norm_out = music_model.lm.out_norm(x)
    print(f"Norm output shape: {norm_out.shape}")
    logits_out = [linear(norm_out) for linear in music_model.lm.linears]
    print(logits_out[0].shape)


    pred_tokens = [top_k_sample(logit) for logit in logits_out]
    
    codes = torch.stack(pred_tokens, dim=0).unsqueeze(0)
    waveform = music_model.compression_model.decode(codes)

display_audio(waveform, sample_rate=32000)

Norm output shape: torch.Size([2, 1024])
torch.Size([2, 2048])


## Baseline evaluation

## Our model evaluation

In [8]:
from activations_dataset import ActivationsDataset
from mlp import MLP

mlp_model = MLP(
    input_dim=1024,
    hidden_dim=2048,
    output_dim=1024,
    dropout=0.2
)

instruments = ['guitar', 'piano', 'trumpet', 'violin']
instruments = ['piano']

for layer in range(len(music_model.lm.transformer.layers) - 1):
    for instrument in instruments:
        print(f"Testing layer {layer} for instrument {instrument}")
        
        # Create the dataset for the current layer and instrument
        test_set = ActivationsDataset(
            data_dir='/home/scur1188/ai-intepr-project/data',
            instruments=[instrument],
            seeds=[1, 2, 3],
            split="test",
            layer=layer,
        )

        pt_file = torch.load(
            f"/home/scur1188/ai-intepr-project/weights/layer_{layer:02d}_{instrument}_seeds-1-2-3_mlp.pt", map_location=device)
        mlp_model.load_state_dict(pt_file['model_state_dict'])
        mlp_model.eval()
        mlp_model.to(device)
        music_model.lm.eval()
        music_model.lm.to(device)
        music_model.compression_model.to(device)
        music_model.compression_model.eval()
        break

        for i in tqdm(range(len(test_set))):
            x, y = test_set[i]
            x, y = x.to(device), y.to(device)
            x = x.unsqueeze_(0)  # Add batch dimension

            with torch.no_grad():
                # Get the output from the MLP
                output = mlp_model(x)

                # Pass the input through the MusicGen model
                norm_out = music_model.lm.out_norm(x)
                logits_out = [linear(norm_out) for linear in music_model.lm.linears]
                pred_tokens = [top_k_sample(logit) for logit in logits_out]

                codes = torch.stack(pred_tokens, dim=0).unsqueeze(0)
                waveform = music_model.compression_model.decode(codes)

                # Display the waveform
                display_audio(waveform, sample_rate=32000)
                break

Testing layer 0 for instrument piano


FileNotFoundError: [Errno 2] No such file or directory: '/home/scur1188/ai-intepr-project/weights/layer_00_piano_seeds-1-2-3_mlp.pt'

In [None]:
from transformers import AutoProcessor

# load data as prompts
x = 'Compose a happy classical song with a piano melody. Use a fast tempo.'

mlp_model = MLP(
    input_dim=1024,
    hidden_dim=2048,
    output_dim=1024,
    dropout=0.1
)
pt_file = torch.load(
            f"/home/scur1188/ai-intepr-project/weights/layer_20_piano_seeds-1-2-3_mlp.pt", map_location=device)
mlp_model.load_state_dict(pt_file['model_state_dict'])
mlp_model.eval()
mlp_model.to(device)
music_model = MusicGen.get_pretrained("facebook/musicgen-small", device=device)
music_model.set_generation_params(duration=4)
music_model.compression_model.eval()
music_model.lm.eval()

original_first = music_model.lm.transformer.layers[:20]
music_model.lm.transformer.layers = nn.ModuleList([*original_first, mlp_model])
audio = music_model.generate([x])

music_model.lm.transformer.layers = nn.ModuleList([*original_first])
audio_decoder_lens = music_model.generate([x])



In [None]:
display_music(audio)
display_music(audio_decoder_lens)

In [None]:
from hear21passt.base import load_model

music_classifier = load_model(mode="logits").to(device)

In [None]:
piano_logit_index = 153
guitar_logit_index = 140
trumpet_logit_index = 187
violin_fiddle_logit_index = 191

logits = music_classifier(audio)

piano_logit = logits[piano_logit_index]

In [None]:
print(f"Piano logit: {piano_logit.item()}")