In [1]:
import os
# os.environ['http_proxy'] = 'http://127.0.0.1:12639'
# os.environ['https_proxy'] = 'http://127.0.0.1:12639'

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

In [74]:
import torchvision.models as models

class SimpleCNN(pl.LightningModule):
    def __init__(self, num_classes=2, num_channels=1, target_sr=16000, conv_targ_out_size=2000):
        super().__init__()
        # init a pretrained resnet
        self.num_classes = num_classes
        self.conv_targ_out_size = conv_targ_out_size
        # self.conv1 = nn.Conv1d(num_channels, 2, 5)
        # self.conv2 = nn.Conv1d(2, 4, 5)
        # self.pool = nn.MaxPool1d(2)
        # self.adaptive_pool = nn.AdaptiveMaxPool1d(conv_targ_out_size)
        self.conv1 = nn.Conv2d(num_channels, 2, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(2, 4, 5)
        self.adaptive_pool = nn.AdaptiveMaxPool2d((40, 50))
        self.fc1 = nn.Linear(4 * conv_targ_out_size, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        print('==========')
        print(x.shape)
        x = self.pool(F.relu(self.conv1(x)))
        print(x.shape)
        x = self.adaptive_pool(F.relu(self.conv2(x)))
        print(x.shape)
        x = x.view(-1, 4 * self.conv_targ_out_size)
        print(x.shape)
        x = F.relu(self.fc1(x))
        print(x.shape)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def training_step(self, batch, batch_idx):
        images, target = batch
        preds = self.forward(images)
        loss = F.cross_entropy(preds, target)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
        return [optimizer], [scheduler]

Data

In [75]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchaudio

In [76]:
# def wav_loader(fn, num_channels=1, target_sr=44100):
#     waveform, sr = torchaudio.load(fn)
#     transformed = torchaudio.transforms.Resample(sr, target_sr)(waveform[:num_channels,:])
#     return transformed

from audio_classifier.wav2vec.wav2vec import Wav2VecFeat

wav2vec_feat = Wav2VecFeat()

def wav_loader(fn, num_channels=1, target_sr=16000):
    waveform, sr = torchaudio.load(fn)
    transformed = torchaudio.transforms.Resample(sr, target_sr)(waveform[:num_channels,:])
    features = wav2vec_feat.extract_feature(transformed)
    return features

In [77]:
def create_dataloader(dataset, batch_size=1):
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle = True
    )
    return loader

In [78]:
train_dataloader = create_dataloader(datasets.DatasetFolder('demo_data', loader=wav_loader, extensions='.wav'))

Training

In [79]:
model = SimpleCNN()
trainer = pl.Trainer(max_epochs=5, gpus=torch.cuda.device_count())
trainer.fit(model, train_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type              | Params
----------------------------------------------------
0 | conv1         | Conv2d            | 52    
1 | pool          | MaxPool2d         | 0     
2 | conv2         | Conv2d            | 204   
3 | adaptive_pool | AdaptiveMaxPool2d | 0     
4 | fc1           | Linear            | 960 K 
5 | fc2           | Linear            | 10 K  
6 | fc3           | Linear            | 170   
torch.Size([1, 1, 512, 1551])
torch.Size([1, 2, 254, 773])
torch.Size([1, 4, 40, 50])
torch.Size([1, 8000])
torch.Size([1, 120])
torch.Size([1, 1, 512, 950])
torch.Size([1, 2, 254, 473])
torch.Size([1, 4, 40, 50])
torch.Size([1, 8000])
torch.Size([1, 120])
torch.Size([1, 1, 512, 950])
torch.Size([1, 2, 254, 473])
torch.Size([1, 4, 40, 50])
torch.Size([1, 8000])
torch.Size([1, 120])
torch.Size([1, 1, 512, 1349])
torch.Size([1, 2, 254, 672])
torch.Size([1, 4, 40, 50])
torch.Size([1, 8000])
to

1