In [None]:
import os
import numpy as np
import pandas as pd
import absl.logging
from ast import literal_eval
import torch
from nlp_embeddings_no_nlu import DistilBERT, SentenceTransformerMPNET

absl.logging.set_verbosity(absl.logging.ERROR)

In [None]:
dataset_name = 'dataset2_proc'

max_words = 400
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def save_embedding(data_x, nlp_embedding, batch_size=5000, start_idx=0, dir_path='data', prefix=''):
    fname = os.path.join(dir_path, f'{prefix}_{nlp_embedding.name}_{dataset_name}.csv')

    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    if start_idx == 0 and os.path.exists(fname):
        os.remove(fname)

    for i in range(start_idx, data_x.shape[0], batch_size):

        if i + batch_size > data_x.shape[0]:
            j = data_x.shape[0]
        else:
            j = i + batch_size

        print(f'Processing rows: {i} - {j - 1}')

        embeddings = nlp_embedding.embed_lyrics(data_x[i:j])
        pd.DataFrame(embeddings).to_csv(fname, mode='a', index=False, header=False)
    
    print('Success!')   

In [None]:
def add_normalized_lyrics(data):
    tokens = data.tokens.apply(literal_eval)
    data['normalized_lyrics'] = [' '.join(t) for t in tokens]

In [None]:
train_data = pd.read_csv(f'data/train/{dataset_name}.csv')
test_data = pd.read_csv(f'data/test/{dataset_name}.csv')

In [None]:
add_normalized_lyrics(train_data)
add_normalized_lyrics(test_data)

In [None]:
train_data = train_data.loc[~train_data['lyrics'].isna()]
test_data = test_data.loc[~test_data['lyrics'].isna()]

In [None]:
embedded_train_data_path = 'data/train/embeddings'
if not os.path.exists(embedded_train_data_path):
    os.makedirs(embedded_train_data_path)

embedded_test_data_path = 'data/test/embeddings'
if not os.path.exists(embedded_test_data_path):
    os.makedirs(embedded_test_data_path)
    
prefix = 'embedded'
prefix_normalized = 'embedded_norm'

## DistilBERT

In [None]:
emb_dbert = DistilBERT(max_words, device)

In [None]:
save_embedding(test_data.lyrics, emb_dbert, dir_path=embedded_test_data_path, prefix=prefix)

In [None]:
save_embedding(train_data.lyrics, emb_dbert, dir_path=embedded_train_data_path, prefix=prefix)

### Normalized data

In [None]:
save_embedding(test_data.normalized_lyrics, emb_dbert, dir_path=embedded_test_data_path, prefix=prefix_normalized)

In [None]:
save_embedding(train_data.normalized_lyrics, emb_dbert, dir_path=embedded_train_data_path, prefix=prefix_normalized)

## SentenceTransformerMPNET

In [None]:
emb_mpnet = SentenceTransformerMPNET()

In [None]:
save_embedding(test_data.lyrics, emb_mpnet, dir_path=embedded_test_data_path, prefix=prefix)

In [None]:
save_embedding(train_data.lyrics, emb_mpnet, dir_path=embedded_train_data_path, prefix=prefix)

### Normalized data

In [None]:
save_embedding(test_data.normalized_lyrics, emb_mpnet, dir_path=embedded_test_data_path, prefix=prefix_normalized)

In [None]:
save_embedding(train_data.normalized_lyrics, emb_mpnet, dir_path=embedded_train_data_path, prefix=prefix_normalized)