<a href="https://colab.research.google.com/github/erika-n/FractalMusicBox/blob/master/soundgen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchaudio
import IPython
from matplotlib import pyplot as plt
from os import listdir
from os.path import isfile, join
import random
from torch import nn
import torch.nn.functional as F

In [None]:
data, rate = torchaudio.load('/content/drive/MyDrive/music/songsinmyhead/songsinmyhead_2016/01blame.wav')
rate

In [None]:
data.shape

In [None]:
data = data.sum(0)

In [None]:
data.shape

In [None]:
data.min()

In [None]:
IPython.display.Audio(data[:5*rate],rate=rate)

In [None]:
spect = torch.stft(data, 2000, return_complex=True)


In [None]:
spect.shape

In [None]:
mag = torch.absolute(spect)

In [None]:
mag.shape

In [None]:
plt.imshow(mag[:, :2000], origin='lower')

In [None]:
phase = torch.angle(spect)

In [None]:
phase.shape

In [None]:
plt.imshow(phase[:, :2000], origin='lower')

In [None]:
def mag_phase_to_complex(mag, phase):
    return mag*torch.exp(1.j*phase)

In [None]:
new_data = mag_phase_to_complex(mag, phase)

In [None]:
new_data.shape

In [None]:
time_domain = torch.istft(new_data, 2000)
time_domain.shape

In [None]:
IPython.display.Audio(time_domain[:5*rate],rate=rate)

In [None]:
class Dataset:
  def __init__(self, folder, n_fft=100, fft_width=50):
    self.folder = folder

    self.files = [f for f in listdir(folder) if isfile(join(folder, f))]
    self.files = [self.files[0]] #TMPDEBUG: use one file to see if it will converge
    random.shuffle(self.files)
    self.fi = 0
    self.di = 0
    self.data = None
    self.n_fft = n_fft
    self.fft_width = fft_width


  def mag_phase_to_complex(self, mag, phase):
    return mag*torch.exp(1.j*phase)


  def to_audio(self, data):
    mag = data[:, 0, :, :]
    phase = data[:, 1, :, :]

    mag = mag.transpose(0, 1)
    phase = phase.transpose(0, 1)

    mag = mag.reshape(mag.shape[0], -1)
    phase = phase.reshape(phase.shape[0], -1)

    new_data = mag_phase_to_complex(mag, phase)

    time_domain = torch.istft(new_data, 200)
    return time_domain

  def load_file(self, path):
    data, rate = torchaudio.load(path)
    assert rate == 44100
    data = data.sum(0)/2.0
    assert data.max() <= 1.0
    assert data.min() >= -1.0
    spect = torch.stft(data, self.n_fft, return_complex=True)
    mag = torch.absolute(spect)
    phase = torch.angle(spect)

    chunked = torch.stack((self.chunk(mag), self.chunk(phase)), 1)

    return chunked

  def chunk(self, data):
    leftover = data.shape[1] % self.fft_width
    if leftover != 0:
      data = data[:, :-leftover]

    chunked_data = data.reshape((data.shape[0], -1, self.fft_width))
    chunked_data = torch.transpose(chunked_data, 0, 1)

    return chunked_data

  def get_batch(self, batch_size):

    if self.data is None: # or self.di + batch_size >= self.data.shape[0]:
      self.fi = (self.fi + 1) % len(self.files)
      print("Loading file", self.files[self.fi])
      self.data = self.load_file(join(self.folder, self.files[self.fi]))
      self.di = 0
      self.data = self.data[:5*batch_size, :, :, :]#TMPDEBUG

    if self.di + batch_size >= self.data.shape[0]:
      self.di = 0
    batch = self.data[self.di:self.di + batch_size, :, :, :]
    self.di += batch_size

    batch = batch[:, 0, :, :]
    batch = batch[:, None, :, :]

    return batch





In [None]:
# test dataset

ds = Dataset('/content/drive/MyDrive/music/songsinmyhead/songsinmyhead_2016/')
magphase = ds.load_file('/content/drive/MyDrive/music/songsinmyhead/songsinmyhead_2016/01blame.wav')
plt.imshow(magphase[0,0, :, :], origin='lower')


In [None]:
plt.imshow(magphase[0, 1, :, :], origin='lower')

In [None]:
for i in range(10):
  data = ds.get_batch(2)
plt.imshow(data[0, 0, :, :], origin='lower')
data.shape

In [None]:


time_domain = ds.to_audio(data)

In [None]:
IPython.display.Audio(time_domain,rate=rate)

In [None]:


class AutoEncode(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(1296, 100)
        self.fc2 = nn.Linear(100, 1*51*50)
        # # self.fc3 = nn.Linear(84, 10)
        # # self.fc4 = nn.Linear(10, 84)
        # self.fc5 = nn.Linear(84, 120)
        # self.fc6 = nn.Linear(120, 1*51*50)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # # x = F.relu(self.fc3(x))
        # # x = F.relu(self.fc4(x))
        # x = F.relu(self.fc5(x))
        # x = F.relu(self.fc6(x))
        x = x.view(x.shape[0], 1, 51, 50)
        return x




In [None]:
data.shape

In [None]:
model = AutoEncode()
output = model(data)
output.shape

In [None]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)



running_loss = 0.0
n_batches = 2000
batch_size = 2
for i in range(n_batches):
    # get the inputs; data is a list of [inputs, labels]
    batch = ds.get_batch(batch_size)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(batch)
    loss = criterion(outputs, batch)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    print_every = 100
    if i % print_every == 0:
        print(f'[{i + 1:5d}] loss: {running_loss/print_every:.6f}')
        running_loss = 0.0



In [None]:
data = ds.get_batch(64)
data = torch.rand(data.shape)
data = data[5:7]
print(data.shape)
out = model(data)
out.shape
out = out.detach()



In [None]:
plt.imshow(data[0, 0, :, :], origin='lower')


In [None]:
plt.imshow(out[0, 0, :, :], origin='lower')

In [None]:
data = ds.get_batch(64)
data = torch.rand(data.shape)
output = model(data)
output.shape
output = output.detach()
plt.imshow(data[0, 0, :, :], origin='lower')
plt.imshow(output[0, 0, :, :], origin='lower')

In [None]:
audio = ds.to_audio(output)
IPython.display.Audio(audio.detach(),rate=rate)