In [1]:
import pandas as pd
from pprint import pprint
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import os


In [2]:
model_name = "microsoft/Phi-3-mini-4k-instruct"

device = "cuda"

quant_config = BitsAndBytesConfig(
    load_in_8bit=True
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quant_config,
    device_map=device,
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side="left",
    trust_remote_code=True,
)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
df = pd.read_csv("../data/spotify-songs.csv")
df = df[df["lyrics"].notnull()]

In [4]:
example_lyrics = df[df["track_id"] == "07nH4ifBxUB4lZcsf44Brn"].iloc[0].lyrics
print(len(example_lyrics), example_lyrics)

4135 Can't be sleeping Keep on waking without the woman next to me Guilt is burning, inside I'm hurting This ain't a feeling I can keep So blame it on the night Don't blame it on me, don't blame it on me Blame it on the night Don't blame it on me, don't blame it on me Blame it on the night Don't blame it on me, don't blame it on me So blame it on the night Don't blame it on me, don't blame it on me Can't you see it? I was manipulated I had to let her through the door I had no choice in this, I was the friend she missed She needed me to talk So blame it on the night Don't blame it on me, don't blame it on me Blame it on the night Don't blame it on me, don't blame it on me Blame it on the night Don't blame it on me, don't blame it on me So blame it on the night Don't blame it on me, don't blame it on me Oh I'm so sorry, so sorry baby (I'll be better this time...) I will be better this time I got to say, I'm so sorry Oh I promise (I'll be better this time, I'll be better this time...) Don

In [5]:
def generate_query(lyrics: list[str], num_candidates=5):
    prompt_template = """
    Generate a title for the song based on the following lyrics.
    Do not add any creative or additional elements.
    Stick strictly to the words and phrases present in the lyrics.
    Ensure the title is in the same language as the lyrics.

    Lyrics: {lyrics}

    Title:"""

    prompts = [prompt_template.replace("{lyrics}", l[:2000]) for l in lyrics]

    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=30,
        temperature=0.5,
        do_sample=True,
        num_return_sequences=num_candidates,
    )

    queries = [output.split("Title: ")[1].split("\n")[0].strip("\"")
                 for output in tokenizer.batch_decode(outputs, skip_special_tokens=True)]

    queries = [queries[i:i+num_candidates] for i in range(0, len(queries), num_candidates)]

    return queries

example_queries = generate_query([example_lyrics])[0]

for query in example_queries:
    print(query)


You are not running the flash-attention implementation, expect numerical differences.


Blame It on the Night
Blame It on the Night
Blame It on the Night
Blame It on the Night
Blame It on the Night


In [6]:
batch_size = 10
num_songs = len(df)

if os.path.exists("../data/queries.csv"):
    queries_df = pd.read_csv("../data/queries.csv")
else:
    queries_df = pd.DataFrame(columns=["track_id", "query"])


for i in range(len(queries_df), num_songs, batch_size):
    batch_lyrics = df.iloc[i:i+batch_size].lyrics
    batch_track_ids = df.iloc[i:i+batch_size].track_id
    batch_queries = generate_query(batch_lyrics, num_candidates=2)
    longest_queries = [max(queries, key=len) for queries in batch_queries]

    new_queries_df = pd.DataFrame({"track_id": batch_track_ids, "query": longest_queries})
    queries_df = pd.concat([queries_df, new_queries_df])
    if (i + batch_size) % 50 == 0:
        queries_df.to_csv("../data/queries.csv", index=False)

    print(f"Progress: {i + batch_size}/{num_songs}")


Progress: 460/18194
Progress: 470/18194
Progress: 480/18194
Progress: 490/18194
Progress: 500/18194
Progress: 510/18194
Progress: 520/18194
Progress: 530/18194
Progress: 540/18194
Progress: 550/18194
Progress: 560/18194
Progress: 570/18194
Progress: 580/18194
Progress: 590/18194
Progress: 600/18194
Progress: 610/18194
Progress: 620/18194
Progress: 630/18194
Progress: 640/18194
Progress: 650/18194
Progress: 660/18194
Progress: 670/18194
Progress: 680/18194
Progress: 690/18194
Progress: 700/18194
Progress: 710/18194
Progress: 720/18194
Progress: 730/18194
Progress: 740/18194
Progress: 750/18194
Progress: 760/18194
Progress: 770/18194
Progress: 780/18194
Progress: 790/18194
Progress: 800/18194
Progress: 810/18194
Progress: 820/18194
Progress: 830/18194
Progress: 840/18194
Progress: 850/18194
Progress: 860/18194
Progress: 870/18194
Progress: 880/18194
Progress: 890/18194
Progress: 900/18194
Progress: 910/18194
Progress: 920/18194
Progress: 930/18194
Progress: 940/18194
Progress: 950/18194
