In [6]:
from audio_dataset import MelSpectrogramDataset,denormalize

from pathlib import Path

import os

import numpy as np
import pandas as pd

import librosa
import librosa.display

import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [4]:
root = '/Users/haohe/Desktop/Heart_shenzhen'
audio_data = root + '/audio'
data = root + '/data'

In [7]:
df_train = pd.read_csv(data+'/train.csv')
df_valid = pd.read_csv(data+'/valid_test.csv')
df_valid = df_valid[df_valid['valid']==True]

In [8]:
def get_train_transform():
    return A.Compose([
        #A.HorizontalFlip(p=0.5),
        #A.VerticalFlip(p=0.5),
        #A.Resize(height=512,width=512,p=1.0),
        A.Normalize(p=1.0),
        ToTensorV2(p=1.0),
    ],p=1.0)

def get_valid_transform():
    return A.Compose([
        #A.Resize(height=512,width=512,p=1.0),
        A.Normalize(p=1.0),
        ToTensorV2(p=1.0),
    ],p=1.0)

In [9]:
train_ds = MelSpectrogramDataset(df_train,audio_data,img_tfms=get_train_transform())
valid_ds = MelSpectrogramDataset(df_valid,audio_data,img_tfms=get_valid_transform())

In [10]:
#Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
#https://arxiv.org/abs/1908.08681v1
#implemented for PyTorch / FastAI by lessw2020 
#github: https://github.com/lessw2020/mish
class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
        return x * (torch.tanh(F.softplus(x)))

In [11]:
class Model_Head(nn.Module):
    def __init__(self,ni,nc,ps=0.25):
        '''
        ni : input filter size
        nc : output class size
        ps : dropout rate
        '''
        super().__init__()
        layers = ([Mish(),conv2d(ni,ni),batchnorm_2d(ni),AdaptiveConcatPool2d(),Flatten()] 
                  + bn_drop_lin(ni*2,512,p=ps,actn=Mish()) 
                  + bn_drop_lin(512,nc,p=ps*2))
        self.head = nn.Sequential(*layers)
    
    def forward(self,xb):
        return self.head(xb)
    
class Resnet_1ch(nn.Module):
    def __init__(self,arch,nc=[1,1,1],pretrained=True):
        super().__init__()
        self.body = nn.Sequential(*list(arch(pretrained=pretrained).children())[:-2])
       
        # change input filter size to 1
        nf,ni,h,w = self.body[0].weight.shape
        w = self.body[0].weight.sum(dim=1,keepdim=True)
        conv_input = conv2d(1,nf,ks=h)
        conv_input.weight.data = w
        self.body[0] = conv_input
        
        # multi-head output
        # 168,11,7 from num of unique labels
        ni = num_features_model(self.body)
        self.head_grapheme = Model_Head(ni,nc[0])
        self.head_vowel = Model_Head(ni,nc[1])
        self.head_consonant = Model_Head(ni,nc[2])
    
    def forward(self,x):
        x = self.body(x)
        return (self.head_grapheme(x),self.head_vowel(x),self.head_consonant(x))
    
# replace all relu layer with Mish        
def to_mish(model):
    for name,child in model.named_children():
        if isinstance(child,nn.ReLU):
            setattr(model,name,Mish())
        else:
            to_mish(child)