In [1]:
import matplotlib
import numpy as np
from IPython.display import HTML, display


def visualize_sentences(sentences, weight_lists, cmap_name="OrRd"):
    """
    Visualize multiple sentences with token weights in a list format, each with a white background.

    Args:
    - sentences (list of list of str): A list of sentences, where each sentence is a list of tokens.
    - weight_lists (list of list of float): A list of weight lists, one for each sentence's tokens.
    - cmap_name (str): Name of the matplotlib colormap to use.

    Returns:
    - None: Displays the highlighted sentences.
    """
    # Get colormap
    cmap = matplotlib.colormaps[cmap_name]
    
    # Build HTML for each sentence
    sentence_html_list = []
    for tokens, weights in zip(sentences, weight_lists):
        weights = np.array(weights)

        # Normalize weights to [0, 1]
        norm_weights = (weights - min(weights)) / (max(weights) - min(weights))
        
        # Convert weights to colors
        colors = [matplotlib.colors.rgb2hex(cmap(w)) for w in norm_weights]
        
        # Create highlighted text for the sentence
        highlighted_sentence = " ".join(
            f'<span style="background-color:{color}; padding:0px;  font-weight:bold; color:black;">{token}</span>'
            for token, color in zip(tokens, colors)
        )
        sentence_html_list.append(f"<li>{highlighted_sentence}</li>")
    
    # Combine all sentences into a single container with a white background
    html_content = f"""
    <div style='font-family:monospace; background-color:white; padding:10px;'>
        <ul style='list-style-type:none; padding:1px; margin:1px;'>
            {''.join(sentence_html_list)}
        </ul>
    </div>
    """
    display(HTML(html_content))


def visualize_tokenized_sentences(tokenized_sentences, token_weights, cmap_name="OrRd"):
    # Get colormap
    cmap = matplotlib.cm.get_cmap(cmap_name)
    
    # Build HTML for each sentence
    sentence_html_list = []
    for tokens, weights in zip(tokenized_sentences, token_weights):
        # Normalize weights to [0, 1]
        weights = np.array(weights)
        norm_weights = (weights - min(weights)) / (max(weights) - min(weights))
        
        # Convert weights to colors
        colors = [matplotlib.colors.rgb2hex(cmap(w)) for w in norm_weights]
        
        # Reconstruct the original sentence with token highlights
        highlighted_sentence = ""
        for i, token in enumerate(tokens):
            # Remove prefix indicators (e.g., ## or Ġ) for natural appearance
            if "##" in token or "Ġ" in token:
                token = " " + token.replace("Ġ", "")
            if token in ["<s>", "</s>"]:
                token = ""
            color = colors[i]
            highlighted_sentence += f'<span style="background-color:{color}; color:black; padding:0px; border-radius:3px; font-weight:bold;">{token}</span>'
        
        # Append the highlighted sentence to the list
        sentence_html_list.append(f"<li>{highlighted_sentence.strip()}</li>")
    
    # Combine all sentences into a single container with a white background
    html_content = f"""
    <div style='font-family:monospace; background-color:white; padding:10px;'>
        <ul style='list-style-type:none; padding:1px; margin:1px;'>
            {''.join(sentence_html_list)}
        </ul>
    </div>
    """
    display(HTML(html_content))


# # Example usage
# sentences = [
#     ["This", "is", "a", "test", "."],
#     ["Another", "example", "sentence", "here", "."],
#     ["Visualizing", "multiple", "sentences", "is", "easy", "!"]
# ]
# weight_lists = [
#     [0.1, 0.2, 0.5, 0.8, 0.3],
#     [0.3, 0.6, 0.9, 0.4, 0.2],
#     [0.5, 0.3, 0.8, 0.6, 0.7, 0.9]
# ]
# visualize_sentences(sentences, weight_lists)


In [2]:
import warnings
warnings.filterwarnings('ignore')

import torch
from src import laion_clap


checkpoint_path = "/fs/nexus-scratch/milis/848K/CLAP/logs/2024_12_07-17_46_23-model_HTSAT-tiny-lr_0.01-b_96-j_1-p_fp32/checkpoints/epoch_latest.pt"

# checkpoint_path = "/fs/nexus-scratch/milis/848K/CLAP/models/630k-audioset-best.pt"


model = laion_clap.CLAP_Module()
model.load_ckpt(checkpoint_path)
model.eval()

Initializing empty model here (0)


Some weights of the model checkpoint at roberta-base were not used when initializing CustomRobertaModel: ['lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing CustomRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CustomRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of CustomRobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'roberta.log_reweighting']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti

Load the specified checkpoint /fs/nexus-scratch/milis/848K/CLAP/logs/2024_12_07-17_46_23-model_HTSAT-tiny-lr_0.01-b_96-j_1-p_fp32/checkpoints/epoch_latest.pt from users.
Load Checkpoint...
Loaded state dict to memory (1)
['text_branch.log_reweighting']
Loading state dict to model (2)
Loaded state dict to model strictly (3)


CLAP_Module(
  (model): CLAP(
    (audio_branch): HTSAT_Swin_Transformer(
      (spectrogram_extractor): Spectrogram(
        (stft): STFT(
          (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
          (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(480,), bias=False)
        )
      )
      (logmel_extractor): LogmelFilterBank()
      (spec_augmenter): SpecAugmentation(
        (time_dropper): DropStripes()
        (freq_dropper): DropStripes()
      )
      (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (patch_embed): PatchEmbed(
        (proj): Conv2d(1, 96, kernel_size=(4, 4), stride=(4, 4))
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (layers): ModuleList(
        (0): BasicLayer(
          dim=96, input_resolution=(64, 64), depth=2
          (blocks): ModuleList(
            (0): SwinTransformerBlock(
      

In [3]:
log_reweighting = model.model.text_branch.log_reweighting
tokenizer = model.tokenizer

print(log_reweighting)

Parameter containing:
tensor([-0.1206,  0.0000,  0.3762,  ...,  0.0000,  0.0000,  0.0000],
       requires_grad=True)


In [4]:
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


def tokenize(text):
    token_ids = tokenizer.encode(text)
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    return token_ids, tokens

In [5]:
sentences = [
    "The sound of the ocean waves crashing while a kid is yelling.",
    "You can hear a lot of car noises in the background.",
    "A person is running and you can hear their footrun."
]

token_list = []
weight_list = []
for sentence in sentences:
    token_ids, tokens = tokenize(sentence)
    print(tokens)
    token_list.append(tokens)

    log_weights = log_reweighting[token_ids]
    weight_list.append(torch.exp(log_weights).squeeze().tolist())


visualize_tokenized_sentences(token_list, weight_list)

['<s>', 'The', 'Ġsound', 'Ġof', 'Ġthe', 'Ġocean', 'Ġwaves', 'Ġcrashing', 'Ġwhile', 'Ġa', 'Ġkid', 'Ġis', 'Ġyelling', '.', '</s>']
['<s>', 'You', 'Ġcan', 'Ġhear', 'Ġa', 'Ġlot', 'Ġof', 'Ġcar', 'Ġnoises', 'Ġin', 'Ġthe', 'Ġbackground', '.', '</s>']
['<s>', 'A', 'Ġperson', 'Ġis', 'Ġrunning', 'Ġand', 'Ġyou', 'Ġcan', 'Ġhear', 'Ġtheir', 'Ġfoot', 'run', '.', '</s>']
