In [1]:
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import torch
import torchvision
from torchvision import datasets, transforms
import torch.utils.data as data
import torchvision.models as models
import matplotlib.image as pli
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from PIL import Image
from PIL import ImageOps
from PIL import ImageEnhance
import random
import math
import pickle
import glob
import librosa

path = './dataset/train'

labels = ['061_foam_brick', 'green_basketball', 'salt_cylinder',
          'shiny_toy_gun', 'stanley_screwdriver', 'strawberry',
          'toothpaste_box', 'toy_elephant', 'whiteboard_spray',
          'yellow_block']

is_plot = False

freq_length = 47
time_length = 201
batch_size = 16

ModuleNotFoundError: No module named 'matplotlib'

In [120]:

class ImageSet(data.Dataset):
    def __init__(self):
        self.length = 10000

    def __getitem__(self, index):
        # print(index)
        label = index % len(labels)
        audio_files = glob.glob(f'{path}/{labels[label]}/*/*.pkl')
        audio_file = random.choice(audio_files)

        data = np.load(audio_file, allow_pickle=True)
        audio = data['audio']
        sample_rate = data['audio_samplerate']

        stft_result = []
        for i in range(4):
            audio_resample = librosa.resample(audio[:, i], sample_rate, 11000)
            stft_result.append(
                np.abs(librosa.stft(audio_resample, n_fft=512)))
            if is_plot:
                print(audio_resample.shape)
                print(stft_result[i].shape)
        stft_result = np.array(stft_result)
        stft_result /= np.max(stft_result)
        # print(np.unravel_index(np.argmax(stft_result), stft_result.shape))

        time_mid = int(stft_result.shape[2] / 2)
        time_left = time_mid - 80
        time_right = time_left + time_length
        audio_map = stft_result[:, 0:freq_length, time_left:time_right]

        if is_plot:
            print(audio_file)
            print(audio.shape)
            print(stft_result.shape)
            print(time_left)
            print(time_right)
            print(audio_map.shape)

            plt.imshow(audio_map[0], cmap='gray')
            plt.show()
            plt.imshow(audio_map[1], cmap='gray')
            plt.show()
            plt.imshow(audio_map[2], cmap='gray')
            plt.show()
            plt.imshow(audio_map[3], cmap='gray')
            plt.show()

        return audio_map, label

    def __len__(self):
        return self.length

train_loader = data.DataLoader(ImageSet(), batch_size=batch_size, shuffle=True)

In [128]:
class AudioCNN(nn.Module):
    def __init__(self,):
        super(AudioCNN, self).__init__()
        self.layer1 = nn.Sequential(
            # 47 201
            nn.Conv2d(in_channels=4, out_channels=64,
                      kernel_size=(3, 11)),
            # 45 191
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=self.layer1[0].out_channels, out_channels=64,
                      kernel_size=(3, 8), stride=(2, 3)),
            # 22 63
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=self.layer2[0].out_channels,
                      out_channels=128, kernel_size=(3, 5)),
            # 21 59
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=self.layer3[0].out_channels,
                      out_channels=128, kernel_size=(3, 5), stride=(2, 2)),
            # 10 28
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=self.layer4[0].out_channels,
                      out_channels=256, kernel_size=(3, 5)),
            # 8 24
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc = nn.Linear(self.layer5[0].out_channels, len(labels))

    def forward(self, input):
        out = self.layer1(input)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        # print(out.shape)
        out = self.avg_pool(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

In [129]:
convNet = AudioCNN()

In [130]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(convNet.parameters(), lr=0.01)

convNet.train()
device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device('cpu')
convNet = convNet.to(device)

for i, (imgs, lbs) in enumerate(train_loader):
    # break
    outputs = convNet(imgs.float())
    loss = loss_func(outputs, lbs)
    # break

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    predict = torch.argmax(F.softmax(outputs, dim=1), dim=1)
    if i % 2 == 0:
        print(f"i = {i},  loss = {loss}, labels = {lbs}, predicts = {predict},  accuracy = {float(sum(lbs == predict))/float(lbs.size(0))}")

i = 0,  loss = 2.2426676750183105, labels = tensor([2, 9, 5, 5, 9, 6, 1, 5]), predicts = tensor([2, 2, 2, 2, 2, 2, 2, 2]),  accuracy = 0.125
i = 2,  loss = 2.085440158843994, labels = tensor([9, 0, 6, 6, 5, 6, 9, 5]), predicts = tensor([9, 5, 5, 9, 9, 5, 9, 9]),  accuracy = 0.25
i = 4,  loss = 2.238786220550537, labels = tensor([4, 5, 6, 1, 1, 6, 4, 2]), predicts = tensor([9, 5, 9, 9, 9, 5, 9, 9]),  accuracy = 0.125
i = 6,  loss = 2.900486707687378, labels = tensor([5, 2, 4, 0, 9, 1, 8, 4]), predicts = tensor([6, 6, 6, 6, 6, 6, 6, 6]),  accuracy = 0.0
i = 8,  loss = 2.9933876991271973, labels = tensor([7, 8, 5, 8, 2, 1, 0, 2]), predicts = tensor([6, 6, 6, 6, 6, 6, 6, 4]),  accuracy = 0.0
i = 10,  loss = 2.0345137119293213, labels = tensor([5, 8, 6, 4, 5, 2, 3, 1]), predicts = tensor([4, 4, 4, 4, 4, 1, 4, 1]),  accuracy = 0.25
i = 12,  loss = 2.090325117111206, labels = tensor([7, 7, 6, 5, 8, 5, 8, 1]), predicts = tensor([4, 8, 5, 4, 4, 8, 8, 4]),  accuracy = 0.125
i = 14,  loss = 2.184

KeyboardInterrupt: 