In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
from transformers import RobertaTokenizer
from src import laion_clap


checkpoint_path = "/fs/nexus-scratch/milis/848K/CLAP/logs/reweighting_5/checkpoints/epoch_latest.pt"


model = laion_clap.CLAP_Module()
model.load_ckpt(checkpoint_path)
model.eval()


log_reweighting = model.model.text_branch.log_reweighting

tokenizer = RobertaTokenizer.from_pretrained("roberta-base")


def tokenize(text):
    token_ids = tokenizer.encode(text)
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    return token_ids, tokens

Initializing empty model here (0)


Some weights of the model checkpoint at roberta-base were not used when initializing CustomRobertaModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing CustomRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CustomRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of CustomRobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.log_reweighting', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti

Load the specified checkpoint /fs/nexus-scratch/milis/848K/CLAP/logs/reweighting_5/checkpoints/epoch_latest.pt from users.
Load Checkpoint...
Loaded state dict to memory (1)
Reweighting modules:
['text_branch.log_reweighting']
Loading state dict to model (2)
Loaded state dict to model strictly (3)


In [37]:
import matplotlib
import numpy as np
from IPython.display import HTML, display


def visualize_sentences(tokenized_sentences, token_weights):
    # Get colormap
    cmap = matplotlib.cm.get_cmap("OrRd")
    
    # Build HTML for each sentence
    sentence_html_list = []
    for tokens, weights in zip(tokenized_sentences, token_weights):
        # Normalize weights to [0, 1]
        weights = np.array(weights)
        norm_weights = (weights - min(weights)) / (max(weights) - min(weights))
        
        # Convert weights to colors
        colors = [matplotlib.colors.rgb2hex(cmap(w)) for w in norm_weights]

        print("Tokens:", tokens)
        print("Weights:", [round(w, 2) for w in weights])
        
        # Reconstruct the original sentence with token highlights
        highlighted_sentence = ""
        for i, token in enumerate(tokens):
            # Remove prefix indicators (e.g., ## or Ġ) for natural appearance
            if "##" in token or "Ġ" in token:
                token = " " + token.replace("Ġ", "")
            if token in ["<s>", "</s>"]:
                token = ""
            color = colors[i]
            highlighted_sentence += f'<span style="background-color:{color}; color:black; padding:0px; border-radius:3px; font-weight:bold;">{token}</span>'
        
        # Append the highlighted sentence to the list
        sentence_html_list.append(f"<li>{highlighted_sentence.strip()}</li>")
    
    # Combine all sentences into a single container with a white background
    html_content = f"""
    <div style='font-family:monospace; background-color:white; padding:10px;'>
        <ul style='list-style-type:none; padding:1px; margin:1px;'>
            {''.join(sentence_html_list)}
        </ul>
    </div>
    """
    display(HTML(html_content))


def visualize_data_driven(sentences):
    """
    sentences: list of strings
    """
    tokenized_sentences_list = []
    tokenized_weights_list = []

    for sentence in sentences:
        token_ids, tokens = tokenize(sentence)
        tokenized_sentences_list.append(tokens)

        log_weights = log_reweighting[token_ids]
        weights = torch.exp(log_weights).squeeze().tolist()
        tokenized_weights_list.append(weights)

    visualize_sentences(tokenized_sentences_list, tokenized_weights_list)


def visualize_user_driven(sentences, weights_list, sos_token="<s>", eos_token="</s>", default_weight=1.0):
    """
    sentences: list of strings
    weights: list of lists of floats
    """
    tokenized_sentences_list = []
    tokenized_weights_list = []

    for sentence, weights in zip(sentences, weights_list):
        words = sentence.split()
        # Initialize tokenized sentence and weights
        tokenized_sentence = []
        tokenized_weights = []

        # Add <SOS> token and its weight
        tokenized_sentence.append(sos_token)
        tokenized_weights.append(default_weight)

        # Process each word and its weight
        i = 0
        for word, weight in zip(words, weights):
            if i != 0:
                word = " " + word
            i += 1
            # Tokenize the word into subwords
            subwords = tokenize(word)[1][1:-1]

            # Extend tokenized sentence and replicate the weight for each subword
            tokenized_sentence.extend(subwords)
            tokenized_weights.extend([weight] * len(subwords))

        # Add <EOS> token and its weight
        tokenized_sentence.append(eos_token)
        tokenized_weights.append(default_weight)

        tokenized_sentences_list.append(tokenized_sentence)
        tokenized_weights_list.append(tokenized_weights)

    visualize_sentences(tokenized_sentences_list, tokenized_weights_list)

In [38]:
sentences = [
    "The sound of the ocean waves crashing while a kid is yelling.",
    "You can hear a lot of car noises in the background.",
    "A person is running and you can hear their footrun."
]

visualize_data_driven(sentences)


sentences = ["This is a testttt"]
weights = [[0.1, 0.2, 0.8, 0.3]]

visualize_user_driven(sentences, weights)

Tokens: ['<s>', 'The', 'Ġsound', 'Ġof', 'Ġthe', 'Ġocean', 'Ġwaves', 'Ġcrashing', 'Ġwhile', 'Ġa', 'Ġkid', 'Ġis', 'Ġyelling', '.', '</s>']
Weights: [0.89, 0.9, 0.83, 0.97, 0.88, 1.09, 1.14, 1.08, 0.94, 0.99, 0.77, 1.14, 1.12, 0.82, 1.36]
Tokens: ['<s>', 'You', 'Ġcan', 'Ġhear', 'Ġa', 'Ġlot', 'Ġof', 'Ġcar', 'Ġnoises', 'Ġin', 'Ġthe', 'Ġbackground', '.', '</s>']
Weights: [0.89, 1.0, 0.97, 0.94, 0.99, 0.94, 0.97, 0.98, 1.01, 0.87, 0.88, 0.94, 0.82, 1.36]
Tokens: ['<s>', 'A', 'Ġperson', 'Ġis', 'Ġrunning', 'Ġand', 'Ġyou', 'Ġcan', 'Ġhear', 'Ġtheir', 'Ġfoot', 'run', '.', '</s>']
Weights: [0.89, 0.97, 1.02, 1.14, 0.99, 0.84, 0.97, 0.97, 0.94, 0.88, 0.93, 1.0, 0.82, 1.36]


Tokens: ['<s>', 'This', 'Ġis', 'Ġa', 'Ġtest', 'tt', 't', '</s>']
Weights: [1.0, 0.1, 0.2, 0.8, 0.3, 0.3, 0.3, 1.0]
