In [8]:
import math

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange
from transformers import AutoModel, AutoTokenizer


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class SigmoidContrastiveLearning(nn.Module):
    def __init__(
        self,
        *,
        layers = 1,
        init_temp = 10,
        init_bias = -10
    ):
        super().__init__()
        self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
        self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)

    def device(self):
        return next(self.parameters()).device

    def forward(self, sims):
        device = sims.device
        if sims.ndim == 2:
            sims = rearrange(sims, 'i j -> 1 i j')

        n = sims.shape[-1] 
        sims = sims * self.temperatures.exp() + self.bias 
        labels = 2 * rearrange(torch.eye(n).to(device), 'i j -> 1 i j') - torch.ones_like(sims)

        return -F.logsigmoid(labels * sims).sum() / n


class NT_Xent(nn.Module):
    def __init__(self, temperature: float=.5):
        super(NT_Xent, self).__init__()
        self.temperature = nn.Parameter(torch.ones(1) * temperature)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)
        
    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask
    
    def forward(self, z_i, z_j):
        batch_size = z_i.shape[0]
        
        N = 2 * batch_size
        mask = self.mask_correlated_samples(batch_size)
        z = torch.cat((z_i, z_j), dim=0)
        
        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0))
        sim /= self.temperature
        
        sim_i_j = torch.diag(sim, batch_size) # 우상 삼각 행렬
        sim_j_i = torch.diag(sim, -batch_size) # 좌하 삼각 행렬

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N,1)
        negative_samples = sim[mask].reshape(N,-1)

        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels) / N
        return loss, labels, logits
    
class MuLaN(nn.Module):
    def __init__(
        self,

        dim_latent=128,
    ):
        super().__init__()
        self.dim_latent = dim_latent

        self.audio_model = ASTModel(fstride=10, tstride=10)
        self.text_model = AutoModel.from_pretrained("klue/bert-base")

        self.audio_to_latent = nn.Sequential(
            nn.Linear(768, 768, bias=False),
            nn.ReLU(),
            nn.Linear(768, dim_latent, bias=False)
        )
        
        self.text_to_latent = nn.Sequential(
            nn.Linear(768, 768, bias=False),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(768, dim_latent, bias=False)
        )

    def get_audio_latents(self, wavs):
        audio_h = self.audio_model(wavs)
        audio_latents = self.audio_to_latent(audio_h)
        return audio_latents

    def get_text_latents(self, input_ids, attention_mask):
        text_h = self.text_model(input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))[0][:, 0, :]
        text_letents = self.text_to_latent(text_h)
        return text_letents

    def forward(
        self,
        wavs=None,
        input_ids=None,
        attention_mask=None,
        return_latents=False,
        return_similarities=False,
        return_pairwise_similarities=False,
    ):
        
        audio_latents = self.get_audio_latents(wavs)
        text_latents = self.get_text_latents(input_ids, attention_mask)

        return audio_latents, text_latents


In [2]:
import os
import json
import argparse
import logging
from pathlib import Path
import torch
import torch.nn.functional as F

from torch.utils.data import DataLoader

import os
import random
from typing import Dict

import numpy as np
import pandas as pd
from pandas import DataFrame

import torch
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset

from transformers import AutoModel, AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("klue/bert-base")

In [3]:

class DatasetBase(Dataset):
    def __init__(self, batch_size: int, data_type: str):
        super().__init__()
        self.batch_size = batch_size
        sf_path = '/data/mulan_text_dataset/shortform_total.parquet.gzip'
        lf_path = '/data/mulan_text_dataset/longform_total.parquet.gzip'
        pl_path = '/data/mulan_text_dataset/playlist_total.parquet.gzip'

        self.shortform_df = pd.read_parquet(sf_path)
        self.shortform_rows = self.shortform_df.shape[0]

        self.longform_df = pd.read_parquet(lf_path)
        self.longform_rows = self.longform_df.shape[0]

        self.playlist_df = pd.read_parquet(pl_path)
        self.playlist_rows = self.playlist_df.shape[0]

        # The number of data should be  a multiple of the batch size
        self.total_length = (
            (self.shortform_rows + self.longform_rows + self.playlist_rows)
            // self.batch_size
            * self.batch_size
        )

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        raise RuntimeError("abstract function")

    def get_fbank(
        self,
        path: str,
        output_freq: int = 16000,
        seconds: int = 30,
        infer=None
    ):
        """
        음악 파일을 fbank 형태로 변환

        Args:
            filename (str): 음원 파일 경로.
            output_freq (int, optional): fbank로 생성될 Hz. Defaults to 16000.
            start_seconds (int, optional): fbank로 생성될 시작 시간. Defaults to 30.
            end_seconds (int, optional): fbank로 생성될 종료 시간. Defaults to 60.
        """
        self.infer = infer
        waveform, orig_freq = torchaudio.load(path)
        waveform  = torch.mean(waveform, dim=0, keepdim=True)
        
        resampler = T.Resample(
            orig_freq=orig_freq, new_freq=output_freq, dtype=waveform.dtype
        )
        waveform = resampler(waveform)
        
        length = seconds * output_freq
        if self.infer==True:
            if waveform.shape[1] <= 30 * output_freq + length:
                start = 0
                end = waveform.shape[1]
            else:  
                start = 0
                end = start+length
        else:
            if waveform.shape[1] <= length:
                start = 0
                end = waveform.shape[1]
            else:
                start = random.randint(0, waveform.shape[1] - length)
                end = start + length

        # print(f"crop start: {start}, end: {end}")
        waveform = waveform[:, start:end]

        # normalization
        waveform = waveform - waveform.mean()

        # wavefile -> mel filter back 만들어줌
        fbank = torchaudio.compliance.kaldi.fbank(
            waveform,
            htk_compat=True,
            sample_frequency=output_freq,
            use_energy=False,
            window_type="hamming",
            num_mel_bins=128,
            dither=0.0,
            frame_shift=10,
        )
        # torchaudio.compliance.kaldi.fbank를 거치면 웨이브파일 -> (frame개수, mel filter 개수)
        # input wave의 maximum 길이를 1000으로 맞춤
        target_length = 3000
        n_frames = fbank.shape[0]
        p = target_length - n_frames

        # max 길이보다 크면 자르고, max길이보다 작으면 zero padding
        if p > 0:
            m = torch.nn.ZeroPad2d((0, 0, 0, p))
            fbank = m(fbank)
        elif p < 0:
            fbank = fbank[0:target_length, :]

        return fbank

    def get_music_path(self, folder, music_id, ext):
        music_path = (
            f"{folder}/{music_id}.aac" if ext == None else f"{folder}/{music_id}.{ext}"
        )

        if not os.path.exists(music_path):
            music_path = f"{folder}/{music_id}.m4a"

        return music_path

class InferDataset(DatasetBase):
    def __init__(self, batch_size: int):
        super().__init__(batch_size=batch_size, data_type="total")

    def __getitem__(self, idx):
        try:
            return self.getitem(idx)
        except:
            return self.getitem(idx - 1)

    def getitem(self, idx):
        caption = self.shortform_df["text"][idx]
        music_id = self.shortform_df["id"][idx]

        encoding = tokenizer(caption, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        input_ids = encoding['input_ids']
        attention_mask = encoding['attention_mask']
        
        music_path = self.get_music_path(folder="music", music_id=music_id, ext=None)
        fbank = self.get_fbank(music_path, infer=True)
        # ast에서 input spectrogram norm 시킴
        fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)

        return music_id, caption, input_ids, attention_mask, fbank


In [9]:
output_path = '/data/mulan_output_v7/'
    
batch_size = 2
mulan = MuLaN()
infer_dataset = InferDataset(batch_size=batch_size)

In [10]:
infer_loader = DataLoader(
    dataset=infer_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    prefetch_factor=4,
)
device = torch.device("cuda:1")

from pathlib import Path
path = Path('/data/mulan_checkpoints/mulan.v6.490000.pt')
pkg = torch.load(str(path), map_location='cpu')
mulan = MuLaN()        
mulan.load_state_dict(pkg["model"])
mulan = mulan.to(device)

RuntimeError: Error(s) in loading state_dict for MuLaN:
	Unexpected key(s) in state_dict: "text_model.embeddings.position_ids". 

In [None]:
output_audio_path = os.path.join(output_path, "audio_vector")
output_text_path = os.path.join(output_path, "text_vector")
    
Path(output_audio_path).mkdir(parents=True, exist_ok=True)
Path(output_text_path).mkdir(parents=True, exist_ok=True)
    
for idx, batch in enumerate(infer_loader):
        v_audio_path = os.path.join(output_audio_path, f"vector_audio_{str(idx)}.json")
        v_text_path = os.path.join(output_text_path, f"vector_text_{str(idx)}.json")
        audio_features = {}
        text_features = {}
        music_id, caption = batch[0], batch[1]
        input_ids, attention_mask, audio = batch[2], batch[3], batch[4]

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        audio = audio.to(device)

        audio_embeds = mulan.get_audio_latents(audio)  # during training
        text_embeds = mulan.get_text_latents(input_ids, attention_mask)  # during inference
        audio_out = F.normalize(audio_embeds, p=2, dim=-1)
        text_out = F.normalize(text_embeds, p=2, dim=-1)
        for a,b,c,d in zip(music_id, caption, audio_out, text_out):
            cbf_id = os.path.basename(str(int(a.numpy())))
            audio_features[cbf_id] = c.detach().cpu().numpy().tolist()
            text_features[b] = d.detach().cpu().numpy().tolist()
        with open(v_audio_path, 'w') as outfile:
            json.dump(audio_features, outfile)
        with open(v_text_path, 'w') as outfile:
            json.dump(text_features, outfile)
        del audio_embeds
        del text_embeds
