In [2]:
# system
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import models

from tqdm import tqdm
from collections import OrderedDict

In [3]:
# AuSiL
from feature_extraction.network_architectures import weak_mxh64_1024
import feature_extraction.extractor as exm

In [4]:
# Watterfle dataset (FMA)
from preprocessing import load_poly_encoder_dataset

In [5]:
device_num = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
last_pad_length = 938
num_feature = 128
cs, ns, labels, cs_path, ns_path = load_poly_encoder_dataset(last_pad_length, num_feature)

  0%|          | 1/32000 [00:00<1:11:43,  7.44it/s]



100%|██████████| 32000/32000 [01:15<00:00, 425.29it/s] 


In [7]:
# bsz * 10 개만 할당
small_size = 32000
cs = torch.tensor(cs[:small_size])
ns = torch.tensor(ns[:small_size])
labels = torch.tensor(labels[:small_size])

cs = cs.view(-1, 1, last_pad_length, num_feature)
ns = ns.view(-1, 1, last_pad_length, num_feature)

cs.size(), ns.size()

(torch.Size([32000, 1, 938, 128]), torch.Size([32000, 1, 938, 128]))

In [8]:
# feture extraction by pre-trained CNN
trainType = 'weak_mxh64_1024'
pre_model_path = 'feature_extraction/mx-h64-1024_0d3-1.17.pkl'
featType = ['layer1', 'layer2', 'layer4', 'layer5', 'layer7', 'layer8', 'layer10', 'layer11', 'layer13', 'layer14', 'layer16', 'layer18'] # or layer 19 -  layer19 might not work well
globalpoolfn = F.max_pool2d # can use max also
netwrkgpl = F.avg_pool2d # keep it fixed

In [9]:
# Load model
def load_model(netx,modpath):
    state_dict = torch.load(modpath, map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if 'module.' in k:
            name = k[7:]
        else:
            name = k
        new_state_dict[name] = v
    netx.load_state_dict(new_state_dict)
    
netx = weak_mxh64_1024(527, netwrkgpl)
load_model(netx, pre_model_path)

feat_extractor = torch.nn.DataParallel(exm.featExtractor(netx, featType))

feat_extractor.to(device_num)
feat_extractor.eval()

DataParallel(
  (module): featExtractor(
    (layer1): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (layer2): Sequential(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (layer3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (layer4): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (layer5): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )

In [14]:
# cs feature extraction
cs_loader = DataLoader(cs, batch_size=1, num_workers=16, shuffle=False)

pbar = tqdm(cs_loader)
for i, spectrogram in enumerate(pbar):
    p1d = (0, 0, 0, 1024-spectrogram.shape[2])
    spectrogram = F.pad(spectrogram, p1d, "constant", 0)
    features = []
    #print('spectrogram', spectrogram.shape)
    for j in range(spectrogram.shape[2]//128):
        batch = spectrogram[:,:,j*128:(j+1)*128,:]
        if batch.shape[0] > 0:
            features.append(feat_extractor(batch.to(device_num)).data.cpu().numpy())
    features = np.concatenate(features, axis=0)
    save_path = cs_path[i]
    np.savez_compressed('{}_wlaf'.format(save_path), features=features)

100%|██████████| 32000/32000 [25:56<00:00, 20.56it/s]


In [15]:
#ns feature extraction
ns_loader = DataLoader(ns, batch_size=1, num_workers=16, shuffle=False)

pbar = tqdm(ns_loader)
for i, spectrogram in enumerate(pbar):
    p1d = (0, 0, 0, 1024-spectrogram.shape[2])
    spectrogram = F.pad(spectrogram, p1d, "constant", 0)
    features = []
    #print('spectrogram', spectrogram.shape)
    for j in range(spectrogram.shape[2]//128):
        batch = spectrogram[:,:,j*128:(j+1)*128,:]
        if batch.shape[0] > 0:
            features.append(feat_extractor(batch.to(device_num)).data.cpu().numpy())
    features = np.concatenate(features, axis=0)
    save_path = ns_path[i]
    np.savez_compressed('{}_wlaf'.format(save_path), features=features)

100%|██████████| 32000/32000 [25:57<00:00, 20.54it/s]


In [None]:
print(ns_path)