In [1]:
import pickle
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
from datasets import load_dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda')

chord_tokenizer = AutoTokenizer.from_pretrained("jammai/chocolm-modernbert-base")
chord_model = AutoModel.from_pretrained("jammai/chocolm-modernbert-base")
chord_model.to(device)

#text_tokenizer = AutoTokenizer.from_pretrained("neavo/modern_bert_multilingual")
#text_model = AutoModel.from_pretrained("neavo/modern_bert_multilingual")
text_tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
text_model = AutoModel.from_pretrained("xlm-roberta-base", attn_implementation="eager")
text_model.to(device)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


XLMRobertaModel(
  (embeddings): XLMRobertaEmbeddings(
    (word_embeddings): Embedding(250002, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): XLMRobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x XLMRobertaLayer(
        (attention): XLMRobertaAttention(
          (self): XLMRobertaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): XLMRobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=Tru

In [3]:
chords_lyrics = load_dataset("jammai/chords_and_lyrics")

In [7]:
def extract_text_chords_pairs(dataset):
    result_chords = []
    result_lyrics = []
    for song in tqdm(dataset):
        lyrics = eval(song["verse_to_lyrics"])
        chords = eval(song["verse_to_harte_chords"])
        for verse in chords:
            if verse + 1 in lyrics and len(lyrics[verse + 1].rstrip()):
                result_chords.append(" ".join(chords[verse]))
                result_lyrics.append(lyrics[verse + 1])
    return result_lyrics, result_chords


def compute_embeddings_ls_hs(data_loader, tokenizer, model):
    embeddings_ls, embeddings_hs = [], []
    with torch.no_grad():
        for batch in tqdm(data_loader):
            ls, hs = compute_embedding(tokenizer, model, batch, output_embending_from_hidden_states=True)
            embeddings_ls.append(ls)
            embeddings_hs.append(hs)
    return embeddings_ls, embeddings_hs


def apply_attention(attention_mask, model_state):
    
    #return torch.vstack([
    #        (model_state[i][torch.nonzero(attention_mask[i])]).transpose(0,1).sum(dim=1)
    #        for i in range(len(model_state))]
    #    )
    return torch.diagonal(attention_mask.to(torch.float32) @ model_state).transpose(0,1)

def compute_embedding(tokenizer, model, input, device="cuda", output_embending_from_hidden_states=False):
    
    tokenized = tokenizer(input, return_tensors="pt", padding=True)
    tokenized = tokenized.to(device)
    model_output = model(**tokenized, output_hidden_states=True)

    embedding = apply_attention(tokenized.attention_mask, model_output.last_hidden_state)
    embedding = embedding.to("cpu")

    if output_embending_from_hidden_states:
        hidden_states = [
            apply_attention(tokenized.attention_mask, model_output.hidden_states[n_layer]).unsqueeze(0) 
            for n_layer in  range(len(model_output.hidden_states))
        ]
        hidden_states = torch.vstack(hidden_states).mean(dim = 0)
        embedding_hs = hidden_states.to("cpu")
    else:
        return embedding

    return embedding, embedding_hs

In [10]:
ds = chords_lyrics["train"].train_test_split(test_size=0.2)

In [11]:
train_lyrics, train_chords = extract_text_chords_pairs(ds["train"])
test_lyrics, test_chords = extract_text_chords_pairs(ds["test"])

100%|██████████| 108626/108626 [00:31<00:00, 3404.95it/s]
100%|██████████| 27157/27157 [00:08<00:00, 3380.63it/s]


In [12]:
data_loader_train_lyrics = DataLoader(train_lyrics, batch_size=256)
data_loader_test_lyrics = DataLoader(test_lyrics, batch_size=256)

data_loader_train_chords = DataLoader(train_chords, batch_size=256)
data_loader_test_chords = DataLoader(test_chords, batch_size=256)

train_chords_embeddings_ls, train_chords_embeddings_hs = compute_embeddings_ls_hs(data_loader_train_chords, chord_tokenizer, chord_model)
test_chords_embeddings_ls, test_chords_embeddings_hs = compute_embeddings_ls_hs(data_loader_test_chords, chord_tokenizer, chord_model)

train_lyrics_embeddings_ls, train_lyrics_embeddings_hs = compute_embeddings_ls_hs(data_loader_train_lyrics, text_tokenizer, text_model)
test_lyrics_embeddings_ls, test_lyrics_embeddings_hs = compute_embeddings_ls_hs(data_loader_test_lyrics, text_tokenizer, text_model)

100%|██████████| 10985/10985 [02:09<00:00, 84.93it/s]
100%|██████████| 2756/2756 [00:32<00:00, 84.85it/s]
100%|██████████| 10985/10985 [14:43<00:00, 12.44it/s]
100%|██████████| 2756/2756 [03:42<00:00, 12.39it/s]


In [11]:
pickle.dump(torch.vstack(train_chords_embeddings_ls), open("experimental_data/train_chords_embeddings_ls.pkl", "wb"))
pickle.dump(torch.vstack(train_chords_embeddings_hs), open("experimental_data/train_chords_embeddings_hs.pkl", "wb"))
pickle.dump(torch.vstack(test_chords_embeddings_ls), open("experimental_data/test_chords_embeddings_ls.pkl", "wb"))
pickle.dump(torch.vstack(test_chords_embeddings_hs), open("experimental_data/test_chords_embeddings_hs.pkl", "wb"))

pickle.dump(torch.vstack(train_lyrics_embeddings_ls), open("experimental_data/train_lyrics_embeddings_ls.pkl", "wb"))
pickle.dump(torch.vstack(train_lyrics_embeddings_hs), open("experimental_data/train_lyrics_embeddings_hs.pkl", "wb"))
pickle.dump(torch.vstack(test_lyrics_embeddings_ls), open("experimental_data/test_lyrics_embeddings_ls.pkl", "wb"))
pickle.dump(torch.vstack(test_lyrics_embeddings_hs), open("experimental_data/test_lyrics_embeddings_hs.pkl", "wb"))

pickle.dump(train_lyrics, open("experimental_data/train_lyrics.pkl", "wb"))
pickle.dump(train_chords, open("experimental_data/train_chords.pkl", "wb"))
pickle.dump(test_lyrics, open("experimental_data/test_lyrics.pkl", "wb"))
pickle.dump(test_chords, open("experimental_data/test_chords.pkl", "wb"))