In [1]:
import torch
import librosa
import matplotlib.pyplot as plt
from PIL import Image
import io
import numpy as np
from librosa.display import specshow
from torchvision import datasets, models, transforms

## Load pre saved model

In [2]:
model = torch.jit.load('cnn_97_accuracy.pt')
#model.load_state_dict(torch.load(r'.\dataset\final_model_97', map_location=torch.device('cpu')))

## Functions 

In [22]:
def classify_walkout(path):
    sr = 16000
    wav, sr = librosa.load(path, sr = sr) # load wav audio file
    
    wav, s = librosa.load(path)
    if (len(wav) < sr):
        wav = np.append(wav, np.zeros(sr - len(wav)))
    elif (len(wav) > sr):
        wav = wav[:sr]
    
    # get spectogram
    plt.figure()
    src_ft = librosa.stft(wav)
    src_db = librosa.amplitude_to_db(abs(src_ft))
    plt.axis('off')
    specshow(src_db, sr=sr) 
    
    # save spectogram
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    
    # applying tranformations (Resize and toTensor)
    transform=transforms.Compose([transforms.Resize((201,81))])
    img= transform(img)
    img = np.array(img)[:,:,:3]
    trans = transforms.ToTensor()
    img = trans(img)
    plt.close()
    
    #model prediction
    model.eval()
    pred = model(img[None, ...])
    
    if (pred.argmax(1) == 1):
        return True
    else:
        return False
    
# convert audio file into 1 sec random chunks and classify them   
def classify_walkout_10sec_stream(path):
    wav, sr = librosa.load(path, sr = 16000) # load wav file
    previous_segments = []
    walkout_segments = []
    segment_ms = 16000
    
    #get random time segments
    for i in range(20):
        segment_time = get_random_time_segment(segment_ms, wav.shape[0])
        
        if is_overlapping(segment_time, previous_segments):
            continue
            
        previous_segments.append(segment_time)
        audio_clip = wav[segment_time[0]:segment_time[1]+1]
        
        if(classify_walkout_wav(audio_clip)):
            walkout_segments.append(segment_time)
            
    return np.array(walkout_segments)/16000      

# random time segments for given length    
def get_random_time_segment(segment_ms, max_length):
    segment_start = np.random.randint(low=0, high=max_length-segment_ms)
    segment_end = segment_start + segment_ms - 1
    
    return (segment_start, segment_end)
                                      
# Check if a segment is overlapping with any previous segments
def is_overlapping(segment_time, previous_segments):
    
    segment_start, segment_end = segment_time
    
    overlap = False
    
    for previous_start, previous_end in previous_segments:
        if segment_start <= previous_end and segment_end >= previous_start:
            overlap = True
    return overlap
       
def classify_walkout_wav(wav):
    sr = 16000
    plt.figure()
    src_ft = librosa.stft(wav)
    src_db = librosa.amplitude_to_db(abs(src_ft))
    plt.axis('off')
    specshow(src_db, sr=sr) 
    
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    
    transform=transforms.Compose([transforms.Resize((201,81))])
    img= transform(img)
    img = np.array(img)[:,:,:3]
    trans = transforms.ToTensor()
    img = trans(img)
    plt.close()
    
    model.eval()
    pred = model(img[None, ...])
    
    if (pred.argmax(1) == 1):
        return True
    else:
        return False        

## Inference

In [23]:
# walkout classifier function, input: path string, output: Boolean 

classify_walkout("raza_1.wav")

True

In [21]:
# walkout classifier function for an audio stream, input: path string, output: List of timestamps of occurences

segments_list = classify_walkout_10sec_stream('raza.wav')

print("'Walkout' was found in the following time durations.\n \n", segments_list)

'Walkout' was found in the following time durations.
 
 [[ 1.7506875  2.750625 ]
 [ 4.0700625  5.07     ]
 [ 5.3208125  6.32075  ]
 [11.7626875 12.762625 ]
 [ 7.867      8.8669375]
 [ 0.3104375  1.310375 ]]
