In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# autoreload
%reload_ext autoreload
%autoreload 

In [54]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64,kernel_size=5, stride=1,requires_sum=True,generator=True):
        super(ConvBlock, self).__init__()
        self.requires_sum = requires_sum
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride, padding='same' if generator else 2),
            nn.PReLU() if generator else nn.LeakyReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,stride=stride, padding='same' if generator else 2) if self.requires_sum else nn.Identity()) # changed in_channels to out_channels 15 Jul 2023 @ 15:39:28
        
    def forward(self, x):
        if self.requires_sum is True:
            out = self.block(x)
            return x + out
        else:
            out = self.block(x)
            return out
        


class ConvBlock1D(nn.Module):
    def __init__(self,in_channels=1, out_channels=32,kernel_size=(1,32), stride=2,padding=2):
        super(ConvBlock1D, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU()
        )

    def forward(self,x):
        return self.block(x)
    


class Generator(nn.Module):
    def __init__(self,in_channel=2, out_channel=64, blocks=4):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=11,stride=1,padding='same'),
            nn.PReLU()
        )

        self.blocks = nn.Sequential(*[ConvBlock() for _ in range(blocks)])
        self.conv = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=5,stride=1,padding='same')

        self.terminal = nn.Sequential(
            ConvBlock(in_channels=out_channel, out_channels=out_channel,requires_sum=False),
            ConvBlock(in_channels=out_channel, out_channels=out_channel,requires_sum=False),
            nn.Conv2d(in_channels=out_channel, out_channels=2,kernel_size=11,stride=1,padding='same')
        )

    def forward(self, z):
        z = self.initial(z)
        out = self.blocks(z)
        out = self.conv(out)
        out = out + z
        return self.terminal(out)


class Discriminator(nn.Module):
    def __init__(self,in_channels=3,out_channels=32,in_features=2048,out_features=512,in_features_final=2048,blocks=2):
        super(Discriminator, self).__init__()
        
        self.magnitude_path = nn.Sequential(
            ConvBlock(in_channels=1,out_channels=out_channels,stride=2,requires_sum=False,generator=False),
            ConvBlock(in_channels=out_channels,out_channels=out_channels,stride=2,requires_sum=False,generator=False),
            nn.Flatten(1,-1),
            nn.Linear(in_features=in_features,out_features=in_features),
            nn.LeakyReLU()
            )
        
        self.phase_path = ConvBlock1D()

        # [f(x) if condition else g(x) for x in sequence]
        self.blocks = nn.Sequential(*[ConvBlock1D(in_channels=1, out_channels=64) if blocks == 0 else ConvBlock1D(in_channels=64, out_channels=64) for _ in range(blocks)])

        self.terminal = nn.Sequential(
            nn.Linear(in_features=in_features_final,out_features=1024),
            nn.LeakyReLU(),
            nn.Linear(in_features=1024,out_features=1)
        )


    def forward(self, comp, magnitude):
        mag = self.magnitude_path(magnitude) # will be 32 * H_mag * W_mag
        print(mag.shape)
        phase = torch.istft(comp,n_fft=128)
        phase = self.phase_path(phase)

        # out = torch.cat([mag,phase])
        # out = self.blocks(out)
        # return self.terminal(out)

In [62]:
x = torch.randn(1,2,4,4,dtype=torch.complex64)
residual = torch.randn(4,64,64,64)


gen = Generator(2, 64, 4)
disc = Discriminator(in_features=32*7*7)

gen.to(torch.complex64)



# t1, t2 = torch.randn(1,2,28,28), torch.randn(1,1,28,28)

# disc(t1,t2)

gen(x)

RuntimeError: "prelu_cpu" not implemented for 'ComplexFloat'

In [48]:
from src.data_loader import *
import os


path = '../data/AudioMNIST_Indicies/dummy_labels.csv'
root_dir = '../data/AudioMNIST/'

data = AvianNatureSounds(annotation_file=path,root_dir=root_dir,mode='testing',max_ms=56)
train_loader = DataLoader(dataset=data, batch_size=4, shuffle=True)
batch = next(iter(train_loader)) 

batch[0].shape

torch.Size([4, 2, 29, 29])

In [22]:
conv = nn.Conv2d(1, 3, 5,1,2)
t = torch.randn((4,1,28,28)) # (K - 1) / 2 

conv(t).shape

torch.Size([4, 3, 28, 28])

In [28]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64,kernel_size=5, stride=1,requires_sum=True,generator=True):
        super(ConvBlock, self).__init__()
        self.requires_sum = requires_sum
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride, padding='same' if generator else 2),
            nn.PReLU() if generator else nn.LeakyReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,stride=stride, padding='same' if generator else 2) if self.requires_sum else nn.Identity()) # changed in_channels to out_channels 15 Jul 2023 @ 15:39:28
        
    def forward(self, x):
        if self.requires_sum is True:
            out = self.block(x)
            return x + out
        else:
            out = self.block(x)
            return out 
        
cv = ConvBlock(1,32,5,2,False, False)
t = torch.randn((4,1,28,28))

cv(t).shape

torch.Size([4, 32, 14, 14])

In [None]:
nn.Complex