In [None]:

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        
        self.disc = nn.Sequential(
            # N x channels_img=1 x 32
            nn.Conv1d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            
            # N x features_d x 16
            nn.Conv1d(features_d, features_d*2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm1d(features_d*2, affine = True),
            nn.LeakyReLU(0.2),
            
            # N x features_d*2 x 8
            nn.Conv1d(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm1d(features_d*4, affine = True),
            nn.LeakyReLU(0.2),
            
            # N x features_d*4 x 4
            nn.Conv1d(features_d*4, 1, kernel_size=4, stride=1, padding=0),
            # N x 1 x 1
       
        )
    def forward(self, x):
#         print(x.shape)
#         y = self.disc(x)
#         print(y.shape)
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.disc = nn.Sequential(
            # N x channel_noise x 1
            nn.ConvTranspose1d(channels_noise, features_g*8, kernel_size=4, stride=1,padding=0),
            nn.BatchNorm1d(features_g*8),
            nn.ReLU(),
            
            # N x features_g*16 x 4
            nn.ConvTranspose1d(features_g*8, features_g*4, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm1d(features_g*4),
            nn.ReLU(),
            
            # N x features_g*8 x 16
            nn.ConvTranspose1d(features_g*4, features_g*2, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm1d(features_g*2),
            nn.ReLU(),
            
            # N x features_g*4 x 32
            nn.ConvTranspose1d(features_g*2, channels_img, kernel_size=4, stride=2,padding=1),
            # N x channels_img x 32
            nn.Tanh()
            
        )
        
    def forward(self, x):
        return self.disc(x)   
    
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, H, W = real.shape
  
    alpha = torch.rand((BATCH_SIZE, 1, 1)).repeat(1, H, W).to(device)

    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty