In [None]:
import torch
import torch.nn as nn
class CNN_Discriminator(nn.Module):
    # Code copied from https://github.com/vlbthambawita/deepfake-ecg/blob/9fa9a02a9fefe579e322d56fa591c3887d7ad135/deepfakeecg/models/pulse2pulse.py#L5
   
    def __init__(self, num_classes, signal_length, model_size=64, ngpus=1, num_channels=2, shift_factor=2,
                 alpha=0.2, verbose=False, dropout=None):
        super(CNN_Discriminator, self).__init__()
        self.num_classes = num_classes
        self.signal_length = signal_length
        self.model_size = model_size  # d
        self.ngpus = ngpus
        self.num_channels = num_channels  # c
        self.shift_factor = shift_factor  # n
        self.alpha = alpha
        self.verbose = verbose
        
        self.embed = nn.Embedding(num_classes, 1*signal_length)
        self.conv1 = nn.Conv1d(num_channels+1,  model_size, 5, stride=2, padding=1)
        self.conv2 = nn.Conv1d(model_size, 2 * model_size, 5, stride=2, padding=1)
        self.conv3 = nn.Conv1d(2 * model_size, 5 * model_size, 5, stride=2, padding=1)
        self.conv4 = nn.Conv1d(5 * model_size, 10 * model_size, 5, stride=2, padding=1)
        self.conv5 = nn.Conv1d(10 * model_size, 20 * model_size, 5, stride=2, padding=1)
        self.conv6 = nn.Conv1d(20 * model_size, 25 * model_size, 5, stride=4, padding=1)
        #self.conv7 = nn.Conv1d(25 * model_size, 100 * model_size, 5, stride=4, padding=1)
        
        self.dropout = dropout
        self.fc1 = nn.Linear(9600, 1) #9600
#         for m in self.modules():
#             if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
#                 nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x, labels):
        #print(x.shape)
        embedding = self.embed(labels).view(labels.shape[0], 1,self.signal_length) 
        #print(embedding.shape)
        x = torch.cat([x,embedding], dim=1) # N x C x channel_signal x signal_length
        x = F.leaky_relu(self.conv1(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        if self.dropout:
            x = F.dropout(x, self.dropout)
        
        x = F.leaky_relu(self.conv2(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        if self.dropout:
            x = F.dropout(x, self.dropout)
            
        x = F.leaky_relu(self.conv3(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        if self.dropout:
            x = F.dropout(x, self.dropout)

        x = F.leaky_relu(self.conv4(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        if self.dropout:
            x = F.dropout(x, self.dropout)
            
        x = F.leaky_relu(self.conv5(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        #x = self.ps5(x)
        if self.dropout:
            x = F.dropout(x, self.dropout)
        
        x = F.leaky_relu(self.conv6(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        if self.dropout:
            x = F.dropout(x, self.dropout)

        x = x.view(-1, x.shape[1] * x.shape[2])
  
        if self.verbose:
            print(x.shape)
    
        return self.fc1(x)

In [19]:
class DoubleConv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1)
        self.bn = nn.BatchNorm1d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class InConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x):
        x = self.conv(x)
        return x

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=None):
        super(Down, self).__init__()
        self.dropout = dropout
        self.maxpool = nn.MaxPool1d(2)
        self.dbc = DoubleConv(in_ch, out_ch)
    def forward(self, x):
        x = self.maxpool(x)
        x = self.dbc(x)
        if self.dropout:
            x = F.dropout(x, self.dropout)
        return x


class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose1d(in_ch // 2, in_ch // 2, 2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diff = x2.size()[2] - x1.size()[2]
        x1 = F.pad(x1, (diff // 2, diff - diff // 2))
        # for padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1) # dim=1 because we add them on channel dimension 
        x = self.conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, 1)
    def forward(self, x):
        x = self.conv(x)
        return x


class UNet1D(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes, embed_size, n_layers, ngpus=1, starting_layers=32, dropout=None):
        super(UNet1D, self).__init__()
        self.dropout = dropout
        self.n_layers = n_layers
        self.inc = InConv(in_channels+1, starting_layers)
        self.down1 = Down(starting_layers * 1, starting_layers * 2, dropout)  # Only dropout on early layers
        self.down2 = Down(starting_layers * 2, starting_layers * 4, dropout)
        self.down3 = Down(starting_layers * 4, starting_layers * 8, dropout)
        self.down4 = Down(starting_layers * 8, starting_layers * 8, dropout)
        if self.n_layers >= 5:
            self.down5 = Down(starting_layers * 8, starting_layers * 8)
            if self.n_layers >= 6:
                self.down6 = Down(starting_layers * 8, starting_layers * 8)
                self.up6 = Up(starting_layers * 16, starting_layers * 8)
            self.up5 = Up(starting_layers * 16, starting_layers * 8)
        self.up4 = Up(starting_layers * 16, starting_layers * 4)
        self.up3 = Up(starting_layers * 8, starting_layers * 2)
        self.up2 = Up(starting_layers * 4, starting_layers)
        self.up1 = Up(starting_layers * 2, starting_layers)
        self.out = OutConv(starting_layers, out_channels)
        self.embed = nn.Embedding(num_classes, embed_size)
        
#         for m in self.modules():
#             if isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear):
#                 nn.init.kaiming_normal_(m.weight.data)
    def forward(self, x,labels):
        embedding = self.embed(labels).unsqueeze(1)#.unsqueeze(3) # adds 1 x 1 at the end
        x = torch.cat([x, embedding],dim=1)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        if self.n_layers >= 4:
            x5 = self.down4(x4)
            x = x5
            if self.n_layers >= 5:
                x6 = self.down5(x5)
                x = x6
                if self.n_layers >= 6:
                    x7 = self.down6(x6)
                    x = x7
                    x = self.up6(x, x6)
                x = self.up5(x, x5)
            x = self.up4(x, x4)
        x = self.up3(x, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        x = torch.tanh(self.out(x))
   
        # latent vector z: N x noise_dim x 1 x 1
        
        return x


In [21]:
# from torchvision import models
# from torchsummary import summary

# num_classes = 4
# embed_size = 750
# signal_length = 750

# device = "cpu"
# generator = UNet1D(2, 2, 4,750, n_layers=5).to(device)

# labels2 = [2 for i in range(1)]

# batch_size = len(labels2)
# fixed_noise = torch.randn(batch_size,2,750).uniform_(-1, 1).to(device)
# labels2 =  torch.Tensor(labels2).to(device)
# labels2 = labels2.int()
# #labels2 = labels2.to(device)
# #generator.eval()

# with torch.no_grad():
#     generated_signals = generator(fixed_noise,labels2)
#     gener = generated_signals.cpu().detach().numpy()

torch.Size([1, 1, 750])
torch.Size([1, 3, 750])
torch.Size([1, 3, 750])
torch.Size([1, 128, 93])
torch.Size([1, 64, 187])
torch.Size([1, 32, 375])
torch.Size([1, 2, 750])


In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
            #nn.init.normal_(m.weight.data, 0.0, 0.02)
            nn.init.kaiming_normal_(m.weight.data)