In [10]:
import os
import sys
import numpy as np
import julius

import torch
import torchaudio

sys.path.append('lib')
from lib.Task2 import *

In [13]:
from marg_network import Estimator
from marg_postprocess import PostProcessor
from marg_utils import *


In [3]:
ckpt = torch.load('weights/marg_ckpt.pt')
estimator = Estimator()
estimator.load_state_dict(ckpt['net'])

<All keys matched successfully>

In [4]:
estimator.eval()
estimator.reset_state()


In [5]:
set_num = 5
drone_num = 3
data_path = '../temp_data'
answer_list = []
for i in range(set_num):
    sub_answer_list = []
    for j in range(drone_num):
        pred_cla_full = []
        folder_name = 'set_0' + str(i+1)
        file_name = 'set0' + str(i+1) + '_drone0' + str(j+1) + '_ch1.wav'
        y, sr = torchaudio.load(os.path.join(data_path, folder_name, file_name))
        wav_full = julius.resample_frac(y, sr, 16000)[0]

        with torch.no_grad():
            for k in range(wav_full.shape[0] // (16000 * 10)):
                wav = wav_full[i * 16000 * 10: (i + 1) * 16000 * 10]

                wav = torch.tensor(wav).float()
                wav = wav.view(1, 1, -1).repeat(1, 7, 1)
                pred_cla = estimator(wav)
                pred_cla_full.append(pred_cla)

            pred_cla_full = torch.cat(pred_cla_full, -2).squeeze()
            output = torch.sigmoid(pred_cla_full).detach().cpu().numpy()
            output_zeroone = (output > 0.5).astype(int)
        postprocessor = PostProcessor()
        estimation = postprocessor.report(output_zeroone)
#         print(estimation)                
        sub_answer_list.append(estimation)
    answer_list.append(sub_answer_list)


  "The use of pseudo complex type in spectrogram is now deprecated."


In [6]:
out_str = answer_list_to_json(answer_list)                    


In [7]:
print(out_str)

{
    "task2_answer": [{
        "set_1":[
            {drone_1:[{
                "M": ["NONE"],
                "W": ["00:00","00:03","00:13","00:23","00:33","00:43","00:53","01:03","01:13","01:23","01:33","01:43","01:53","02:03","02:13","02:23","02:33","02:43","02:53","03:03","03:13","03:23","03:33","03:43","03:53"],
                "C": ["00:07","00:17","00:27","00:37","00:47","00:57","01:07","01:17","01:27","01:37","01:47","01:57","02:07","02:17","02:27","02:37","02:47","02:57","03:07","03:17","03:27","03:37","03:47","03:58"]}]},
            {drone_2:[{
                "M": ["NONE"],
                "W": ["00:03","00:13","00:23","00:33","00:43","00:53","01:03","01:13","01:23","01:33","01:43","01:53","02:03","02:13","02:23","02:33","02:43","02:53","03:03","03:13","03:23","03:33","03:43","03:53"],
                "C": ["00:07","00:17","00:27","00:37","00:47","00:57","01:07","01:17","01:27","01:37","01:47","01:57","02:07","02:17","02:27","02:37","02:47","02:57","03:07","03:17","03:27

In [8]:
gt_list_for_samples = [   # gt_list for sample video (1 set, 3 drone, 3 class)
                    [
                        [[[210, 214], [217, 220]], [[222, 226]], [[74, 79], [225, 231]]],
                        [[[168, 172], [183, 186]], [[175, 179]], [[164, 169]]],
                        [[[213, 216], [220, 224]], [[214, 218]], [[226, 231]]],
                    ],
                    [
                        [[[210, 214], [217, 220]], [[222, 226]], [[74, 79], [225, 231]]],
                        [[[168, 172], [183, 186]], [[175, 179]], [[164, 169]]],
                        [[[213, 216], [220, 224]], [[214, 218]], [[226, 231]]],
                    ],
                    [
                        [[[210, 214], [217, 220]], [[222, 226]], [[74, 79], [225, 231]]],
                        [[[168, 172], [183, 186]], [[175, 179]], [[164, 169]]],
                        [[[213, 216], [220, 224]], [[214, 218]], [[226, 231]]],
                    ],
                    [
                        [[[210, 214], [217, 220]], [[222, 226]], [[74, 79], [225, 231]]],
                        [[[168, 172], [183, 186]], [[175, 179]], [[164, 169]]],
                        [[[213, 216], [220, 224]], [[214, 218]], [[226, 231]]],
                    ],
                    [
                        [[[210, 214], [217, 220]], [[222, 226]], [[74, 79], [225, 231]]],
                        [[[168, 172], [183, 186]], [[175, 179]], [[164, 169]]],
                        [[[213, 216], [220, 224]], [[214, 218]], [[226, 231]]],
                    ],                                                                                                
        ]

In [26]:
def cw_metrics_cal(gt, pred):
    correct = 0
    deletion = 0
    insertion = 0
    include_list = []
    for ii in range(len(gt)):
        is_include = 0
        if gt[ii][0] == None:
            if pred[0] == None:
                correct += 1
        else:
            if pred[0] == None:
                deletion += 1
            else:
                gt_start, gt_end = gt[ii][0], gt[ii][1]
                for jj in range(len(pred)):
                    if gt_start <= pred[jj] and pred[jj] <= gt_end:
                        include_list.append(pred[jj])
                        is_include += 1
                if is_include == 0:
                    deletion += 1
                elif is_include > 1:
                    insertion = is_include - 1
                elif is_include == 1:
                    correct += 1
    substitution = len(pred) - len(list(set(include_list)))
#     print('s, d, i, correct:', substitution, deletion, insertion, correct)
    return substitution, deletion, insertion, correct


def evaluation(gt_list, answer_list):
    set_num = len(gt_list)
    total_s, total_d, total_i, total_n, total_correct = 0, 0, 0, 0, 0
    for i in range(set_num):
        sw_s, sw_d, sw_i, sw_n, sw_correct = 0, 0, 0, 0, 0
        for j in range(3): # for drones 1 ~ 3
            dw_s, dw_d, dw_i, dw_n, dw_correct = 0, 0, 0, 0, 0
            for k in range(3): # for class man, woman, child
                cw_s, cw_d, cw_i, cw_correct = cw_metrics_cal(gt_list[i][j][k], answer_list[i][j][k])
                dw_s += cw_s
                dw_d += cw_d
                dw_i += cw_i
                dw_n += len(gt_list[i][j][k])
                dw_correct += cw_correct
            dw_er = (dw_s + dw_d + dw_i) / dw_n
            print('Set', str(i), 'Drone', str(j), 's, d, i, er, correct:', dw_s, dw_d, dw_i, np.round(dw_er, 2), dw_correct)
            sw_s += dw_s
            sw_d += dw_d
            sw_i += dw_i
            sw_n += dw_n
            sw_er = (sw_s + sw_d + sw_i) / sw_n
            sw_correct += dw_correct
        total_s += sw_s
        total_d += sw_d
        total_i += sw_i
        total_n += sw_n
        total_er = (total_s + total_d + total_i) / total_n
        total_correct += sw_correct
        print('Subtotal Set', str(i), 's, d, i, er, correct:', sw_s, sw_d, sw_i, np.round(sw_er, 2), sw_correct)
    print('Total', 's, d, i, er, correct:', total_s, total_d, total_i, np.round(total_er, 2), total_correct)
    return total_s, total_d, total_i, total_er, total_correct


In [27]:
evaluation(gt_list_for_samples, answer_list)


Set 0 Drone 0 s, d, i, er, correct: 47 2 0 9.8 3
Set 0 Drone 1 s, d, i, er, correct: 48 3 0 12.75 1
Set 0 Drone 2 s, d, i, er, correct: 48 3 0 12.75 1
Subtotal Set 0 s, d, i, er, correct: 143 8 0 11.62 5
Set 1 Drone 0 s, d, i, er, correct: 59 4 0 12.6 1
Set 1 Drone 1 s, d, i, er, correct: 59 2 0 15.25 2
Set 1 Drone 2 s, d, i, er, correct: 61 4 0 16.25 0
Subtotal Set 1 s, d, i, er, correct: 179 10 0 14.54 3
Set 2 Drone 0 s, d, i, er, correct: 3 5 0 1.6 0
Set 2 Drone 1 s, d, i, er, correct: 3 4 0 1.75 0
Set 2 Drone 2 s, d, i, er, correct: 3 4 0 1.75 0
Subtotal Set 2 s, d, i, er, correct: 9 13 0 1.69 0
Set 3 Drone 0 s, d, i, er, correct: 28 4 0 6.4 1
Set 3 Drone 1 s, d, i, er, correct: 29 4 0 8.25 0
Set 3 Drone 2 s, d, i, er, correct: 29 4 0 8.25 0
Subtotal Set 3 s, d, i, er, correct: 86 12 0 7.54 1
Set 4 Drone 0 s, d, i, er, correct: 27 4 0 6.2 1
Set 4 Drone 1 s, d, i, er, correct: 27 3 0 7.5 1
Set 4 Drone 2 s, d, i, er, correct: 27 3 0 7.5 1
Subtotal Set 4 s, d, i, er, correct: 81 10 0 

(498, 53, 0, 8.476923076923077, 12)

In [19]:
answer_list[0][0][0][0] == None

True