In [1]:
import numpy as np
from matplotlib import pyplot as plt

In [2]:
# Load pytorch modules
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.autograd as autograd
import torch.nn.functional as F

In [3]:
# Some other helpful modules
import time
from Bio.Seq import Seq
from Bio.Alphabet import single_letter_alphabet
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO

##### Read In Data

In [None]:
# Read in data
x_train = np.load("/home/pbromley/generative_dhs/data_numpy/one_hot_seqs_train_100.npy")
y_train = np.load("/home/pbromley/generative_dhs/data_numpy/components_train_100.npy")
x_test = np.load("/home/pbromley/generative_dhs/data_numpy/one_hot_seqs_test_100.npy")
y_test = np.load("/home/pbromley/generative_dhs/data_numpy/components_test_100.npy")

y_train = y_train - 1       # raw component data goes from 1-15, want 0-14
y_test = y_test - 1         # "       "       "    "    "    "     "    "

 


In [None]:
x_train = x_train.reshape(-1, 1, 100, 4)
# x_train = x_train[y_train < 5]
# y_train = y_train[y_train < 5]
class_dist = np.bincount(y_train.astype(int))

In [None]:
## For trying with fewer classes
idx = np.array([(elt > 2.0) and (elt < 8.0) for elt in y_train])
x_train = x_train[idx]
y_train = y_train[idx]

##### Set Up Custom Dataset

In [None]:
# Setting up the data (custom collate function for variable length inputs w/in batch)
#   https://discuss.pytorch.org/t/how-to-create-a-dataloader-with-variable-size-input/8278/2

# Custom dataset class to handle loading in of list of numpy arrays
#   Makes it so don't have to write arrays to files and use DatasetFolder
class DHSSequencesDataset(Dataset):
    """DHS sequences of varying length, as well as a component label"""

    def __init__(self, seqs, components, transform=None):
        """
        Args:
            seqs (list/np.array): List of one-hot numpy array DNA sequences
            components (list/np.array): List of integers indicating components 1-15
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.seqs = seqs
        self.components = components
        self.transform = transform

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        one_hot = self.seqs[idx]
        component = self.components[idx]

        if self.transform:
            image = self.transform(one_hot)

        return one_hot, component
    
    
# For loading batches with variable length inputs
def collate_variable_length(batch):
    one_hot = [sample[0] for sample in batch]
    component = [sample[1] for sample in batch]
    component = torch.LongTensor(component)
    return [one_hot, component]



dhs_dataset = DHSSequencesDataset(x_train, y_train)


### Sample balanced class distribution
class_weights = np.sum(class_dist) / class_dist    # calculate inverse probs
weights = [class_weights[int(comp)] for _, comp in dhs_dataset]    # assign weight for every sample
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))  # create sampler

In [None]:
###### For Uniform Length ######
dhs_dataloader = DataLoader(dataset=dhs_dataset, 
                            batch_size=128,
                            sampler=sampler)

###### For Variable Length ######
# dhs_dataloader = DataLoader(dataset=dhs_dataset, 
#                             batch_size=4,
#                             shuffle=True,
#                             collate_fn=collate_variable_length)

##### Models

In [None]:
class cdcgan_generator(nn.Module):
    def __init__(self, channels, nz, num_classes):
        super(cdcgan_generator, self).__init__()
        
        self.embed = nn.Embedding(num_classes, num_classes//2)
        
        self.fc = nn.Linear(nz+(num_classes//2), 256*12*1)

        self.net = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=(7, 1), stride=2,
                     padding=(2, 0), bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=(8, 1), stride=2,
                     padding=(3, 0), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=(8, 1), stride=2,
                     padding=(3, 0), bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, channels, kernel_size=(16, 4), stride=2,
                     padding=(7, 0), bias=False),
            nn.Softmax(dim=3)
        )
        
    def forward(self, noise, label):
        embed = self.embed(label)
        concat = torch.cat([noise, embed], 1)
        fc = self.fc(concat).view(-1, 256, 12, 1)
        output = self.net(fc)
        return output
    
    
class cdcgan_discriminator(nn.Module):
    def __init__(self, channels, num_classes):
        super(cdcgan_discriminator, self).__init__()
        
        self.channels = channels
        
        self.embed = nn.Embedding(num_classes, 10)
        
        self.label_upsample = nn.Sequential(
            nn.ConvTranspose2d(1, 8, (20, 1), 10, (5, 0), bias=False),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(8, 1, (8, 4), 2, (3, 0), bias=False),
            nn.BatchNorm2d(1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.net = nn.Sequential(
            nn.Conv2d(channels+1, 32, kernel_size=(15, 4), stride=1, padding=(7, 0), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(p=0.5),
            nn.Conv2d(32, 64, kernel_size=(8, 1), stride=2, padding=(3, 0), bias=False), 
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(p=0.5),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=(8, 1), stride=2, padding=(3, 0), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(p=0.5),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, kernel_size=(8, 1), stride=2, padding=(3, 0), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(p=0.5),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 1, kernel_size=(25, 1), stride=1, padding=(0, 0), bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, seq, label):
        embed = self.embed(label).view(-1, 1, 10, 1)
        label_matrix = self.label_upsample(embed)
        concat = torch.cat([seq.view(-1, 1, 200, 4), label_matrix], 1)
        output = self.net(concat).squeeze()
        return output
    
    
    

class DHS_cDCGAN():
    def __init__(self, channels, bs, nz, nc, dataloader, use_cuda=True, init_weights=False):
        
        self.channels = channels
        self.bs = bs
        self.nz = nz
        self.nc = nc
        self.dataloader = dataloader
        self.use_cuda = use_cuda
        
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        
        self.G = cdcgan_generator(self.channels, self.nz, self.nc).to(self.device)
        self.D = cdcgan_discriminator(self.channels, self.nc).to(self.device)
        
        self.criterion = nn.BCELoss().to(self.device)
        self.optD = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.99))
        self.optG = optim.Adam(self.G.parameters(), lr=0.001, betas=(0.5, 0.99))
        
        for l in [self.G, self.D]:
            if init_weights:
                l.apply(self.weights_init)
        
        self.train_hist = {}
        self.train_hist['d_loss'] = []
        self.train_hist['g_loss'] = []
        self.train_hist['epoch_time'] = []
        self.train_hist['total_time'] = []
        
        
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.detach().normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.detach().normal_(1.0, 0.02)
            m.bias.detach().fill_(0)
        
        
    def create_fixed_inputs(self, num_total_imgs):
        # For saving imgs
        fixed_noise = torch.Tensor(num_total_imgs, self.nz).normal_(0, 1).to(self.device)

        np_fixed_c = np.arange(0, self.nc, 1)
        fixed_c = torch.from_numpy(np.repeat(np_fixed_c, num_total_imgs//self.nc, axis=0)).long().to(self.device)

        return fixed_noise, fixed_c
        
    
    def update_train_hist(self, d_loss, g_loss):
        self.train_hist['d_loss'].append(d_loss.item())
        self.train_hist['g_loss'].append(g_loss.item())
        
    
    def plot_loss(self, path):
        x = range(len(self.train_hist['d_loss']))
        d_loss_hist = self.train_hist['d_loss']
        g_loss_hist = self.train_hist['g_loss']
        plt.figure(figsize=(8, 8))
        plt.plot(x, d_loss_hist, label='d_loss')
        plt.plot(x, g_loss_hist, label='g_loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.legend(loc=4)
        plt.savefig(path)
        plt.close()
    
    
    def save_imgs(self, f_noise, f_c, epoch):
        with torch.no_grad():
            images = self.G(f_noise, f_c)
        to_save = images.transpose(2, 3)
        img_path = '/home/pbromley/generative_dhs/images/pytorch-cdcgan-dhs-200/%d.png' % epoch
        vutils.save_image(to_save, img_path, nrow=1, normalize=True)
        
        
    
    def train(self, epochs=100, save_model=False, plot_loss=False):

        noise = torch.zeros(self.bs, self.nz).to(self.device)
        label = torch.zeros(self.bs).to(self.device)
        
        fixed_noise, fixed_c = self.create_fixed_inputs(45)
        
        start_time = time.time()
        for epoch in range(epochs):
            noise.resize_(self.bs, self.nz)
            label.resize_(self.bs)
            epoch_start_time = time.time()
            for i, (x, c) in enumerate(self.dataloader):
                
                # DISCRIMINATOR
                self.optD.zero_grad()
                x, c = x.float().to(self.device), c.long().to(self.device)
                
                if x.size(0) != self.bs:
                    noise.resize_(x.size(0), self.nz)
                    label.resize_(x.size(0))
                
                pred_real = self.D(x, c)
                label.fill_(1)
                d_loss_real = self.criterion(pred_real, label)
                d_loss_real.backward()
                
                noise.normal_(0, 1)
                fake = self.G(noise, c)
                pred_fake = self.D(fake.detach(), c)
                label.fill_(0)
                d_loss_fake = self.criterion(pred_fake, label)
                d_loss_fake.backward()
                
                d_loss_total = d_loss_real + d_loss_fake
                
                self.optD.step()
                
                
                # GENERATOR
                self.optG.zero_grad()
                pred_g = self.D(fake, c)
                label.fill_(1)
                g_loss = self.criterion(pred_g, label)
                
                g_loss.backward()
                
                self.optG.step()
                
                
                if plot_loss:
                    self.update_train_hist(d_loss_total, g_loss)
                    
                if i % 1000 == 0:
                    print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format(
                            epoch, i, d_loss_total.item(), g_loss.item())
                         )
            
            self.save_imgs(fixed_noise, fixed_c, epoch)
            self.train_hist['epoch_time'].append(time.time() - epoch_start_time)
            print("Time to complete epoch: " + str(self.train_hist['epoch_time'][-1]))
            
            
        print("Training is complete!")
        if save_model:
            path = "/home/pbromley/generative_dhs/saved_models/pytorch-cdcgan-dhs-200"
            print("Saving Model Weights...")
            torch.save(self.G.state_dict(), path + "-g.pth")
            torch.save(self.D.state_dict(), path + "-d.pth")
            print("Saved Model Weights")

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Total training time (%d epochs): " % epochs + str(self.train_hist['total_time'][0]))
        avg_epoch_time = np.mean(self.train_hist['epoch_time'])
        print("Average time for each epoch: %.2f" % avg_epoch_time)

        if plot_loss:
            self.plot_loss("/home/pbromley/generative_dhs/loss_plots/pytorch-cdcgan-dhs-200.png")
                
                
        

In [None]:
class twoNgan_generator(nn.Module):
    def __init__(self, channels, nz, num_classes):
        super(twoNgan_generator, self).__init__()
        
        self.embed = nn.Embedding(num_classes, num_classes//2)
        
        self.fc = nn.Linear(nz+(num_classes//2), 64*12*1)
        self.relu = nn.ReLU(inplace=True)

        self.net = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=(7, 1), stride=2,
                     padding=(2, 0), bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, kernel_size=(8, 1), stride=2,
                     padding=(3, 0), bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(16, 8, kernel_size=(8, 1), stride=2,
                     padding=(3, 0), bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(8, channels, kernel_size=(16, 4), stride=2,
                     padding=(7, 0), bias=False),
            nn.Softmax(dim=3)
            
#             nn.ConvTranspose2d(256, 256, kernel_size=(40, 1), stride=40, bias=False),
#             nn.ConvTranspose2d(256, 1, kernel_size=(25, 4), stride=1,
#                      padding=(12, 0), bias=False),
#             nn.Softmax(dim=3)
        )
        
    def forward(self, noise, label):
        embed = self.embed(label)
        concat = torch.cat([noise, embed], 1)
        fc = self.fc(concat)
        relu = self.relu(fc).view(-1, 64, 12, 1)
        output = self.net(relu)
        return output
    
    
class twoNgan_discriminator(nn.Module):
    def __init__(self, channels, num_classes):
        super(twoNgan_discriminator, self).__init__()
        
        self.net = nn.Sequential(
#             nn.Conv2d(1, 32, kernel_size=(15, 4), stride=1, padding=(7, 0), bias=False),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(32, 64, kernel_size=(15, 1), stride=3, padding=(8, 0), bias=False), 
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.BatchNorm2d(64),
#             nn.Conv2d(64, 128, kernel_size=(15, 1), stride=3, padding=(8, 0), bias=False),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.BatchNorm2d(128),
# #             nn.Conv2d(128, 256, kernel_size=(8, 1), stride=2, padding=(3, 0), bias=False),
# #             nn.LeakyReLU(0.1, inplace=True),
# #             nn.BatchNorm2d(256),
#             nn.Conv2d(128, 256, kernel_size=(15, 1), stride=1, padding=(0, 0), bias=False),
#             nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1, 320, kernel_size=(25, 4), stride=1, padding=(12, 0), bias=False),
            nn.MaxPool2d(kernel_size=(40, 1)),
            nn.BatchNorm2d(320),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.5),
#             nn.Conv2d(128, 256, kernel_size=(9, 1), stride=1, padding=(4, 0), bias=False),
#             nn.MaxPool2d(kernel_size=(4, 1)),
#             nn.BatchNorm2d(256),
#             nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.dense_net = nn.Sequential(
            nn.Linear(320*5, 256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes*2),
            nn.Softmax(dim=1)
        )
        
        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(25, 4), padding=(12, 0)),
            nn.MaxPool2d(kernel_size=(40, 1)),
            nn.ReLU(True),
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(16 * 5, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 16 * 5)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x
        
    def forward(self, seq):
        seq = self.stn(seq)
        net = self.net(seq)
        output = self.dense_net(net.view(-1, 320 * 5))
        return output
    
    
    

class DHS_TwoNGAN():
    def __init__(self, channels, bs, nz, nc, dataloader, use_cuda=True, init_weights=True):
        
        self.channels = channels
        self.bs = bs
        self.nz = nz
        self.nc = nc
        self.dataloader = dataloader
        self.use_cuda = use_cuda
        
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        
        self.G = twoNgan_generator(self.channels, self.nz, self.nc).to(self.device)
        self.D = twoNgan_discriminator(self.channels, self.nc).to(self.device)
        
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        self.optD = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.99))
        self.optG = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.99))
        
        for l in [self.G, self.D]:
            if init_weights:
                l.apply(self.weights_init)
        
        self.train_hist = {}
        self.train_hist['d_loss'] = []
        self.train_hist['g_loss'] = []
        self.train_hist['epoch_time'] = []
        self.train_hist['total_time'] = []
        
        
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.detach().normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.detach().normal_(1.0, 0.02)
            m.bias.detach().fill_(0)
        
        
    def create_fixed_inputs(self, num_total_imgs):
        # For saving imgs
        fixed_noise = torch.Tensor(num_total_imgs, self.nz).normal_(0, 1).to(self.device)

        np_fixed_c = np.arange(0, self.nc, 1)
        fixed_c = torch.from_numpy(np.repeat(np_fixed_c, num_total_imgs//self.nc, axis=0)).long().to(self.device)

        return fixed_noise, fixed_c
        
    
    def update_train_hist(self, d_loss, g_loss):
        self.train_hist['d_loss'].append(d_loss.item())
        self.train_hist['g_loss'].append(g_loss.item())
        
    
    def plot_loss(self, path):
        x = range(len(self.train_hist['d_loss']))
        d_loss_hist = self.train_hist['d_loss']
        g_loss_hist = self.train_hist['g_loss']
        plt.figure(figsize=(8, 8))
        plt.plot(x, d_loss_hist, label='d_loss')
        plt.plot(x, g_loss_hist, label='g_loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.legend(loc=4)
        plt.savefig(path)
        plt.close()
    
    
    def save_imgs(self, f_noise, f_c, epoch):
        with torch.no_grad():
            images = self.G(f_noise, f_c)
        to_save = images.transpose(2, 3)
        img_path = '/home/pbromley/generative_dhs/images/pytorch-twongan-dhs-200/2-%d.png' % epoch
        vutils.save_image(to_save, img_path, nrow=1, normalize=True)
        
        
    
    def train(self, epochs=100, save_model=False, plot_loss=False):

        noise = torch.zeros(self.bs, self.nz).to(self.device)
        
        fixed_noise, fixed_c = self.create_fixed_inputs(45)
        
        start_time = time.time()
        for epoch in range(epochs):
            noise.resize_(self.bs, self.nz)
            epoch_start_time = time.time()
            for i, (x, c) in enumerate(self.dataloader):
                
                # DISCRIMINATOR
                self.optD.zero_grad()
                x, c = x.float().to(self.device), c.long().to(self.device)
                
                if x.size(0) != self.bs:
                    noise.resize_(x.size(0), self.nz)
                
                pred_real = self.D(x)
                d_loss_real = self.criterion(pred_real, c)
                d_loss_real.backward()
                
                noise.normal_(0, 1)
                fake_c = torch.randint_like(c, 0, self.nc)               # fake random classes
                fake_c_for_d = fake_c + self.nc                          # fake class to be fed to D loss calc
                fake = self.G(noise, fake_c)
                pred_fake = self.D(fake.detach())
                d_loss_fake = self.criterion(pred_fake, fake_c_for_d)
                d_loss_fake.backward()
                
                d_loss_total = d_loss_real + d_loss_fake
                
                self.optD.step()
        
                # GENERATOR
                self.optG.zero_grad()
                pred_g = self.D(fake)
                g_loss = self.criterion(pred_g, fake_c)
                
                g_loss.backward()
                
                self.optG.step()
                
                
                if plot_loss:
                    self.update_train_hist(d_loss_total, g_loss)
                    
                if i % 2000 == 0:
                    print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format(
                            epoch, i, d_loss_total.item(), g_loss.item())
                         )
            
            self.save_imgs(fixed_noise, fixed_c, epoch)
            self.train_hist['epoch_time'].append(time.time() - epoch_start_time)
            print("Time to complete epoch: " + str(self.train_hist['epoch_time'][-1]))
            
            
        print("Training is complete!")
        if save_model:
            path = "/home/pbromley/generative_dhs/saved_models/pytorch-twongan-dhs-200"
            print("Saving Model Weights...")
            torch.save(self.G.state_dict(), path + "-g-2.pth")
            torch.save(self.D.state_dict(), path + "-d-2.pth")
            print("Saved Model Weights")

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Total training time (%d epochs): " % epochs + str(self.train_hist['total_time'][0]))
        avg_epoch_time = np.mean(self.train_hist['epoch_time'])
        print("Average time for each epoch: %.2f" % avg_epoch_time)

        if plot_loss:
            self.plot_loss("/home/pbromley/generative_dhs/loss_plots/pytorch-twongan-dhs-200-2.png")
                
                
        

In [None]:
class acgan_generator(nn.Module):
    def __init__(self, channels, nz, num_classes):
        super(acgan_generator, self).__init__()
        
        self.embed = nn.Embedding(num_classes, num_classes//2)
        
        self.fc = nn.Linear(nz+(num_classes//2), 256*12*1)
        self.relu = nn.ReLU(inplace=True)

        self.net = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=(7, 1), stride=2,
                     padding=(2, 0), bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=(8, 1), stride=2,
                     padding=(3, 0), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=(8, 1), stride=2,
                     padding=(3, 0), bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, channels, kernel_size=(16, 4), stride=2,
                     padding=(7, 0), bias=False),
            nn.Softmax(dim=3)
            
#             nn.ConvTranspose2d(256, 256, kernel_size=(40, 1), stride=40, bias=False),
#             nn.ConvTranspose2d(256, 1, kernel_size=(25, 4), stride=1,
#                      padding=(12, 0), bias=False),
#             nn.Softmax(dim=3)
        )
        
    def forward(self, noise, label):
        embed = self.embed(label)
        concat = torch.cat([noise, embed], 1)
        fc = self.fc(concat)
        relu = self.relu(fc).view(-1, 256, 12, 1)
        output = self.net(relu)
        return output
    
    
class acgan_discriminator(nn.Module):
    def __init__(self, channels, num_classes):
        super(acgan_discriminator, self).__init__()
        
        self.net = nn.Sequential(
#             nn.Conv2d(1, 32, kernel_size=(15, 4), stride=1, padding=(7, 0), bias=False),
#             nn.LeakyReLU(0.1, inplace=True),
#             nn.Conv2d(32, 64, kernel_size=(15, 1), stride=3, padding=(8, 0), bias=False), 
#             nn.LeakyReLU(0.1, inplace=True),
#             nn.BatchNorm2d(64),
#             nn.Conv2d(64, 128, kernel_size=(15, 1), stride=3, padding=(8, 0), bias=False),
#             nn.LeakyReLU(0.1, inplace=True),
#             nn.BatchNorm2d(128),
# #             nn.Conv2d(128, 256, kernel_size=(8, 1), stride=2, padding=(3, 0), bias=False),
# #             nn.LeakyReLU(0.1, inplace=True),
# #             nn.BatchNorm2d(256),
#             nn.Conv2d(128, 256, kernel_size=(15, 1), stride=1, padding=(0, 0), bias=False),
#             nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(1, 320, kernel_size=(25, 4), stride=1, padding=(12, 0), bias=False),
            nn.MaxPool2d(kernel_size=(40, 1)),
            nn.LayerNorm([320, 5, 1]),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        
        self.dense_net = nn.Sequential(
            nn.Linear(320*5, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
        )
        
        self.pred = nn.Sequential(
            nn.Linear(256, 1)
        )
        
        self.aux = nn.Sequential(
            nn.Linear(256, num_classes),
            nn.Softmax(dim=1)
        )
        
    def forward(self, seq):
        net = self.net(seq)
        dense = self.dense_net(net.view(-1, 320 * 5))
        pred = self.pred(dense)
        aux = self.aux(dense)
        return pred, aux
    
    
    

class DHS_ACGAN():
    def __init__(self, channels, bs, nz, nc, dataloader, use_cuda=True, init_weights=True):
        
        self.channels = channels
        self.bs = bs
        self.nz = nz
        self.nc = nc
        self.dataloader = dataloader
        self.use_cuda = use_cuda
        
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        
        self.G = acgan_generator(self.channels, self.nz, self.nc).to(self.device)
        self.D = acgan_discriminator(self.channels, self.nc).to(self.device)
        
        self.criterion_pred = nn.BCELoss().to(self.device)
        self.criterion_aux = nn.CrossEntropyLoss().to(self.device)
        self.optD = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.99))
        self.optG = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.99))
        
        for l in [self.G, self.D]:
            if init_weights:
                l.apply(self.weights_init)
        
        self.train_hist = {}
        self.train_hist['d_loss'] = []
        self.train_hist['g_loss'] = []
        self.train_hist['epoch_time'] = []
        self.train_hist['total_time'] = []
        
        
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.detach().normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.detach().normal_(1.0, 0.02)
            m.bias.detach().fill_(0)
        
        
    def create_fixed_inputs(self, num_total_imgs):
        # For saving imgs
        fixed_noise = torch.Tensor(num_total_imgs, self.nz).normal_(0, 1).to(self.device)

        np_fixed_c = np.arange(0, self.nc, 1)
        fixed_c = torch.from_numpy(np.repeat(np_fixed_c, num_total_imgs//self.nc, axis=0)).long().to(self.device)

        return fixed_noise, fixed_c
        
    
    def update_train_hist(self, d_loss, g_loss):
        self.train_hist['d_loss'].append(d_loss.item())
        self.train_hist['g_loss'].append(g_loss.item())
        
    
    def plot_loss(self, path):
        x = range(len(self.train_hist['d_loss']))
        d_loss_hist = self.train_hist['d_loss']
        g_loss_hist = self.train_hist['g_loss']
        plt.figure(figsize=(8, 8))
        plt.plot(x, d_loss_hist, label='d_loss')
        plt.plot(x, g_loss_hist, label='g_loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.legend(loc=4)
        plt.savefig(path)
        plt.close()
    
    
    def save_imgs(self, f_noise, f_c, epoch):
        with torch.no_grad():
            images = self.G(f_noise, f_c)
        to_save = images.transpose(2, 3)
        img_path = '/home/pbromley/generative_dhs/images/pytorch-acgan-dhs-200/%d.png' % epoch
        vutils.save_image(to_save, img_path, nrow=1, normalize=True)
        
    # From https://github.com/jalola/improved-wgan-pytorch/blob/master/congan_train.py
    def calc_gradient_penalty(self, x, fake_x):
        bs = x.size(0)
        alpha = torch.rand(bs, 1)
        alpha = alpha.expand(bs, int(x.nelement()/bs)).contiguous()
        alpha = alpha.view(bs, 1, 200, 4)
        alpha = alpha.to(self.device)

        fake_x = fake_x.view(bs, 1, 200, 4)
        interpolates = alpha * x.detach() + ((1 - alpha) * fake_x.detach())

        interpolates = interpolates.to(self.device)
        interpolates.requires_grad_(True)   

        disc_interpolates, _ = self.D(interpolates)

        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(disc_interpolates.size()).to(self.device),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]

        gradients = gradients.view(gradients.size(0), -1)                              
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
        return gradient_penalty
        
        
    
    def train(self, epochs=100, save_model=False, plot_loss=False):

        noise = torch.zeros(self.bs, self.nz).to(self.device)
        label = torch.zeros(self.bs).to(self.device)
        
        fixed_noise, fixed_c = self.create_fixed_inputs(45)
        
        start_time = time.time()
        for epoch in range(epochs):
            epoch_start_time = time.time()
            for i, (x, c) in enumerate(self.dataloader):
                
                
                self.optD.zero_grad()
                x, c = x.float().to(self.device), c.long().to(self.device)
                bs = x.size(0)//5
                
                # GENERATOR
                self.optG.zero_grad()
                
                noise.resize_(bs, self.nz)
                

                noise.normal_(0, 1)
                fake_c = torch.randint_like(c[0:bs], 0, self.nc)               # fake random classes
                fake = self.G(noise, fake_c)
                
                pred_g, pred_g_aux = self.D(fake)
             
                g_loss_pred = -pred_g.mean()
                g_loss_aux = self.criterion_aux(pred_g_aux, fake_c).mean()
                g_loss = g_loss_pred + g_loss_aux
                g_loss.backward()
                
                self.optG.step()
                
                
                # DISCRIMINATOR
                for j in range(5):
                    batch = x[((j)*bs):((j+1)*bs)]
                    pred_real, pred_real_aux = self.D(batch)
                 
                    d_loss_real_pred = pred_real.mean()
                    d_loss_real_aux = self.criterion_aux(pred_real_aux, c[(j)*bs:(j+1)*bs]).mean()
                    
                    noise.normal_(0, 1)
                    fake_c_d = torch.randint_like(fake_c, 0, self.nc)
                    fake_d = self.G(noise, fake_c_d)
                    
                    pred_fake, pred_fake_aux = self.D(fake_d.detach())
                    d_loss_fake_pred = pred_fake.mean()
                    gradient_penalty = self.calc_gradient_penalty(batch, fake_d)

                    d_loss_total = d_loss_fake_pred - d_loss_real_pred + gradient_penalty + d_loss_real_aux
                    d_loss_total.backward()
                    w_dist = d_loss_fake_pred - d_loss_real_pred

                    self.optD.step()
        
                
                
                
                if plot_loss:
                    self.update_train_hist(d_loss_total, g_loss)
                    
                if i % 200 == 0:
                    print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}, WDist: {4}'.format(
                            epoch, i, d_loss_total.item(), g_loss.item(), w_dist.item())
                         )
            
            self.save_imgs(fixed_noise, fixed_c, epoch)
            self.train_hist['epoch_time'].append(time.time() - epoch_start_time)
            print("Time to complete epoch: " + str(self.train_hist['epoch_time'][-1]))
            
            
        print("Training is complete!")
        if save_model:
            path = "/home/pbromley/generative_dhs/saved_models/pytorch-acgan-dhs-200"
            print("Saving Model Weights...")
            torch.save(self.G.state_dict(), path + "-g.pth")
            torch.save(self.D.state_dict(), path + "-d.pth")
            print("Saved Model Weights")

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Total training time (%d epochs): " % epochs + str(self.train_hist['total_time'][0]))
        avg_epoch_time = np.mean(self.train_hist['epoch_time'])
        print("Average time for each epoch: %.2f" % avg_epoch_time)

        if plot_loss:
            self.plot_loss("/home/pbromley/generative_dhs/loss_plots/pytorch-acgan-dhs-200.png")
                
                
        

##### Train

In [None]:
dhs_cdcgan = DHS_cDCGAN(1, 128, 100, 15, 
                        dhs_dataloader, init_weights=False)

In [None]:
dhs_cdcgan.train(epochs=50, save_model=True, plot_loss=True)

In [None]:
dhs_twongan = DHS_TwoNGAN(1, 128, 70, 15, dhs_dataloader, init_weights=True)

In [None]:
dhs_twongan.train(epochs=50, save_model=True, plot_loss=True)

In [None]:
dhs_acgan = DHS_ACGAN(1, 128, 500, 15, dhs_dataloader, init_weights=True)

In [None]:
dhs_acgan.train(epochs=50, save_model=True, plot_loss=True)

In [None]:
dhs_twongan.D.to("cpu")
dhs_twongan.D.double()
# D = twoNgan_discriminator(1, 15)
# D.load_state_dict(torch.load("/home/pbromley/generative_dhs/saved_models/pytorch-twongan-dhs-200-d.pth"))

In [None]:
pred = dhs_twongan.D(torch.from_numpy((x_train).reshape(-1, 1, 200, 4)[10000:15000]))

target = torch.from_numpy(y_train[10000:15000]).long()

max_index = pred.max(dim=1)[1]

(max_index == target).sum()

In [None]:
filters = dhs_twongan.D.state_dict()['net.0.weight'].data.cpu().numpy()
filters = filters.squeeze()
f, ax = plt.subplots(5, 5, figsize=(50, 50))
for i in range(5):
    for j in range(5):
        ax[i][j].imshow(filters[i*5+j].transpose(), vmin=0, cmap="Greens")

##### Analyze

In [None]:
# Read in pretrained models if necessary
# G = cdcgan_generator(1, 100, 15)
# D = cdcgan_discriminator(1, 15)
# G.load_state_dict(torch.load("/home/pbromley/generative_dhs/saved_models/pytorch-cdcgan-dhs-200-g.pth"))
# D.load_state_dict(torch.load("/home/pbromley/generative_dhs/saved_models/pytorch-cdcgan-dhs-200-d.pth"))

# Otherwise get G and D weights from trained model class
G = dhs_twongan.G
D = dhs_twongan.D

In [None]:
torch.load("/home/pbromley/generative_dhs/saved_models/pytorch-twongan-dhs-200-d.pth").keys()#['dense_net.1.weight'].size()

In [None]:
G.double()
G.to("cpu")

In [None]:

G.state_dict()['net.9.weight'].shape

In [None]:
G.state_dict()

In [None]:
filters = G.state_dict()['conv3.weight'].data.cpu().numpy()
filters = filters.squeeze()

In [None]:
filters[0, 0]

In [None]:
f, ax = plt.subplots(5, 5, figsize=(50, 50))
for i in range(5):
    for j in range(5):
        ax[i][j].imshow(filters[i*5+j].transpose(), vmin=0, cmap="Greens")

In [None]:
noise = torch.from_numpy(np.random.normal(0, 1, (75000, 100)))
components = torch.from_numpy(np.repeat(np.arange(15), 5000)).long()
with torch.no_grad():
    one_hot_seqs = G(noise, components).numpy()

one_hot_seqs = one_hot_seqs.reshape(-1, 200, 4)
    
def one_hot_to_seq(one_hot):
    order_dict = {0:'A', 1:'T', 2:'C', 3:'G'}
    seq = ""
    idxs = np.argmax(one_hot, axis=1)
    for elt in idxs:
        seq += order_dict[elt]
    return Seq(seq, single_letter_alphabet)

print("Converting one-hot to normal...")
seqs = [one_hot_to_seq(one_hot_seq) for one_hot_seq in one_hot_seqs]
components = components.numpy()

In [None]:
print(np.round(one_hot_seqs[-1]))

In [None]:
motifs = ["GGCGC", "ATGAGTCAT", "CGAAACCGAAAC", "TGATGCAA", "ATGA", "ATTGT",
          "AACCGGTT", "AATTA", "AAATAG", "GTCACGCTT", "GTAAACA", "AACAGCTGT", "CAAAGT", "CGGAT",
          "ACTTCC"]
motifs = [Seq(motif, single_letter_alphabet) for motif in motifs]

# def motif_search(seqs, components, motifs):
count_dict = {num:np.zeros(15) for num in range(15)}
print("Motif: ", end=" ")
for i, motif in enumerate(motifs):
    print(str(i) + ":" + motif, end=' ')
    for j, seq in enumerate(seqs):
        count_dict[components[j]][i] += seq.count(motif)

count_mat = np.array(list(count_dict.values()))   
plt.xticks(np.arange(15), [str(x+1) for x in range(15)])
plt.yticks(np.arange(15), [str(x+1) for x in range(15)])
plt.imshow(count_mat/count_mat.max(axis=0), cmap='Greens')
plt.colorbar()
plt.xlabel("Motifs (for each component)")
plt.ylabel("Sequences (separated by component)")
plt.show()


    
# motif_search(seqs, components, motifs)

In [None]:
np.set_printoptions(precision=2, linewidth=250)
count_mat/count_mat.max(axis=0)

In [None]:
# ALIGNMENT SCORE
from Bio import pairwise2

In [None]:
pairwise2.align.globalxx(seqs[0], seqs[4])

In [None]:
align_sum = {k:0 for k in range(15)}
totals = {k:0 for k in range(15)}
for i in range(50000):
    align_sum[y_train[i+1]] += pairwise2.align.globalxx(x_train_seqs[0], x_train_seqs[i+1], score_only=True)
    totals[y_train[i+1]] += 1
    

In [None]:
[align_sum[]

In [None]:
x_train_seqs = [one_hot_to_seq(x_one_hot_seq) for x_one_hot_seq in x_train.reshape(-1, 200, 4)]

In [None]:
pairwise2.align.globalxx(x_train_seqs[-2], x_train_seqs[-1], score_only=True)

In [None]:
y_train

In [None]:
torch.from_numpy(np.random.normal(0, 1, (75000, 100))).requires_grad

In [None]:
count_mat.astype(int)

In [None]:
def disp_seqs(seqs, idx):
    num_seqs = len(idx)
    f, ax = plt.subplots(num_seqs, figsize=(10, 2*num_seqs), sharex=True)
    for i in idx:
        if num_seqs > 1:
            ax[i].imshow(x_train[i].transpose(), cmap="Blues")
        else:
            ax.imshow(x_train[idx[0]].transpose(), cmap="Blues")  
    f.subplots_adjust(bottom=0.7)
    f.savefig("/home/pbromley/test.png")

In [None]:
f, ax = plt.subplots(2, figsize=(100, 4), sharex=True)
for i in range(2):
    ax[i].imshow(x_train[i].transpose(), cmap="Blues")
plt.subplots_adjust(bottom=0.7)

In [None]:
x = range(len(dhs_cdcgan.train_hist['d_loss']))
d_loss_hist = dhs_cdcgan.train_hist['d_loss']
g_loss_hist = dhs_cdcgan.train_hist['g_loss']
plt.figure(figsize=(8, 8))
plt.scatter(x, [float(d.cpu()) for d in d_loss_hist], label='d_loss')
# plt.scatter(x, [float(g.cpu()) for g in g_loss_hist], label='g_loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend(loc=4)
plt.savefig("/home/pbromley/generative_dhs/loss_plots/pytorch-cdcgan-dhs-200-d.png")
plt.close()

In [None]:
x = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=(7, 1), stride=2,
                     padding=(2, 0), bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=(8, 1), stride=2,
                     padding=(3, 0), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=(8, 1), stride=2,
                     padding=(3, 0), bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 1, kernel_size=(16, 4), stride=2,
                     padding=(7, 0), bias=False),
            nn.Softmax(dim=3)
        )

In [None]:
x(Variable(torch.rand(1, 256, 12, 1))).size()

In [None]:
x =nn.Embedding(15, 10)

In [None]:
x(Variable(torch.LongTensor([1]))).size()


In [None]:
x = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(15, 4), stride=1, padding=(7, 0), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, kernel_size=(15, 1), stride=3, padding=(8, 0), bias=False), 
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=(15, 1), stride=3, padding=(8, 0), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
#             nn.Conv2d(128, 256, kernel_size=(8, 1), stride=2, padding=(3, 0), bias=False),
#             nn.LeakyReLU(0.1, inplace=True),
#             nn.BatchNorm2d(256),
            nn.Conv2d(128, 256, kernel_size=(15, 1), stride=1, padding=(0, 0), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
#             nn.ConvTranspose2d(256, 256, kernel_size=(40, 1), stride=40, bias=False),
#             nn.ConvTranspose2d(256, 1, kernel_size=(25, 4), stride=1,
#                      padding=(12, 0), bias=False),
#             nn.Softmax(dim=3)
        )
        

In [None]:
x(torch.rand(1, 1, 200, 4)).size()

In [None]:
320 * 10

##### VAE

In [None]:
# inspired by https://github.com/pytorch/examples/blob/master/vae/main.py
class encoder(nn.Module):
    def __init__(self):
        super(encoder, self).__init__()
        
        self.fc1 = nn.Linear(800, 400)
        self.relu = nn.ReLU(True)
        self.bn = nn.BatchNorm1d(400)
        self.fc21 = nn.Linear(400, 200)
        self.fc22 = nn.Linear(400, 200)
    
    
    def forward(self, x):
        fc1 = self.fc1(x)
        relu = self.relu(fc1)
        bn = self.bn(relu)
        fc21 = self.fc21(bn)
        fc22 = self.fc22(bn)
        return fc21, fc22
    

class decoder(nn.Module):
    def __init__(self):
        super(decoder, self).__init__()
        
        self.fc1 = nn.Linear(200, 400)
        self.relu = nn.ReLU(True)
        self.bn = nn.BatchNorm1d(400)
        self.fc2 = nn.Linear(400, 800)
        self.softmax = nn.Softmax(dim=2)
        
        
    def forward(self, z):
        fc1 = self.fc1(z)
        relu = self.relu(fc1)
        bn = self.bn(relu)
        fc2 = self.fc2(bn).view(-1, 200, 4)
        softmax = self.softmax(fc2)
        return softmax
    
    
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.encoder = encoder()
        self.decoder = decoder()
        
    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu
        
        
    def forward(self, x):
        mu, logvar = self.encoder(x.view(-1, 800))
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar
    

device = "cuda"
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)


def loss_function(recon_x, x, mu, logvar, bs):
    BCE = F.binary_cross_entropy(recon_x.view(-1, 800), x.view(-1, 800))
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # Normalise by same number of elements as in reconstruction
    KLD /= bs * 784
  
    return BCE + KLD


def save_imgs(real_x, epoch):
    with torch.no_grad():
        recon_x, _, _ = model(real_x)
    to_save = recon_x.transpose(1, 2).view(-1, 1, 4, 200)
    img_path_real = '/home/pbromley/generative_dhs/images/pytorch-vae-dhs-200/real-%d.png' % epoch
    img_path_fake = '/home/pbromley/generative_dhs/images/pytorch-vae-dhs-200/fake-%d.png' % epoch
    vutils.save_image(real_x.transpose(1, 2).view(-1, 1, 4, 200), img_path_real, nrow=1, normalize=True)
    vutils.save_image(to_save, img_path_fake, nrow=1, normalize=True)



def train(epoch, save):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dhs_dataloader):
        data = data.to(device).float()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar, data.size(0))
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 1000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(dhs_dataloader.dataset),
                100. * batch_idx / len(dhs_dataloader),
                loss.item() / len(data)))
        

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(dhs_dataloader.dataset)))
    save_imgs(save, epoch)
    


# def test(epoch):
#     model.eval()
#     test_loss = 0
#     with torch.no_grad():
#         for i, (data, _) in enumerate(test_loader):
#             data = data.to(device)
#             recon_batch, mu, logvar = model(data)
#             test_loss += loss_function(recon_batch, data, mu, logvar).item()
#             if i == 0:
#                 n = min(data.size(0), 8)
#                 comparison = torch.cat([data[:n],
#                                       recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
#                 save_image(comparison.cpu(),
#                          'results/reconstruction_' + str(epoch) + '.png', nrow=n)

#     test_loss /= len(test_loader.dataset)
#     print('====> Test set loss: {:.4f}'.format(test_loss))

save = torch.from_numpy(x_test[0:10]).float().to(device)
for epoch in range(50):
    train(epoch, save)
    
#     test(epoch)
#     with torch.no_grad():
#         sample = torch.randn(64, 20).to(device)
#         sample = model.decode(sample).cpu()
#         save_image(sample.view(64, 1, 28, 28),
#                    'results/sample_' + str(epoch) + '.png')

        
        

### SN-GAN w/ Projection Discriminator

In [None]:
# Helpful PyTorch code from external sources:


###################################################################################
##                            Conditional Batch Norm                             ##
## (https://github.com/ap229997/Conditional-Batch-Norm/blob/master/model/cbn.py) ##
###################################################################################
'''
CBN (Conditional Batch Normalization layer)
    uses an MLP to predict the beta and gamma parameters in the batch norm equation
    Reference : https://papers.nips.cc/paper/7237-modulating-early-visual-processing-by-language.pdf
'''
class CBN(nn.Module):

    def __init__(self, lstm_size, emb_size, out_size, use_betas=True, use_gammas=True, eps=1.0e-5):
        super(CBN, self).__init__()

        self.lstm_size = lstm_size # size of the lstm emb which is input to MLP
        self.emb_size = emb_size # size of hidden layer of MLP
        self.out_size = out_size # output of the MLP - for each channel
        self.use_betas = use_betas
        self.use_gammas = use_gammas

        self.batch_size = 16
        self.channels = out_size
        self.height = 100
        self.width = 4

        # beta and gamma parameters for each channel - defined as trainable parameters
        self.betas = nn.Parameter(torch.zeros(self.batch_size, self.channels).cuda())
        self.gammas = nn.Parameter(torch.ones(self.batch_size, self.channels).cuda())
        self.eps = eps

        # MLP used to predict betas and gammas
        self.fc_gamma = nn.Sequential(
            nn.Linear(self.lstm_size, self.emb_size),
            nn.ReLU(inplace=True),
            nn.Linear(self.emb_size, self.out_size),
            ).cuda()

        self.fc_beta = nn.Sequential(
            nn.Linear(self.lstm_size, self.emb_size),
            nn.ReLU(inplace=True),
            nn.Linear(self.emb_size, self.out_size),
            ).cuda()

        # initialize weights using Xavier initialization and biases with constant value
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0.1)

    '''
    Predicts the value of delta beta and delta gamma for each channel
    Arguments:
        lstm_emb : lstm embedding of the question
    Returns:
        delta_betas, delta_gammas : for each layer
    '''
    def create_cbn_input(self, lstm_emb):

        if self.use_betas:
            delta_betas = self.fc_beta(lstm_emb)
        else:
            delta_betas = torch.zeros(self.batch_size, self.channels).cuda()

        if self.use_gammas:
            delta_gammas = self.fc_gamma(lstm_emb)
        else:
            delta_gammas = torch.zeros(self.batch_size, self.channels).cuda()

        return delta_betas, delta_gammas

    '''
    Computer Normalized feature map with the updated beta and gamma values
    Arguments:
        feature : feature map from the previous layer
        lstm_emb : lstm embedding of the question
    Returns:
        out : beta and gamma normalized feature map
        lstm_emb : lstm embedding of the question (unchanged)
    Note : lstm_emb needs to be returned since CBN is defined within nn.Sequential
           and subsequent CBN layers will also require lstm question embeddings
    '''
    def forward(self, feature, lstm_emb):
        self.batch_size, self.channels, self.height, self.width = feature.data.shape

        # get delta values
        delta_betas, delta_gammas = self.create_cbn_input(lstm_emb)
        
        self.betas.data.resize_(self.batch_size, self.channels)
        self.gammas.data.resize_(self.batch_size, self.channels)

        betas_cloned = self.betas.clone()
        gammas_cloned = self.gammas.clone()
        
        # update the values of beta and gamma
        betas_cloned += delta_betas
        gammas_cloned += delta_gammas

        # get the mean and variance for the batch norm layer
        batch_mean = torch.mean(feature)
        batch_var = torch.var(feature)

        # extend the betas and gammas of each channel across the height and width of feature map
        betas_expanded = torch.stack([betas_cloned]*self.height, dim=2)
        betas_expanded = torch.stack([betas_expanded]*self.width, dim=3)

        gammas_expanded = torch.stack([gammas_cloned]*self.height, dim=2)
        gammas_expanded = torch.stack([gammas_expanded]*self.width, dim=3)

        # normalize the feature map
        feature_normalized = (feature-batch_mean)/torch.sqrt(batch_var+self.eps)

        # get the normalized feature map with the updated beta and gamma values
        out = torch.mul(feature_normalized, gammas_expanded) + betas_expanded

        return out
    
    
    
    
###################################################################################################################
##                                                  Spectral Norm                                                ##
## https://github.com/christiancosgrove/pytorch-spectral-normalization-gan/blob/master/spectral_normalization.py ##
###################################################################################################################

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=True)
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=True)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = nn.Parameter(w.data, requires_grad=True)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)
    

In [None]:
class snp_generator(nn.Module):
    def __init__(self):
        super(snp_generator, self).__init__()

        self.embed = nn.Embedding(15, 4)

        self.fc = nn.Linear(100, 256*12*1)

        self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=(7, 1), stride=2,
                         padding=(2, 0), bias=False)
        self.cbn1 = CBN(4, 128, 128)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=(8, 1), stride=2,
                         padding=(3, 0), bias=False)
        self.cbn2 = CBN(4, 128, 64)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=(8, 1), stride=2,
                         padding=(3, 0), bias=False)
        self.cbn3 = CBN(4, 128, 32)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.ConvTranspose2d(32, 1, kernel_size=(16, 4), stride=2,
                         padding=(7, 0), bias=False)
        self.softmax = nn.Softmax(dim=3)

    def forward(self, nz, c):
        embedding = self.embed(c)
        h = self.fc(nz)
        h = self.conv1(h.view(-1, 256, 12, 1))
        h = self.cbn1(h, embedding)
        h = self.relu1(h)
        h = self.conv2(h)
        h = self.cbn2(h, embedding)
        h = self.relu2(h)
        h = self.conv3(h)
        h = self.cbn3(h, embedding)
        h = self.relu3(h)
        h = self.conv4(h)
        output = self.softmax(h)
        return output



class snp_discriminator(nn.Module):
    def __init__(self):
        super(snp_discriminator, self).__init__()

        self.embed = nn.Embedding(15, 8)

#         self.conv1 = SpectralNorm(nn.Conv2d(1, 16, kernel_size=(11, 4), stride=1, padding=(5, 0), bias=False))
#         self.lrelu1 = nn.LeakyReLU(0.2, inplace=True)
#         self.conv2 = SpectralNorm(nn.Conv2d(16, 32, kernel_size=(11, 1), stride=2, padding=(5, 0), bias=False))
#         self.lrelu2 = nn.LeakyReLU(0.2, inplace=True)
#         self.conv3 = SpectralNorm(nn.Conv2d(32, 64, kernel_size=(11, 1), stride=2, padding=(5, 0), bias=False))
#         self.lrelu3 = nn.LeakyReLU(0.2, inplace=True)
#         self.conv4 = SpectralNorm(nn.Conv2d(64, 128, kernel_size=(11, 1), stride=2, padding=(5, 0), bias=False))
#         self.lrelu4 = nn.LeakyReLU(0.2, inplace=True)
#         self.pool = nn.AvgPool2d((25, 1))
#         self.fc = SpectralNorm(nn.Linear(128, 1))
#         self.sigmoid = nn.Sigmoid()

        self.conv1 = SpectralNorm(nn.Conv2d(1, 320, kernel_size=(11, 4), stride=1, padding=(5, 0), bias=False))
        self.pool1 = nn.MaxPool2d(kernel_size=(40, 1))
        self.lrelu1 = nn.LeakyReLU(0.1, inplace=True)
        self.fc1 = SpectralNorm(nn.Linear(320*5, 8))
        self.drop1 = nn.Dropout(0.5)
        self.lrelu2 = nn.LeakyReLU(0.1, inplace=True)
        self.fc2 = SpectralNorm(nn.Linear(8, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, c):
        embedding = self.embed(c)
        h = self.conv1(x.view(-1, 1, 200, 4))
        h = self.pool1(h)
#         h = self.bn1(h)
        h = self.lrelu1(h)
        h = self.fc1(h.view(-1, 320 * 5))
        h = self.drop1(h)
        h = self.lrelu2(h)
#         h = self.bn2(h)
        fc = self.fc2(h)
#         h = self.conv1(x.view(-1, 1, 200, 4))
#         h = self.lrelu1(h)
#         h = self.conv2(h)
#         h = self.lrelu2(h)
#         h = self.conv3(h)
#         h = self.lrelu3(h)
#         h = self.conv4(h)
#         h = self.lrelu4(h)
#         pool = self.pool(h)
#         fc = self.fc(pool.squeeze())
        proj = torch.sum(h * embedding, 1, keepdim=True)
        fc += proj
        sig = self.sigmoid(fc)
        return sig.squeeze()

        

In [None]:
G = snp_generator().to("cuda")
D = snp_discriminator().to("cuda")

BATCH_SIZE = 128
NZ = 100
NC = 15


criterion = nn.BCELoss().to("cuda")
optD = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.99))
optG = optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.99))

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.detach().normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.detach().normal_(1.0, 0.02)
        m.bias.detach().fill_(0)

G.apply(weights_init)

train_hist = {}
train_hist['d_loss'] = []
train_hist['g_loss'] = []
train_hist['epoch_time'] = []
train_hist['total_time'] = []


def create_fixed_inputs(num_total_imgs):
    # For saving imgs
    fixed_noise = torch.Tensor(num_total_imgs, NZ).normal_(0, 1).to("cuda")

    np_fixed_c = np.arange(0, NC, 1)
    fixed_c = torch.from_numpy(np.repeat(np_fixed_c, num_total_imgs//NC, axis=0)).long().to("cuda")

    return fixed_noise, fixed_c


def save_imgs(G, f_noise, f_c, it):
    with torch.no_grad():
        images = G(f_noise, f_c)
    to_save = images.transpose(2, 3)
    img_path = '/home/pbromley/generative_dhs/images/pytorch-sngan-dhs-200/%d.png' % it
    vutils.save_image(to_save, img_path, nrow=1, normalize=True)
    
def accuracy(output, target):
    """Computes the accuracy for multiple binary predictions"""
    pred = output >= 0.5
    truth = target >= 0.5
    acc = pred.eq(truth).sum().item() / target.numel()
    return acc


def train(G, D, iterations):
    dataiter = iter(dhs_dataloader)
    noise = torch.zeros(BATCH_SIZE, NZ).to("cuda")
    fake_c = torch.zeros(BATCH_SIZE).to("cuda")
    label = torch.zeros(BATCH_SIZE).to("cuda")
    fixed_noise, fixed_c = create_fixed_inputs(45)
    neg_one = torch.FloatTensor([-1]).to("cuda")
    for iteration in range(iterations):
        
        ## GENERATOR ##
        optG.zero_grad()
        noise.normal_(0, 1)
        label.fill_(1)
        fake_c = torch.randint_like(fake_c, 0, NC).long()
        fake = G(noise, fake_c)
        pred_g = D(fake, fake_c)
        g_loss = pred_g.mean()     ##
        g_loss.backward(neg_one)   ##
        optG.step()
        
        ## DISCRIMINATOR ##
        for d_iter in range(5):
            optD.zero_grad()
            batch = next(dataiter, None)
            if (batch is None) or (batch[0].size(0) != BATCH_SIZE):
                dataiter = iter(dhs_dataloader)
                batch = dataiter.next()
            x, c = batch
            x = x.float().to("cuda")
            c = c.long().to("cuda")
            label.fill_(1)
            pred_real = D(x, c)
            d_loss_real = pred_real.mean()  ##
#             d_loss_real.backward()
            
            label.fill_(0)
            noise.normal_(0, 1)
            fake_c = torch.randint_like(fake_c, 0, NC).long()
            fake = G(noise, fake_c)
            pred_fake = D(fake.detach(), fake_c)
            d_loss_fake = pred_fake.mean()  ##
#             d_loss_fake.backward()

            d_loss_total = d_loss_fake - d_loss_real ##
            d_loss_total.backward() ##
            
            optD.step()
            
        if (iteration % 100 == 0):
            
#             acc_real = accuracy(pred_real, torch.ones(BATCH_SIZE).cuda())
#             acc_fake = accuracy(pred_fake, torch.zeros(BATCH_SIZE).cuda())
            print('Iter:{0}, Dloss: {1}, Gloss: {2}'.format(
                            iteration, d_loss_total.item(), g_loss.item())
                 )
            if iteration % 1000 == 0:
                save_imgs(G, fixed_noise, fixed_c, iteration)
                
            
                        
            
train(G, D, 500000)
        

In [None]:
# path = "/home/pbromley/generative_dhs/saved_models/pytorch-sngan-dhs-wgan-200"
# print("Saving Model Weights...")
# torch.save(G.state_dict(), path + "-g.pth")
# torch.save(D.state_dict(), path + "-d.pth")
# print("Saved Model Weights")

class snp_generator_2d(nn.Module):
    def __init__(self, nz, ne, cbn_h, num_filters, len_filters, dropout=False, concat=False):
        super(snp_generator_2d, self).__init__()

        self.num_filters = num_filters

        self.embed = nn.Embedding(15, ne)
        self.fc = SpectralNorm(nn.Linear(nz, num_filters//2*10))
        self.relu1 = nn.ReLU(True)
        self.cbn1 = CBN(ne, cbn_h, num_filters//2)
        self.up1 = SpectralNorm(nn.ConvTranspose2d(num_filters//2, num_filters, (10, 1), 10, bias=False))    # -1, 320, 100, 1
        self.relu2 = nn.ReLU(True)
        self.cbn2 = CBN(ne, cbn_h, num_filters)
        self.up2 = SpectralNorm(nn.ConvTranspose2d(num_filters, 1, (len_filters, 4), 1, (len_filters//2, 0)))
        self.softmax = nn.Softmax(dim=3)

    def forward(self, nz, c):
        embedding = self.embed(c)
        h = self.fc(nz).view(-1, self.num_filters//2, 10, 1)
        h = self.relu1(h)
        h = self.cbn1(h, embedding)
        h = self.up1(h)
        h = self.relu2(h)
        h = self.cbn2(h, embedding)
        h = self.up2(h)
        output = self.softmax(h)
        return output

G = snp_generator_2d(100, 4, 256, 320, 17).to("cuda")
G.load_state_dict(torch.load("/home/pbromley/generative_dhs/saved_models/pytorch-sngan-hinge-dhs-100-2-g.pth"))

In [None]:
G.double()
G.to("cpu")

In [None]:
G.train(False)
G.eval()
one_hot_seqs = np.zeros((153600, 1, 100, 4))
components = np.zeros(153600, dtype=int)
for i in range(15):
    for j in range(640):
        noise = torch.from_numpy(np.random.normal(0, 1, (16, 100)))
#         components = torch.from_numpy(np.ones(128, dtype=int)*i).long()
        comps = torch.from_numpy(np.random.randint(0, 15, 16, dtype=int)).long()
        with torch.no_grad():
            one_hot_seqs[(i*10240) + (j*16):(i*10240) + ((j+1)*16)] = G(noise, comps).numpy()
            components[(i*10240) + (j*16):(i*10240) + ((j+1)*16)] = comps

one_hot_seqs = one_hot_seqs.reshape(-1, 100, 4)
    
def one_hot_to_seq(one_hot):
    order_dict = {0:'A', 1:'T', 2:'C', 3:'G'}
    seq = ""
    idxs = np.argmax(one_hot, axis=1)
    for elt in idxs:
        seq += order_dict[elt]
    return Seq(seq, single_letter_alphabet)

print("Converting one-hot to normal...")
seqs = [one_hot_to_seq(one_hot_seq) for one_hot_seq in one_hot_seqs]
# components = np.repeat(np.arange(0, 15), 20480)

In [None]:
np.repeat(np.arange(0, 5, 1), 45//5, axis=0)

In [None]:
seqs = [one_hot_to_seq(one_hot_seq) for one_hot_seq in x_train.squeeze()]
components = y_train

In [None]:
for i in range(10):
    noise = torch.from_numpy(np.random.normal(0, 1, (16, 100)))
    comps = torch.from_numpy(np.random.randint(0, 15, 16, dtype=int)).long()
    _ = G(noise, comps)

In [None]:
motifs = ["GGCGC", "ATGAGTCAT", "CGAAACCGAAAC", "TGATGCAA", "ATGA", "ATTGT",
          "AACCGGTT", "AATTA", "AAATAG", "GTCACGCTT", "GTAAACA", "AACAGCTGT", "CAAAGT", "CGGAT",
          "ACTTCC"]
motifs = [Seq(motif, single_letter_alphabet) for motif in motifs]

# def motif_search(seqs, components, motifs):
count_dict = {num:np.zeros(15) for num in range(15)}
print("Motif: ", end=" ")
for i, motif in enumerate(motifs):
    print(str(i) + ":" + motif, end=' ')
    for j, seq in enumerate(seqs):
        count_dict[components[j]][i] += seq.count(motif)

count_mat = np.array(list(count_dict.values()))
plt.xticks(np.arange(15), [str(x+1) for x in range(15)])
plt.yticks(np.arange(15), [str(x+1) for x in range(15)])
plt.imshow(count_mat/count_mat.max(axis=0), cmap='Greens')
plt.colorbar()
plt.xlabel("Motifs (for each component)")
plt.ylabel("Sequences (separated by component)")
plt.show()


    
# motif_search(seqs, components, motifs)

In [None]:
np.set_printoptions(linewidth=150)
count_mat.astype(int)

In [None]:
calc_overall_composition(seqs)

In [None]:
# Calculate the avg A, T, C, G content per sequence for an array of sequences 
#  (output is np array of format [A/n, T/n, C/n, G/n] where n = len(seqs))
def calc_overall_composition(seqs):
    atcg = np.array([0, 0, 0, 0])
    for seq in seqs:
        atcg[0] += seq.count("A")
        atcg[1] += seq.count("T")
        atcg[2] += seq.count("C")
        atcg[3] += seq.count("G")
    return atcg / len(seqs)

# Calculate (and plot) the avg A, T, C, G content per sequence in each component class 
#  Input: array of seqs and array of corresponding component class
#  Output: dict of format Component:[A/n_c, T/n_c, C/n_c, G/n_c] where n_c = len(seqs from class c)
#  If plot is true, plot the output as a bar graph of ATCG content across sequences by class
def per_class_composition(seqs, c, plot=True):
    num_c = np.max(c) + 1
    comp_dict = {k:np.array([0, 0, 0, 0]) for k in range(num_c)}
    for i in range(len(seqs)):
        comp_dict[c[i]][0] += seqs[i].count("A")
        comp_dict[c[i]][1] += seqs[i].count("T")
        comp_dict[c[i]][2] += seqs[i].count("C")
        comp_dict[c[i]][3] += seqs[i].count("G")
    num_per_class = np.bincount(c.astype(int))
    comp_dict = {k:v/num_per_class[k] for k, v in comp_dict.items()}
    
    if plot:
        plt.figure(figsize=(10, 10))
        index = np.arange(num_c)
        bar_width = 0.2
        a = [comp_dict[i][0] for i in range(num_c)]
        t = [comp_dict[i][1] for i in range(num_c)]
        c = [comp_dict[i][2] for i in range(num_c)]
        g = [comp_dict[i][3] for i in range(num_c)]
        plt.bar(index, a, bar_width)
        plt.bar(index + bar_width, t, bar_width)
        plt.bar(index + 2*bar_width, c, bar_width)
        plt.bar(index + 3*bar_width, g, bar_width)
        
    return comp_dict

In [None]:
per_class_composition(seqs, components.astype(int))

In [None]:
seq_array = []
for i in range(len(seqs)):
    seq_array.append(SeqRecord(seqs[i], id=str(components[i])))
SeqIO.write(seq_array, "/home/pbromley/generative_dhs/memes/fake_test_4.fasta", "fasta")

In [None]:
128 * 160

In [None]:
seqs

In [None]:
## for the pytorch-sngan-dhs-100-siggenerator
class snp_generator(nn.Module):
    def __init__(self):
        super(snp_generator, self).__init__()

        self.embed = nn.Embedding(15, 4)

        self.fc = nn.Linear(100, 128*12*1)

        self.conv1 = nn.ConvTranspose2d(128, 64, kernel_size=(7, 1), stride=2,
                         padding=(2, 0), bias=False)
        self.cbn1 = CBN(4, 128, 64)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.ConvTranspose2d(64, 32, kernel_size=(8, 1), stride=2,
                         padding=(3, 0), bias=False)
        self.cbn2 = CBN(4, 128, 32)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.ConvTranspose2d(32, 1, kernel_size=(16, 4), stride=2,
                         padding=(7, 0), bias=False)
        self.softmax = nn.Softmax(dim=3)

    def forward(self, nz, c):
        embedding = self.embed(c)
        h = self.fc(nz)
        h = self.conv1(h.view(-1, 128, 12, 1))
        h = self.cbn1(h, embedding)
        h = self.relu1(h)
        h = self.conv2(h)
        h = self.cbn2(h, embedding)
        h = self.relu2(h)
        h = self.conv3(h)
        output = self.softmax(h)
        return output



In [None]:
## for the pytorch-sngan-dhs-100-sig-1dg generator
class snp_generator_1d(nn.Module):
    def __init__(self):
        super(snp_generator_1d, self).__init__()

        self.embed = nn.Embedding(15, 4)

        self.conv1 = nn.Conv1d(1, 100, 1, 1, bias=False)
        self.cbn1 = CBN(4, 128, 100)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(100, 4, 11, 1, 5, bias=False)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, nz, c):
        embedding = self.embed(c)
        h = self.conv1(nz.view(-1, 1, 100))
        h = self.cbn1(h.view(-1, 100, 100, 1), embedding)
        h = self.relu2(h.squeeze())
        h = self.conv2(h)
        output = self.softmax(h)
        return output.transpose(1, 2).view(-1, 1, 100, 4)

In [None]:
## for the pytorch-sngan-dhs-100-sig-1dg-2 generator
class snp_generator_1d(nn.Module):
    def __init__(self):
        super(snp_generator_1d, self).__init__()

        self.embed = nn.Embedding(15, 4)

        self.conv1 = nn.Conv1d(1, 200, 1, 1, bias=False)
        self.cbn1 = CBN(4, 128, 200)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(200, 4, 1, 1, bias=False)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, nz, c):
        embedding = self.embed(c)
        h = self.conv1(nz.view(-1, 1, 100))
        h = self.cbn1(h.view(-1, 200, 100, 1), embedding)
        h = self.relu2(h.squeeze())
        h = self.conv2(h)
        output = self.softmax(h)
        return output.transpose(1, 2).view(-1, 1, 100, 4)

In [None]:
class snp_generator_2(nn.Module):
    def __init__(self, nz, ne, cbn_h, num_filters, len_filters, dropout=False, concat=False):
        super(snp_generator_2, self).__init__()
        
        self.num_filters = num_filters

        self.embed = nn.Embedding(15, ne)
        self.fc = SpectralNorm(nn.Linear(nz, num_filters//2*10))
        self.relu1 = nn.ReLU(True)
        self.cbn1 = CBN(ne, cbn_h, num_filters//2)
        self.up1 = SpectralNorm(nn.ConvTranspose2d(num_filters//2, num_filters, (10, 1), 10, bias=False))    # -1, 320, 100, 1
        self.relu2 = nn.ReLU(True)
        self.cbn2 = CBN(ne, cbn_h, num_filters)
        self.up2 = SpectralNorm(nn.ConvTranspose2d(num_filters, 1, (len_filters, 4), 1, (len_filters//2, 0)))
        self.softmax = nn.Softmax(dim=3)

    def forward(self, nz, c):
        embedding = self.embed(c)
        h = self.fc(nz).view(-1, self.num_filters//2, 10, 1)
        h = self.relu1(h)
        h = self.cbn1(h, embedding)
        h = self.up1(h)
        h = self.relu2(h)
        h = self.cbn2(h, embedding)
        h = self.up2(h)
        output = self.softmax(h)
        return output

In [None]:
x = nn.ConvTranspose2d(160, 320, (10, 1), 10)
y = nn.ConvTranspose2d(320, 1, (11, 4), 1, (5, 0))
y(x(torch.rand(1, 160, 10, 1))).size()