In [1]:
import pickle
import torch
import sklearn
from tqdm import tqdm
import numpy as np
from asif import ASIF, extract_candidate_sets_from_clusters, compute_embedding
from transformers import AutoTokenizer, AutoModel
from matplotlib import pyplot as plt
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda')

chord_tokenizer = AutoTokenizer.from_pretrained("jammai/chocolm-modernbert-base-transposed")
chord_model = AutoModel.from_pretrained("jammai/chocolm-modernbert-base-transposed")
chord_model.to(device)

text_tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
text_model = AutoModel.from_pretrained("answerdotai/ModernBERT-base")
text_model.to(device)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


ModernBertModel(
  (embeddings): ModernBertEmbeddings(
    (tok_embeddings): Embedding(50368, 768, padding_idx=50283)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (drop): Dropout(p=0.0, inplace=False)
  )
  (layers): ModuleList(
    (0): ModernBertEncoderLayer(
      (attn_norm): Identity()
      (attn): ModernBertAttention(
        (Wqkv): Linear(in_features=768, out_features=2304, bias=False)
        (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)
        (Wo): Linear(in_features=768, out_features=768, bias=False)
        (out_drop): Identity()
      )
      (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): ModernBertMLP(
        (Wi): Linear(in_features=768, out_features=2304, bias=False)
        (act): GELUActivation()
        (drop): Dropout(p=0.0, inplace=False)
        (Wo): Linear(in_features=1152, out_features=768, bias=False)
      )
    )
    (1-2): 2 x ModernBertEncoderLayer(
     

In [None]:
def extract_text_chords_pairs(dataset):
    result_chords = []
    result_lyrics = []
    for song in tqdm(dataset):
        lyrics = eval(song["verse_to_lyrics"])
        chords = eval(song["verse_to_harte_chords"])
        for verse in chords:
            if verse + 1 in lyrics and len(lyrics[verse + 1].rstrip()):
                result_chords.append(" ".join(chords[verse]))
                result_lyrics.append(lyrics[verse + 1])
    return result_lyrics, result_chords


def compute_embeddings_ls_hs(data_loader, tokenizer, model):
    embeddings_ls, embeddings_hs = [], []
    with torch.no_grad():
        for batch in tqdm(data_loader):
            ls, hs = compute_embedding(tokenizer, model, batch, mode="hs")
            embeddings_ls.append(ls)
            embeddings_hs.append(hs)
    return embeddings_ls, embeddings_hs


In [4]:
chords_lyrics = load_dataset("jammai/chords_and_lyrics")

In [5]:
ds = chords_lyrics["train"].train_test_split(test_size=0.2)

In [6]:
train_lyrics, train_chords = extract_text_chords_pairs(ds["train"])
test_lyrics, test_chords = extract_text_chords_pairs(ds["test"])

100%|██████████| 108626/108626 [00:31<00:00, 3430.01it/s]
100%|██████████| 27157/27157 [00:07<00:00, 3426.41it/s]


In [9]:
data_loader_train_lyrics = DataLoader(train_lyrics, batch_size=256)
data_loader_test_lyrics = DataLoader(test_lyrics, batch_size=256)

data_loader_train_chords = DataLoader(train_chords, batch_size=256)
data_loader_test_chords = DataLoader(test_chords, batch_size=256)

train_chords_embeddings_ls, train_chords_embeddings_hs = compute_embeddings_ls_hs(data_loader_train_chords, chord_tokenizer, chord_model)
test_chords_embeddings_ls, test_chords_embeddings_hs = compute_embeddings_ls_hs(data_loader_test_chords, chord_tokenizer, chord_model)

train_lyrics_embeddings_ls, train_lyrics_embeddings_hs = compute_embeddings_ls_hs(data_loader_train_lyrics, text_tokenizer, text_model)
test_lyrics_embeddings_ls, test_lyrics_embeddings_hs = compute_embeddings_ls_hs(data_loader_test_lyrics, text_tokenizer, text_model)

100%|██████████| 10988/10988 [05:43<00:00, 32.03it/s]
100%|██████████| 2752/2752 [01:27<00:00, 31.31it/s]
100%|██████████| 10988/10988 [16:52<00:00, 10.86it/s]
100%|██████████| 2752/2752 [04:07<00:00, 11.11it/s]


In [11]:
pickle.dump(torch.vstack(train_chords_embeddings_ls), open("experimental_data/train_chords_embeddings_ls.pkl", "wb"))
pickle.dump(torch.vstack(train_chords_embeddings_hs), open("experimental_data/train_chords_embeddings_hs.pkl", "wb"))
pickle.dump(torch.vstack(test_chords_embeddings_ls), open("experimental_data/test_chords_embeddings_ls.pkl", "wb"))
pickle.dump(torch.vstack(test_chords_embeddings_hs), open("experimental_data/test_chords_embeddings_hs.pkl", "wb"))

pickle.dump(torch.vstack(train_lyrics_embeddings_ls), open("experimental_data/train_lyrics_embeddings_ls.pkl", "wb"))
pickle.dump(torch.vstack(train_lyrics_embeddings_hs), open("experimental_data/train_lyrics_embeddings_hs.pkl", "wb"))
pickle.dump(torch.vstack(test_lyrics_embeddings_ls), open("experimental_data/test_lyrics_embeddings_ls.pkl", "wb"))
pickle.dump(torch.vstack(test_lyrics_embeddings_hs), open("experimental_data/test_lyrics_embeddings_hs.pkl", "wb"))

pickle.dump(train_lyrics, open("experimental_data/train_lyrics.pkl", "wb"))
pickle.dump(train_chords, open("experimental_data/train_chords.pkl", "wb"))
pickle.dump(test_lyrics, open("experimental_data/test_lyrics.pkl", "wb"))
pickle.dump(test_chords, open("experimental_data/test_chords.pkl", "wb"))