In [15]:
import torch
from torch.utils import data
import torch.nn as nn
from glob import glob
import matplotlib.pyplot as plt
import librosa as li

In [154]:
class GuitarChord(data.Dataset):
    def __init__(self):
        self.liste = glob("AC_GuitarChords/*.wav")
        print("Preprocessing stuff... ", end="")
        
        for elm in self.liste:
            [x,fs] = li.load(elm)
            x = x[:4*fs]
            mel = li.filters.mel(fs,2048,500)
            S = torch.from_numpy(mel.dot(abs(li.stft(x,n_fft=2048,)))).float()
            torch.save(S,elm.replace(".wav",".pt"))
        
        print("Done!")
    def __getitem__(self,i):
        return torch.load(self.liste[i].replace(".wav",".pt"))
    
    def __len__(self):
        return len(self.liste)
            
class WAE(nn.Module):
    def __init__(self):
        super(WAE,self).__init__()
        size = [1, 16, 32, 64, 128, 256]
        zdim = 32
        
        self.act = nn.LeakyReLU()
        
        self.enc1 = nn.Conv2d(size[0],size[1],stride=2, kernel_size=5, padding=2)
        self.enc2 = nn.Conv2d(size[1],size[2],stride=2, kernel_size=5, padding=2)
        self.enc3 = nn.Conv2d(size[2],size[3],stride=2, kernel_size=5, padding=2)
        self.enc4 = nn.Conv2d(size[3],size[4],stride=2, kernel_size=5, padding=2)
        self.enc5 = nn.Conv2d(size[4],size[5],stride=2, kernel_size=5, padding=2)
        
        self.lin1 = nn.Linear(256*16*6, 1024)
        self.lin2 = nn.Linear(1024, 256)
        self.lin3 = nn.Linear(256, zdim)
        
        self.dec1 = nn.ConvTranspose2d(size[1],size[0],stride=2, kernel_size=5, padding=2)
        self.dec2 = nn.ConvTranspose2d(size[2],size[1],stride=2, kernel_size=5, padding=2)
        self.dec3 = nn.ConvTranspose2d(size[3],size[2],stride=2, kernel_size=5, padding=2)
        self.dec4 = nn.ConvTranspose2d(size[4],size[3],stride=2, kernel_size=5, padding=2)
        self.dec5 = nn.ConvTranspose2d(size[5],size[4],stride=2, kernel_size=5, padding=2)
        
        self.dlin1 = nn.Linear(1024,256*16*6)
        self.dlin2 = nn.Linear(256,1024)
        self.dlin3 = nn.Linear(zdim,256)
        
        self.f1   = nn.Sequential(self.enc1,
                                nn.BatchNorm2d(num_features=size[1]),self.act, 
                                self.enc2,
                                nn.BatchNorm2d(num_features=size[2]),self.act, 
                                self.enc3,
                                nn.BatchNorm2d(num_features=size[3]),self.act, 
                                self.enc4,
                                nn.BatchNorm2d(num_features=size[4]),self.act, 
                                self.enc5,
                                nn.BatchNorm2d(num_features=size[5]),self.act)
        
        self.f2   = nn.Sequential(self.lin1,
                                 nn.BatchNorm1d(num_features=1024),self.act,  
                                 self.lin2,
                                 nn.BatchNorm1d(num_features=256),self.act,
                                 self.lin3)
        
        self.f3   = nn.Sequential(self.dlin3,
                                  nn.BatchNorm1d(num_features=256), self.act,
                                  self.dlin2,
                                  nn.BatchNorm1d(num_features=1024), self.act,
                                  self.dlin1,
                                  nn.BatchNorm1d(num_features=256*16*6), self.act)
        
        self.f4   = nn.Sequential(self.dec5,
                                 nn.BatchNorm2d(num_features=))
                                  
        
        
        
    def flatten(self, inp):
        dim = 1
        for i,elm in enumerate(inp.size()):
            if i!=0:
                dim *= elm
        return inp.view(-1,dim)
    
    def encode(self, inp):
        inp = inp.unsqueeze(1)
        inp = self.f1(inp)
        inp = self.flatten(inp)
        inp = self.f2(inp)
        return inp
        
    
    def forward(self,inp):
        return self.decode(self.encode(inp))
        
        
        
        

In [155]:
#GC = GuitarChord()
GCloader = data.DataLoader(GC, batch_size=8, shuffle=True)

In [156]:
model = WAE()

In [157]:
for i,o in enumerate(GCloader):
    print(model.encode(o).size())

torch.Size([8, 32])
torch.Size([8, 32])
