In [1]:
# Importing necessary libraries
import pandas as pd
import numpy as np


from matplotlib import pyplot as plt

pd.options.mode.chained_assignment = None  # default='warn'


# from transformers import SentenceTransformer
import numpy as np
import json
from sentence_transformers import SentenceTransformer
import torch
from torch import nn
from collections import OrderedDict


import pyarrow.parquet as pa

In [2]:
df = pd.read_csv('/home/mendu/Thesis/data/magnatagatune/processed_df.csv', index_col=[0])

In [3]:
df

Unnamed: 0_level_0,mp3_path,tags
clip_id,Unnamed: 1_level_1,Unnamed: 2_level_1
2,american_bach_soloists-j_s__bach_solo_cantatas...,opera
6,american_bach_soloists-j_s__bach_solo_cantatas...,opera
10,american_bach_soloists-j_s__bach_solo_cantatas...,opera
11,american_bach_soloists-j_s__bach_solo_cantatas...,opera
14,lvx_nova-lvx_nova-01-contimune-30-59.mp3,electronic
...,...,...
58896,jacob_heringman-blame_not_my_lute-56-la_bressa...,classical
58897,jacob_heringman-blame_not_my_lute-56-la_bressa...,classical
58898,jacob_heringman-blame_not_my_lute-56-la_bressa...,classical
58907,jacob_heringman-blame_not_my_lute-57-lost_is_m...,classical


In [4]:
table_test = pa.read_table('/home/mendu/Thesis/data/magnatagatune/captions/test-00000-of-00001-94781ef88fa7ed89.parquet') 
table_train = pa.read_table('/home/mendu/Thesis/data/magnatagatune/captions/train-00000-of-00001-28dbf9154d6d526d.parquet') 
table_val = pa.read_table('/home/mendu/Thesis/data/magnatagatune/captions/valid-00000-of-00001-bf9893b31ca2d5e5.parquet') 

In [5]:
df_test = table_test.to_pandas().set_index('track_id')
df_train = table_train.to_pandas().set_index('track_id')
df_val = table_val.to_pandas().set_index('track_id') 

In [6]:
len(df_train)+len(df_test)+len(df_val)

25860

In [7]:
frames = [df_test, df_train, df_val]
df_captions = pd.concat(frames)

In [8]:
captions = df_captions[['caption_writing']]
captions.index = captions.index.astype('int64')
captions

Unnamed: 0_level_0,caption_writing
track_id,Unnamed: 1_level_1
2,Experience the majestic beauty of classical mu...
6,Experience the rich sound of classical eleganc...
10,This powerful classic opera piece showcases th...
11,This atmospheric and introspective song blends...
12,Experience a powerful and uptempo classical me...
...,...
58716,This breathtaking song features a mesmerizing ...
58717,This folk-inspired song features intricate str...
58719,This hauntingly beautiful ballad takes its tim...
58736,This beautiful piece of music features intrica...


In [9]:
df = pd.merge(df, captions, how = 'left', left_index=True, right_index=True)


In [10]:
df

Unnamed: 0_level_0,mp3_path,tags,caption_writing
clip_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2,american_bach_soloists-j_s__bach_solo_cantatas...,opera,Experience the majestic beauty of classical mu...
6,american_bach_soloists-j_s__bach_solo_cantatas...,opera,Experience the rich sound of classical eleganc...
10,american_bach_soloists-j_s__bach_solo_cantatas...,opera,This powerful classic opera piece showcases th...
11,american_bach_soloists-j_s__bach_solo_cantatas...,opera,This atmospheric and introspective song blends...
14,lvx_nova-lvx_nova-01-contimune-30-59.mp3,electronic,This upbeat dance track features a pulsing tec...
...,...,...,...
58896,jacob_heringman-blame_not_my_lute-56-la_bressa...,classical,This classical guitar solo piece features intr...
58897,jacob_heringman-blame_not_my_lute-56-la_bressa...,classical,This beautiful classical piece features a haun...
58898,jacob_heringman-blame_not_my_lute-56-la_bressa...,classical,This classical piece features beautiful melodi...
58907,jacob_heringman-blame_not_my_lute-57-lost_is_m...,classical,This beautiful classical piece features a gent...


In [22]:
# df.to_csv('/home/mendu/Thesis/data/magnatagatune/df_w_captions.csv')

## Getting the embedding for the caption

In [11]:
# Define the Encoder class focusing only on the encoding part
class Encoder(nn.Module):
    def __init__(self, input_size, encoding_size):
        super(Encoder, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),
            nn.Linear(512, 256),
            nn.ReLU(True),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(True),
            nn.Linear(128, encoding_size),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.layers(x)

In [12]:
# Define input and encoding sizes
input_size = 768
encoding_size = 64

In [15]:
def load_encoder_state_dict(encoder, state_dict_path):
    state_dict = torch.load(state_dict_path)
    # Rename the state_dict keys to match the structure of the Encoder class
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        # Prepend 'layers.' to each key
        name = f'layers.{k}'
        new_state_dict[name] = v
    
    encoder.load_state_dict(new_state_dict)
    return encoder

def get_device():
    return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

device = get_device()

In [16]:
# Instantiate the encoder
encoder = Encoder(input_size=input_size, encoding_size=encoding_size).to(device)
encoder.eval()  # Set the encoder to evaluation mode

# Load the pre-trained weights for the encoder
encoder = load_encoder_state_dict(encoder, '/home/mendu/Thesis/data/musiccaps/auto_encoder/encoder_state_dict.pth')

# Load the SentenceTransformer model and move to the correct device
roberta_model = SentenceTransformer('/home/mendu/Thesis/data/musiccaps/new_embedding_model').to(device)

# Taking a sample caption from df
caption = df.caption_writing[2]

In [18]:
# Function to encode caption
def encode_caption(encoder, sentence_model, text):
    with torch.no_grad():
        encoded_caption = sentence_model.encode(text, convert_to_tensor=True)
        encoded_caption = encoded_caption.to(device)
        return encoder(encoded_caption.unsqueeze(0))

In [19]:
# Encoded caption 
fully_encoded_caption = encode_caption(encoder, roberta_model, caption)

In [20]:
fully_encoded_caption

tensor([[ 52.0009,  28.6403,  34.2275,  59.2607,   1.1570,  57.7524, 142.9124,
          44.6702,  41.6169, 104.4225,  92.6284, 119.1738,  42.9483,  79.1854,
         216.7657,  67.0605,  43.6379,  77.6768, 100.6521,  87.6473, 132.7112,
         126.8025,  86.1116, 131.3719,  68.6272,  81.4069, 134.8302, 196.9826,
          90.1404,  65.5783,  97.3774, 187.2035,  53.5716, 243.3475,   0.0000,
           0.0000,  55.3763, 132.5184,  23.2223,   0.0000, 155.0516,  99.4067,
          38.8811, 111.8058, 213.8127, 153.5288,   0.0000,  39.6156,  19.5320,
         135.4064,  96.3223,  40.3084, 116.1382, 154.0468,  79.5539,  27.2223,
          96.1176,  39.1434, 160.7672,  63.0411,  12.1961,  99.3322,  49.6889,
          55.5985]], device='cuda:0')

In [21]:
df.columns

Index(['mp3_path', 'tags', 'caption_writing'], dtype='object')