In [1]:
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch
from torchvision import models, transforms
import torch.nn.functional as F
import numpy as np
import skimage.transform

import librosa
import librosa.display

import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
from pathlib import Path
from IPython.display import Audio


In [2]:
# config
import yaml
config_filename = 'ResNeSt003.yaml'
with open(f"../configs/{config_filename}") as f:
    config = yaml.safe_load(f)

In [3]:
import pandas as pd
from pathlib import Path
audio_path = Path('../input/rfcx-species-audio-detection/train/')
train_tp = pd.read_csv('../input/rfcx-species-audio-detection/train_tp.csv')
train_fp = pd.read_csv('../input/rfcx-species-audio-detection/train_fp.csv')
train_tp['data_type'] = 'tp'
train_fp['data_type'] = 'fp'
train = pd.concat([train_tp, train_fp])
train.head()

Unnamed: 0,recording_id,species_id,songtype_id,t_min,f_min,t_max,f_max,data_type
0,003bec244,14,1,44.544,2531.25,45.1307,5531.25,tp
1,006ab765f,23,1,39.9615,7235.16,46.0452,11283.4,tp
2,007f87ba2,12,1,39.136,562.5,42.272,3281.25,tp
3,0099c367b,17,4,51.4206,1464.26,55.1996,4565.04,tp
4,009b760e6,10,1,50.0854,947.461,52.5293,10852.7,tp


In [4]:
import numpy as np
import matplotlib.patches as patches

def get_spec_sample(df, total_time=60, PERIOD=10):
    sample = df.sample(1)
    recording_id = sample["recording_id"].values[0]
    main_species_id = sample["species_id"].values[0]

    y, sr = sf.read(audio_path / f"{recording_id}.flac")  # for default

    # データの長さを全てtotal_time分にする
    len_y = len(y)
    total_length = total_time * sr
    if len_y < total_length:
        new_y = np.zeros(total_length, dtype=y.dtype)
        start = np.random.randint(total_length - len_y)
        new_y[start:start + len_y] = y
        y = new_y.astype(np.float32)
    elif len_y > total_length:
        start = np.random.randint(len_y - total_length)
        y = y[start:start + total_length].astype(np.float32)
    else:
        y = y.astype(np.float32)

    # PERIODO単位に分割(現在は6等分)
    split_y = np.split(y, total_time/PERIOD)

    images = []
    original_images = []
    # 分割した音声を一つずつ画像化してリストで返す
    for y in split_y:
        melspec = librosa.feature.melspectrogram(y, sr=sr, n_mels=128, fmin=80, fmax=15000, power=1.5)
        melspec = librosa.power_to_db(melspec).astype(np.float32)
        original_images.append(melspec)
        image = mono_to_color(melspec)
        image = cv2.resize(image, (400, 224))
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)
        images.append(image)

    labels = np.zeros(len(df['species_id'].unique()), dtype=np.float32)
    labels[main_species_id] = 1.0
    return original_images, images, labels


def mono_to_color(X: np.ndarray,
                  mean=None,
                  std=None,
                  norm_max=None,
                  norm_min=None,
                  eps=1e-6):
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V

In [5]:
original_images, images, labels = get_spec_sample(train_tp)


In [8]:
!pip install torchlibrosa

Collecting torchlibrosa
  Downloading torchlibrosa-0.0.5-py3-none-any.whl (8.9 kB)
Installing collected packages: torchlibrosa
Successfully installed torchlibrosa-0.0.5


In [9]:
import sys
sys.path.append('../')
from src.models import get_model

fold = 0
output_dir = Path('../output/1226_162624/')
model_name = 'ResNeSt50'
model = get_model(config, fold)
try:
    ckpt = torch.load(output_dir / f'{model_name}-{fold}-v0.ckpt')  # TODO foldごとのモデルを取得できるようにする
except:
    ckpt = torch.load(output_dir / f'{model_name}-{fold}.ckpt')  # TODO foldごとのモデルを取得できるようにする
model.load_state_dict(ckpt['state_dict'])

Downloading: "https://s3.us-west-1.wasabisys.com/resnest/torch/resnest50-528c19ca.pth" to /root/.cache/torch/hub/checkpoints/resnest50-528c19ca.pth


  0%|          | 0.00/105M [00:00<?, ?B/s]

FileNotFoundError: [Errno 2] No such file or directory: '../output/1226_162624/ResNeSt50-0.ckpt'

In [None]:
# Get the features from a model
class SaveFeatures():
    features = None
    def __init__(self, module): 
        self.hook = module.register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output): 
        self.features = output.data.numpy()

    def remove(self): 
        self.hook.remove()

def getCAM(feature_conv, weight_fc, class_idx):
    _, nc, h, w = feature_conv.shape
    cam = weight_fc[class_idx].dot(feature_conv.reshape((nc, h * w)))
    cam = cam.reshape(h, w)
    cam = cam - np.min(cam)
    cam_img = cam / np.max(cam)
    return [cam_img]

In [None]:
# Get features from last conv layer
final_layer = model.model._modules.get('layer4')
activated_features = SaveFeatures(final_layer)

# Inference
_ = model.eval()
input = torch.Tensor(images[0])
input = input.unsqueeze(0)
prediction = model(input)
pred_probabilities = F.softmax(prediction).data.squeeze()
activated_features.remove()
print('Top-1 prediction:', torch.topk(pred_probabilities, 1)[0])

# Take weights from the first linear layer
weight_softmax_params = list(model.model._modules.get('fc').parameters())
weight_softmax = np.squeeze(weight_softmax_params[0].data.numpy())

# Get the top-1 prediction and get CAM
class_idx = torch.topk(pred_probabilities, 1)[1].int()
print(class_idx)
print(activated_features.features.shape)
print(weight_softmax[class_idx])
overlay = getCAM(activated_features.features, weight_softmax, class_idx)

In [None]:
fig = plt.figure(figsize=(15,10))
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)
ax1.imshow(original_images[0])

image = original_images[0]
ax2.imshow(image)
ax2.imshow(skimage.transform.resize(overlay[0], (image.shape[0], image.shape[1])), alpha=0.5, cmap='jet');
plt.show()