In [30]:
# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

In [143]:
class HVC_RA(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(HVC_RA, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.HVC = nn.RNN(input_size, hidden_size, nonlinearity='relu')
        self.RA = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        batch_size = x.size(0)
        hidden = self.init_hidden(batch_size)
        output, hidden = self.HVC(x, hidden)
        output = self.RA(output)
        return output, hidden
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.hidden_size)
        return hidden
    
    def loss_function():
        '''
        Purpose: 
            - to calculate the difference between the HVC neuron activation, and proper RA output 
            - Uses MSE loss
    
        Input:
            - HVC neuron activation
            - RA output

        Output:
            - MSE loss
        '''

        

        pass

In [144]:
class Training():
    def __init__(self):
        pass
        
    def generate_syllables(self, syllable_count):
        '''
        Purpose: 
            - Generate syllables for training
            - 4 params for syllable generation, freq, amplitude, y-offset, length
        Input: 
            - Number of syllables to generate
        Output: 
            - returns songs as a numpy array of shape (syllable_count * individual lengths, 3)
            - returns length of each syllable 
        ''' 

        self.syllable_count = syllable_count

        # careful, the code below is cringe af!
        syllable_duration = np.random.uniform(3, 10, syllable_count)
        syllable_duration = np.array(syllable_duration, dtype=int)
        self.syllable_lengths = syllable_duration

        total_syllable_length = np.sum(syllable_duration)

        syllables = np.zeros((total_syllable_length, 3))
        self.syllable_index = np.zeros((total_syllable_length, 1))
        
        cum_length = 0
        for i, length in enumerate(syllable_duration):
            self.syllable_index[cum_length:cum_length+length] = i
            freq = np.random.uniform(0, 1)
            amp = np.random.uniform(0, 1)
            y_offset = np.random.uniform(0, 1)
            for t in range(length):
                syllables[cum_length, 0] = freq
                syllables[cum_length, 1] = amp
                syllables[cum_length, 2] = y_offset
                cum_length += 1

        self.syllables = syllables

        self.num_syllables = syllable_count

        return syllables, syllable_duration

    def generate_train_set(self, train_size):
        '''
        Purpose:
            - Generate a training set for the model
        Input:
            - train_size: number of training examples to generate
        Output:
            - train_set: a list of training examples
        '''
        output_set = []
        input_set = []

        for i in range(train_size):
            # self.syllables.shape[0] is the the total num of timesteps 
            padded_syllables = np.zeros((self.syllables.shape[0], 3))
            padded_HVC_on = np.zeros((self.syllables.shape[0], self.syllable_count))

    
            lower_bound_syllable_index = np.random.randint(0, self.num_syllables)
            upper_bound_syllable_index = np.random.randint(lower_bound_syllable_index, self.num_syllables)

            if lower_bound_syllable_index != upper_bound_syllable_index:
                list_of_syllables_indices = [i for i in range(lower_bound_syllable_index, upper_bound_syllable_index)]
            else:
                list_of_syllables_indices = [lower_bound_syllable_index]

            # create a list of first timestep of each syllable
            list_of_first_timestep_per_syllable = []

            for syllable_index in list_of_syllables_indices:
                first_timestep = np.sum(self.syllable_lengths[:syllable_index])
                list_of_first_timestep_per_syllable.append(first_timestep)

        
            # set the HVC_on to 1 for the interval
            for i, syllable in enumerate(list_of_syllables_indices):
                padded_HVC_on[list_of_first_timestep_per_syllable[i]:list_of_first_timestep_per_syllable[i]+self.syllable_lengths[syllable], syllable] = 1

            for syllable in list_of_syllables_indices:
                padded_syllables[list_of_first_timestep_per_syllable[i]:list_of_first_timestep_per_syllable[i]+self.syllable_lengths[syllable], :] = self.syllables[list_of_first_timestep_per_syllable[i]:list_of_first_timestep_per_syllable[i]+self.syllable_lengths[syllable], :]
          
            # padded to tensor
            padded_syllables = torch.from_numpy(padded_syllables).float()
            padded_HVC_on = torch.from_numpy(padded_HVC_on).float()

            output_set.append(padded_syllables)
            input_set.append(padded_HVC_on)

        return output_set, input_set


    def train(self, model, epochs, train_size):
        '''
        Purpose:
            - Train the model
        Input:
            - model
            - train_loader
            - epochs
        Output:
            - trained model
        ''' 
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.MSELoss()

        output_set, input_set = self.generate_train_set(train_size)

        for epoch in range(epochs): 
            for i in range(train_size):
                optimizer.zero_grad()
                output, hidden = model(input_set[i])
                loss = criterion(output, output_set[i])
                loss.backward()
                optimizer.step()
        pass

    def get_syllables(self):
        return self.syllables

In [148]:
num_syllables = 4
input_size = 10
hidden_size = 5
output_size = 3

model = HVC_RA(input_size, hidden_size, output_size)
train = Training()

syllables, syllable_duration = train.generate_syllables(num_syllables)

syllables = train.get_syllables()

train.train(model, epochs=10, train_size=1000)

RuntimeError: For unbatched 2-D input, hx should also be 2-D but got 3-D tensor