In [2]:
import os

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
if torch.cuda.is_available():
    dev = "cuda:0"
    num_workers = 1
    pin_memory = True
else:
    dev = "cpu"
    num_workers = 0
    pin_memory = False
num_workers = 0
pin_memory = False

In [None]:
root_dir = '../slices'

In [None]:
slice_df = pd.read_csv("../slice_filenames.csv")
slice_df.head()

In [None]:
class SignalDataset(Dataset):
    def __init__(self, slice_df):
        self.slice_df = slice_df

    def __len__(self):
        return len(self.slice_df)

    def __getitem__(self, idx):
        row = self.slice_df.iloc[idx,:]
        filename = row['filename']
        fold = row['fold']
        x = np.load(os.path.join(root_dir, f"fold{fold}", filename))
        x = torch.tensor(x, device=dev).float().unsqueeze(0)
        y = torch.tensor(row['classID'], device=dev)
        return x, y

In [None]:
train_df = slice_df[(slice_df['fold'] != 8) & (slice_df['fold'] != 9)]
test_df = slice_df[(slice_df['fold'] == 8) | (slice_df['fold'] == 9)]

In [None]:
root_dir = '../slices'

In [None]:
slice_df = pd.read_csv("../slice_filenames.csv")
slice_df.head()

In [None]:
class SignalDataset(Dataset):
    def __init__(self, slice_df):
        self.slice_df = slice_df

    def __len__(self):
        return len(self.slice_df)

    def __getitem__(self, idx):
        row = self.slice_df.iloc[idx,:]
        filename = row['filename']
        fold = row['fold']
        x = np.load(os.path.join(root_dir, f"fold{fold}", filename))
        x = torch.tensor(x, device=dev).float().unsqueeze(0)
        y = torch.tensor(row['classID'], device=dev)
        return x, y

In [56]:
train_df = slice_df[(slice_df['fold'] != 8) & (slice_df['fold'] != 9)]
test_df = slice_df[(slice_df['fold'] == 8) | (slice_df['fold'] == 9)]

In [64]:
class SignalModel(nn.Module):
    def __init__(self, n_channels=32):
        super(SignalModel, self).__init__()
        self.relu = nn.ReLU()
        self.log_softmax = nn.LogSoftmax(dim=2)
        self.conv1 = nn.Conv1d(1, n_channels, kernel_size=240, stride=16)
        self.bn1 = nn.BatchNorm1d(n_channels)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channels, n_channels, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channels)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channels, 2 * n_channels, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channels)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channels, 2 * n_channels, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channels)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channels, 10)
        """self.layer = nn.Sequential( #44100
            nn.Conv1d(in_channels=1, out_channels=n_channels, kernel_size=300, padding=150, stride=2, bias=False), #22050
            nn.BatchNorm1d(n_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1), #11025
            nn.Conv1d(in_channels=n_channels, out_channels=n_channels, kernel_size=3, stride=2, padding=1), #5513
            nn.BatchNorm1d(n_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1), #2757
            nn.Conv1d(in_channels=n_channels, out_channels=n_channels, kernel_size=3, stride=2, padding=1), #1379
            nn.BatchNorm1d(n_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1), #690
            nn.Conv1d(in_channels=n_channels, out_channels=n_channels, kernel_size=3, stride=2, padding=1), #345
            nn.BatchNorm1d(n_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1), #173
            nn.Flatten(),
            nn.Linear(in_features=173*n_channels, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=10),
        )"""

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = self.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return torch.squeeze(x)

In [65]:
model = SignalModel().to(device=dev)

In [66]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)
train_loop(train_dl, test_dl, model, loss_fn, optimizer)
for _ in range(3):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
    train_loop(train_dl, test_dl, model, loss_fn, optimizer)
for _ in range(6):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.001)
    train_loop(train_dl, test_dl, model, loss_fn, optimizer)

  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 2.171638990441958, val_loss: 2.1370315296309337


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 2.082924457391103, val_loss: 2.0913774285997664


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 2.062578489383062, val_loss: 2.0236293290342604


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 2.024481275677681, val_loss: 2.024122485092708


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 2.0069278796513874, val_loss: 1.995130947657994


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.9750117441018422, val_loss: 1.9602098039218359


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.968151819705963, val_loss: 1.9733299485274725


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.9512041827042899, val_loss: 1.926779295716967


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.9471990913152695, val_loss: 1.8959851562976837


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.922140325109164, val_loss: 1.9008562990597315


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.8651067992051442, val_loss: 1.875817779983793


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.83436987499396, val_loss: 1.8426436058112554


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.8109943648179372, val_loss: 1.8207889667579107


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.792018249630928, val_loss: 1.803649468081338


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.7747472961743673, val_loss: 1.7899213560989924


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.7578540205955506, val_loss: 1.7752206751278468


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.7403303081790606, val_loss: 1.7615614916597093


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.7231063509980837, val_loss: 1.749498724937439


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.7055982321500778, val_loss: 1.7376600461346763


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.6871025497714678, val_loss: 1.7257959416934423


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.665229400495688, val_loss: 1.7153353435652596


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.6480875889460245, val_loss: 1.7073563592774528


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.631591677169005, val_loss: 1.6997098752430506


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.6146755794684091, val_loss: 1.692933372088841


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.5984348644812902, val_loss: 1.6866987986224038


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.581979808708032, val_loss: 1.6818477852003915


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.5646188537279764, val_loss: 1.6776799474443709


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.5485582580169042, val_loss: 1.6712463412966048


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.531800232330958, val_loss: 1.6649084091186523


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.51549697269996, val_loss: 1.6581445038318634


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.493545720477899, val_loss: 1.6478451916149683


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.4774343878030778, val_loss: 1.642058287348066


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.4631938179334005, val_loss: 1.6356304500784193


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.4485656554500261, val_loss: 1.6309109926223755


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.4354006841778755, val_loss: 1.6274601689406805


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.4196890393892925, val_loss: 1.6228431165218353


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.4050072446465491, val_loss: 1.6211887981210436


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.3890242780248323, val_loss: 1.619031982762473


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.3739601378639539, val_loss: 1.6172473515783037


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.357829356690248, val_loss: 1.6151413470506668


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.328248673180739, val_loss: 1.6135807505675726


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.3151816494762898, val_loss: 1.6089411165033067


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.3034210036198297, val_loss: 1.6071342655590601


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2934211850166322, val_loss: 1.605848286833082


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2838005262116592, val_loss: 1.6052124606711524


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2741870592037836, val_loss: 1.6040045597723551


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2638189621269702, val_loss: 1.602721516575132


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2540936768054962, val_loss: 1.6016071609088354


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2447044673065344, val_loss: 1.6009348801204137


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2349871809283892, val_loss: 1.6005446399961198


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2259068176150323, val_loss: 1.5982289910316467


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2173438300689061, val_loss: 1.5993030880178725


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2093044119576613, val_loss: 1.5996653182165963


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.2022435670097669, val_loss: 1.600603774189949


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1952634006738663, val_loss: 1.6017715696777617


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1886202454566956, val_loss: 1.602078018443925


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1790806405246257, val_loss: 1.601323617356164


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.172785197943449, val_loss: 1.6012358452592577


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1661050468683243, val_loss: 1.6016095514808382


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1583727943400541, val_loss: 1.6021530202456884


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.14837117853264, val_loss: 1.597675655569349


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1427769874533018, val_loss: 1.5982035781655992


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1358828091373045, val_loss: 1.5993956412587846


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1288481041789056, val_loss: 1.5993748754262924


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1212735470384358, val_loss: 1.5994339010545187


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1136368289589882, val_loss: 1.599068516067096


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.1090547914306323, val_loss: 1.6000928218875612


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.102374235416452, val_loss: 1.6006286122969218


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0948987003415822, val_loss: 1.6015929728746414


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.08795674542586, val_loss: 1.6020046706710542


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0782150015234948, val_loss: 1.5979211649724416


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0701557607700427, val_loss: 1.5988305572952544


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0647867829849322, val_loss: 1.6007359964506966


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0574771496156852, val_loss: 1.6016495014939989


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0503578091661134, val_loss: 1.6034072488546371


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.04428211239477, val_loss: 1.6046237072774343


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.036911831671993, val_loss: 1.605593476976667


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0297428575654826, val_loss: 1.6070069989987783


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0234221006433168, val_loss: 1.6082166752644949


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0171260774135589, val_loss: 1.6091614621026176


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0067518981794517, val_loss: 1.601371863058635


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 1.0020796481519938, val_loss: 1.6062979804618018


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9954554254810015, val_loss: 1.6093708553484507


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9897626967479786, val_loss: 1.6110351298536574


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9824615729351839, val_loss: 1.6135307401418686


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9770615754028161, val_loss: 1.6168480558054787


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9729228691508373, val_loss: 1.6200820377894811


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.965693931405743, val_loss: 1.6211806620870317


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9607169196009636, val_loss: 1.6222518299307143


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9531231995671987, val_loss: 1.6238354103905814


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9400432728851835, val_loss: 1.6119633551154817


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9365014720708131, val_loss: 1.6187393495014735


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9299323336531718, val_loss: 1.6220928643430983


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9242582094545165, val_loss: 1.6255334402833665


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9209696466103197, val_loss: 1.6294165977409907


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9144424425438047, val_loss: 1.6310863771608897


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9087017248695095, val_loss: 1.634055916752134


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.9032951680322489, val_loss: 1.6375148126057215


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.8972707418724895, val_loss: 1.640453195997647


  0%|          | 0/120 [00:00<?, ?it/s]

train_loss: 0.8902540589372318, val_loss: 1.640840313264302


In [62]:
one_pass_acc(model, test_dl)

torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([256, 10])
torch.Size([230, 10])


0.4567348081769812