<a href="https://colab.research.google.com/github/iakioh/MusiCAN/blob/main/models/first_music_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch

## Settings

In [6]:
# Pianoroll parameters
notes_per_octave = 12
number_octaves = 6   # as in MuseGAN
number_pitches = number_octaves * notes_per_octave

number_bars = 1
beats_per_bar = 4   # 4/4 rythm
blips_per_beat = 24   # temporal resolution of Lakh dataset
number_blips  = number_bars * beats_per_bar * blips_per_beat

pianoroll_size = number_pitches * number_blips

print(f"pianoroll array: {number_pitches}x{number_blips} = {pianoroll_size}")


# Generator parameters
seed_length = 64


pianoroll array: 72x96


## Data Preparation

## Model

In [7]:
# code from https://www.hassanaskary.com/python/pytorch/deep%20learning/2020/09/19/intuitive-explanation-of-straight-through-estimators.html#:~:text=A%20straight%2Dthrough%20estimator%20is,function%20was%20an%20identity%20function.

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return torch.nn.functional.hardtanh(grad_output)

class StraightThroughEstimator(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

In [8]:
class Generator (torch.nn.Module) :
    """
    first GAN generator, MLP, layer size doubles until output size O is reached.

    input: normally distributed random vector of length I
    output: binary vector of length O
    """

    def __init__ (self) : 
        super().__init__()
        
        I = seed_length   # length of input vector
        O = pianoroll_size    # lenght of output vector

        layers = []
        current_size = I
        while 2*current_size < O:
            layers.append([
                torch.nn.Linear(current_size, 2*current_size),
                torch.nn.ReLU(),
                torch.nn.BatchNorm1d(2*current_size),
            ])
            current_size *= 2
        layers.append([
            torch.nn.Linear(current_size, O),
            StraightThroughEstimator()
        ])

        self.generator = torch.nn.Sequential(*layers)


    def forward (self, input) :
        return self.generator(input)

In [None]:
        '''
        torch.nn.Linear(I, 2*I),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(2*I),
            torch.nn.Linear(2*I, 4*I),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(4*I),
            torch.nn.Linear(4*I, 8*I),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(8*I),
            torch.nn.Linear(8*I, 16*I),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(16*I),
            torch.nn.Linear(16*I, 32*I),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(32*I),
            torch.nn.Linear(32*I, 64*I),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(64*I),
            torch.nn.Linear(64*I, O),
            StraightThroughEstimator()
        '''
        
        '''
        self.layer0 = torch.nn.Linear(I, 2*I)
        self.layer1 = torch.nn.Linear(2*I, 4*I)
        self.layer2 = torch.nn.Linear(4*I, 8*I)
        self.layer3 = torch.nn.Linear(8*I, 16*I)
        self.layer4 = torch.nn.Linear(16*I, 32*I)
        self.layer5 = torch.nn.Linear(32*I, 64*I)
        self.layer6 = torch.nn.Linear(64*I, O)

        layer_input = torch.nn.functional.relu(self.layer0(input))
        layer_input = torch.nn.functional.relu(self.layer1(input))
        layer_input = torch.nn.functional.relu(self.layer2(input))
        layer_input = torch.nn.functional.relu(self.layer3(input))
        layer_input = torch.nn.functional.relu(self.layer4(input))
        layer_input = torch.nn.functional.relu(self.layer5(input))
        output = StraightThroughEstimator.forward(self.layer6(input))
        '''