In [1]:
import os

import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from bc.beats import get_beat_bank

base_dir = os.path.abspath('..')
data_dir = os.path.join(base_dir, 'data')

# Table of record names and the beat types they contain
beat_table = pd.read_csv(os.path.join(data_dir, 'beat-types.csv'), dtype={'record':object})
beat_table.set_index('record', inplace=True)

In [2]:
# Load Beats with fixed width
n_beats, n_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table,
                                   wanted_type='N', filter=True, fixed_width=240)
l_beats, l_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table,
                                   wanted_type='L' ,filter=True, fixed_width=240)
r_beats, r_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table,
                                   wanted_type='R', filter=True, fixed_width=240)
v_beats, v_centers = get_beat_bank(data_dir=data_dir, beat_table=beat_table,
                                   wanted_type='V', filter=True, fixed_width=240)

In [6]:
[len(b) for b in v_beats[:40]]

[240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240,
 240]

In [None]:
class BeatNet(nn.Module):
    def __init__(self):
        super(BeatNet, self).__init__()
        # 1 input channel, 
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=(1,3))
        self.conv2 = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=(1,3))
        # maxpool function
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        # First fully connected layer
        self.fc1 = nn.Linear(8 * 3, 12)
        # Final fully connected layer 
        self.fc2 = nn.Linear(12, 4)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 8 * 3)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
beatnet = BeatNet()

In [None]:
help(nn.Conv2d)