In [300]:
import librosa
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio
import mir_eval.sonify
import pandas as pd


In [302]:
class Drums:
    def __init__(self, file):
        self.file = file
        self.y, self.sr = librosa.load(self.file) #load audio: y = time series, sr = sampling rate
        self.time = librosa.get_duration(y = self.y, sr = self.sr) #audio length in seconds
    
    #play audio
    def play(self): 
        return(Audio(data = self.y, rate = self.sr))
    
    #get tempo & beats
    def tempo(self):
        tempo, beat_frames = librosa.beat.beat_track(y = self.y, sr = self.sr)
        return(tempo, beat_frames)
    
    #mel spectrogram
    def melspec(self, xlim = 0):
        S = librosa.feature.melspectrogram(y = self.y, sr = self.sr)
        
        if xlim != 0:
            fig, ax = plt.subplots()
            ax.set(xlim = xlim)
            img = librosa.display.specshow(librosa.power_to_db(S, ref=np.max),
                         x_axis='time', y_axis='mel', ax = ax)
            fig.colorbar(img, ax=ax)
            
        else:
            img = librosa.display.specshow(librosa.power_to_db(S, ref=np.max),
                         x_axis='time', y_axis='mel')
        
        return(S, img)

    #save mel spec image
    def melsave(self, filename, imgdir):
        self.melspec()[1]
        name = f"{imgdir}{filename}.png"
        plt.savefig(name)
        
    #chromagram
    def chroma(self):
        Sc = np.abs(librosa.stft(self.y, n_fft=4096))**2
        chroma = librosa.feature.chroma_stft(S=Sc, sr=self.sr)

        fig, (ax, ax2) = plt.subplots(nrows=2, sharex=True)
        ax.set(title = 'Chromagram')
        img = librosa.display.specshow(librosa.amplitude_to_db(Sc, ref=np.max),
                               y_axis='log', x_axis='time', ax=ax)
        fig.colorbar(img, ax=[ax])
        ax.label_outer()
        img = librosa.display.specshow(chroma, y_axis='chroma', x_axis='time', ax=ax2)
        fig.colorbar(img, ax=[ax2])
    
    #get onset times
    def onsets(self):
        o_env = librosa.onset.onset_strength(y=self.y, sr=self.sr)
        times = librosa.times_like(o_env, sr=self.sr)
        onset_frames = librosa.onset.onset_detect(onset_envelope=o_env, sr=self.sr) 
        return(times[onset_frames])
    
    #save melspec at each onset
    def splitaudio(self, filename):
        onsetlst = list(self.onsets())
    
        for o in range(len(onsetlst)):
            if o+1 < len(onsetlst):
                self.melspec(xlim = [onsetlst[o], onsetlst[o+1]-0.1])
                self.melsave(filename = f"{filename}_{o}", imgdir = f'{filename}/' )
                plt.close()
            else:
                self.melspec(xlim = [onsetlst[o], rock.time])
                self.melsave(filename = f"{filename}_{o}", imgdir = f'{filename}/' )
                plt.close()
    
    #convert annotations.txt to df chart        
    def key(self, txt):  
        onsets = self.onsets()
        #txt to df of onsets & classes
        ann = open(txt, 'r')
        ann = ann.read().split('\n')

        ons = []
        cls = []

        for i in range(len(ann) - 1):
            splt = ann[i].split('\t')
            if float(splt[0].strip()) not in ons:
                ons.append(float(splt[0].strip()))
                cls.append(splt[1].strip())
            else:
                cls[-1] = splt[1].strip() + cls[-1] #combine multiple classes at 1 onset

        onsets = np.round(onsets, 2)
        key = list(zip(onsets, cls))

        df = pd.DataFrame(key)
        df.columns = ['onset', 'class']

        #df drum chart
        kick = []
        snare = []
        cym = []
        tom = []

        for i in df['class']:
            kick.append(1) if 'KD' in str(i) else kick.append(0)
            snare.append(1) if 'SD' in str(i) else snare.append(0)
            cym.append(1) if 'CY' in str(i) else cym.append(0)
            tom.append(1) if 'TT' in str(i) else tom.append(0)

        df['kick'] = kick
        df['snare'] = snare
        df['cymbal'] = cym
        df['tom'] = tom
        df = df.drop(columns = ['class'])

        return(df)