In [9]:
import torch
import torch.nn.functional as F

In [8]:
a = torch.tensor([[1, 2],[1, 4], [1,1]])
b = torch.tensor([4, 5, 6])
a.shape

torch.Size([3, 2])

In [None]:
class DiffWaveBlock(torch.nn.Module):
    def __init__(self, layer_index, C) -> None:
        super().__init__()
        self.layer_index = layer_index
        self.C = C
        self.input = None
        self.x_skip = None

        # diffusion time step embedding
        self.fc_timestep = torch.nn.Linear(512, C)

        #bi directional conv
        self.conv_dilated = torch.nn.Conv1d(C, 2*C, 3, dilation=2**layer_index, padding='same')

        self.conv_skip = torch.nn.Conv1d(C, C, 1)
        self.conv_next = torch.nn.Conv1d(C, C, 1)

    def forward(self, x, t):

        self.input = x #TODO: check if passing forward through layers happens inplace!. Potentially use self.input = x.clone()
        t = self.fc_timestep(t)
        x = x + t #broadcast addition
        x = self.conv_dilated(x)
        x_tanh, x_sigmoid = x.chunk(2, dim=1)
        x_tanh = torch.tanh(x_tanh)
        x_sigmoid = torch.sigmoid(x_sigmoid)
        x = x_tanh * x_sigmoid
        self.x_skip = self.conv_skip(x)
        x = self.conv_next(x) + self.input
        return x







In [None]:
class DiffWave(torch.nn.Module):
    def __init__(self, C) -> None:
        super().__init__()
        #in
        self.fc1 = torch.nn.Linear(128, 512)
        self.fc2 = torch.nn.Linear(512, 512)
        self.conv_in_1 = torch.nn.Conv1d(1, C, 1)

        #blocks
        self.layer1 = DiffWaveBlock(0, C)
        self.layer2 = DiffWaveBlock(1, C)

        #out
        self.conv_out_1 = torch.nn.Conv1d(C, C, 1)
        self.conv_out_2 = torch.nn.Conv1d(C, 1, 1)

    def forward(self, x, t):

        #waveform input
        x = self.conv_in_1(x)

        #time embedding t=0
        t=0
        t = self.fc1(t)
        t = F.silu(t)
        t = self.fc2(t)
        t = F.silu(t)
        x = self.layer1(x, t)

        #time embedding t=1
        t=1
        t = self.fc1(t)
        t = F.silu(t)
        t = self.fc2(t)
        t = F.silu(t)
        x = self.layer2(x, t)

        #out
        x = self.conv_out_1(x)
        x = self.conv_out_2(x)
        return x