In [None]:
!pip install torchaudio pytorch_lightning

Collecting torchaudio
  Downloading torchaudio-0.9.1-cp37-cp37m-manylinux1_x86_64.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 5.3 MB/s 
[?25hCollecting pytorch_lightning
  Downloading pytorch_lightning-1.4.9-py3-none-any.whl (925 kB)
[K     |████████████████████████████████| 925 kB 46.8 MB/s 
[?25hCollecting torch==1.9.1
  Downloading torch-1.9.1-cp37-cp37m-manylinux1_x86_64.whl (831.4 MB)
[K     |████████████████████████████████| 831.4 MB 6.3 kB/s 
Collecting torchmetrics>=0.4.0
  Downloading torchmetrics-0.5.1-py3-none-any.whl (282 kB)
[K     |████████████████████████████████| 282 kB 72.1 MB/s 
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2021.10.0-py3-none-any.whl (125 kB)
[K     |████████████████████████████████| 125 kB 57.7 MB/s 
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 51.7 MB/s 
[?25hCollecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none

In [None]:
import torch, torchaudio
from torch import nn
from torch.nn import functional as F

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional

import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from google.colab import drive


drive.mount('/content/drive')
torch.cuda.is_available()

In [None]:
!ls "drive/My Drive/EEGNET"

In [None]:
path = Path("drive/My Drive/EEGNET")

In [None]:
feat = pickle.load(open("drive/My Drive/EEGNET/features.pkl", "rb"))
eeg = pickle.load(open("drive/My Drive/EEGNET/eeg.pkl", "rb"))

X = torch.from_numpy(feat['X'])
y = torch.from_numpy(feat['y'])
eeg = torch.from_numpy(eeg['EEG'])

In [None]:
class EEGDataset(torch.utils.data.Dataset):
    # Simple class to load the desired folders inside ESC-50
    
    def __init__(self, path: Path = Path("drive/My Drive/EEGNET"), 
                 sample_rate: int = 8000):
        # Load CSV & initialize all torchaudio.transforms:
        # Resample --> MelSpectrogram --> AmplitudeToDB

        feat = pickle.load(open("drive/My Drive/EEGNET/features.pkl", "rb"))
        eeg = pickle.load(open("drive/My Drive/EEGNET/eeg.pkl", "rb"))

        self.X = torch.from_numpy(feat['X']).float()
        self.y = torch.from_numpy(feat['y']).float()
        self.eeg = torch.from_numpy(eeg['EEG']).float()

        self.resample = torchaudio.transforms.Resample(
            orig_freq=250, new_freq=sample_rate
        ) #useful?
        self.melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate)
        self.db = torchaudio.transforms.AmplitudeToDB(top_db=80)
        
        
    def __getitem__(self, index):
        # Returns (xb, yb) pair, after applying all transformations on the audio file.
        
        wav = self.eeg[index]
        label = self.y[index]
        tmp = []
        for w in wav:
          tmp.append(self.db(
            self.melspec(
                self.resample(w.reshape(1, -1))
            )
        ))
        """
        xb = self.db(
            self.melspec(
                self.resample(wav)
            )
        )
        """

        xb = np.vstack(tmp)
        return xb, label
        
    def __len__(self):
        # Returns length
        return len(self.eeg)

In [None]:
train_data = EEGDataset()
for xb, yb in train_data:
    break

In [None]:
yb

In [None]:
# We use folds 1,2,3 for training, 4 for validation, 5 for testing.
train_data = EEGDataset()
val_data = EEGDataset()
test_data = EEGDataset()

train_loader = \
    torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=8)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=8)

In [None]:
class AudioNet(pl.LightningModule):
    
    def __init__(self, n_classes = 1, base_filters = 32):
        super().__init__()
        self.conv1 = nn.Conv2d(32, base_filters, 11, padding=5)
        self.bn1 = nn.BatchNorm2d(base_filters)
        self.conv2 = nn.Conv2d(base_filters, base_filters, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(base_filters)
        self.pool1 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(base_filters, base_filters * 2, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(base_filters * 2)
        self.conv4 = nn.Conv2d(base_filters * 2, base_filters * 4, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(base_filters * 4)
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(base_filters * 4, n_classes)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool1(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool2(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = self.fc1(x[:, :, 0, 0])
        return x
    
    def training_step(self, batch, batch_idx):
        # Very simple training loop
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss', loss, on_step=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y_hat = torch.argmax(y_hat, dim=1)
        acc = functional.accuracy(y_hat, y)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)
        return acc

    def test_step(self, batch, batch_idx):
      return self.validation_step(batch, batch_idx)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
pl.seed_everything(0)
# Test that the network works on a single mini-batch
audionet = AudioNet()
xb, yb = next(iter(train_loader))
audionet(xb).shape

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=25)
trainer.fit(audionet, train_loader, val_loader)


In [None]:
# TODO: implement the test loop.
trainer.test(audionet, test_loader)