In [1]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:12639'
os.environ['https_proxy'] = 'http://127.0.0.1:12639'

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

In [2]:
import torchvision.models as models

class SimpleCNN(pl.LightningModule):
    def __init__(self, num_classes=2, num_channels=1, target_sr=44100, conv_targ_out_size=2000):
        super().__init__()
        # init a pretrained resnet
        self.num_classes = num_classes
        self.conv_targ_out_size = conv_targ_out_size
        self.conv1 = nn.Conv1d(num_channels, 2, 5)
        self.conv2 = nn.Conv1d(2, 4, 5)
        self.pool = nn.MaxPool1d(2)
        self.adaptive_pool = nn.AdaptiveMaxPool1d(conv_targ_out_size)
        self.fc1 = nn.Linear(4 * conv_targ_out_size, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.adaptive_pool(F.relu(self.conv2(x)))
        print(x.shape)
        x = x.view(-1, 4 * self.conv_targ_out_size)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def training_step(self, batch, batch_idx):
        images, target = batch
        preds = self.forward(images)
        loss = F.cross_entropy(preds, target)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
        return [optimizer], [scheduler]

Data

In [3]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchaudio

In [4]:
def wav_loader(fn, num_channels=1, target_sr=44100):
    waveform, sr = torchaudio.load(fn)
    transformed = torchaudio.transforms.Resample(sr, target_sr)(waveform[:num_channels,:])
    return transformed

In [5]:
def create_dataloader(dataset, batch_size=1):
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle = True
    )
    return loader

In [6]:
train_dataloader = create_dataloader(datasets.DatasetFolder('demo_data', loader=wav_loader, extensions='.wav'))

Training

In [7]:
model = SimpleCNN()
trainer = pl.Trainer(max_epochs=10, gpus=torch.cuda.device_count())
trainer.fit(model, train_dataloader)

h.Size([1, 4, 2000])
Epoch 6:  85%|████████▌ | 51/60 [00:06&lt;00:01,  8.31it/s, loss=0.698, v_num=9]torch.Size([1, 4, 2000])
Epoch 6:  87%|████████▋ | 52/60 [00:06&lt;00:00,  8.22it/s, loss=0.698, v_num=9]torch.Size([1, 4, 2000])
Epoch 6:  88%|████████▊ | 53/60 [00:06&lt;00:00,  8.19it/s, loss=0.697, v_num=9]torch.Size([1, 4, 2000])
Epoch 6:  90%|█████████ | 54/60 [00:06&lt;00:00,  8.16it/s, loss=0.696, v_num=9]torch.Size([1, 4, 2000])
Epoch 6:  92%|█████████▏| 55/60 [00:06&lt;00:00,  8.17it/s, loss=0.697, v_num=9]torch.Size([1, 4, 2000])
Epoch 6:  93%|█████████▎| 56/60 [00:06&lt;00:00,  8.20it/s, loss=0.696, v_num=9]torch.Size([1, 4, 2000])
Epoch 6:  95%|█████████▌| 57/60 [00:06&lt;00:00,  8.21it/s, loss=0.697, v_num=9]torch.Size([1, 4, 2000])
Epoch 6:  97%|█████████▋| 58/60 [00:07&lt;00:00,  8.22it/s, loss=0.696, v_num=9]torch.Size([1, 4, 2000])
Epoch 6:  98%|█████████▊| 59/60 [00:07&lt;00:00,  8.23it/s, loss=0.696, v_num=9]torch.Size([1, 4, 2000])
Epoch 7:   0%|          | 0/60 [00

1