In [1]:
import os
from sklearn.datasets import fetch_mldata
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torch.nn as nn
import torch.nn.parallel
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import cv2
import numpy as np

torch.cuda.set_device(0)
device_ids = [0,1]
batchsize = 1
rand = 128
cont = 4
dis = 1

In [2]:
def get_mnist():
    mnist = fetch_mldata('MNIST original',data_home="/home/msragpu/cellwork/test_dataset/")
    np.random.seed(1234) # set seed for deterministic ordering
   #p = np.datacom.permutation(mnist.data.shape[0])
   #X = mnist.data[p]
    X = mnist.data.reshape((70000, 28, 28))

    X = np.asarray([cv2.resize(x, (32,32)) for x in X])

    X = X.astype(np.float32)/(255.0/2) - 1.0
    X = X.reshape((70000, 1, 32, 32)) 
    X = np.tile(X, (1, 3, 1, 1))
    p = np.random.permutation(70000)
    X = X[p]
    X_train = X[:60000]
    X_test = X[60000:70000]
    
    return X_train.reshape(60000,3,32,32)

def visual(X):
    assert len(X.shape) == 4
    X = X.transpose((0, 2, 3, 1))
    X = (X+1.0)*(255.0/2.0)
    X = X.reshape(X.shape[1],X.shape[2],X.shape[3])
 #   X = X[:,:,::-1]
    return np.uint8(X) #  cv2.waitKey(1)

def fill_buf(buf, i, img, shape):
    n = buf.shape[0]/shape[1]
    m = buf.shape[1]/shape[0]

    sx = (i%m)*shape[0]
    sy = (i/m)*shape[1]
    buf[sy:sy+shape[1], sx:sx+shape[0], :] = img


In [3]:
def bonemarrow_cell():
    X = np.load("/home/msragpu/cellwork/data/data.npy")
    img = X
    X = np.asarray([cv2.resize(x, (32,32)) for x in X])
    X = np.asarray([x[:,:,::-1].transpose((2,0,1)) for x in X])
    X = X.astype(np.float32)/(255.0/2) - 1.0
    return X

X = bonemarrow_cell()
print (X.shape)

(3841, 3, 32, 32)


In [4]:
import cv2

def test():
    _X = cv2.imread('./111.jpg',3)
    X = np.float32(_X)
    print (X.dtype)
    X = X.reshape(1,X.shape[0],X.shape[1],X.shape[2])
    X = np.asarray([x[:,:,::-1].transpose((2,0,1)) for x in X])
    X = X.astype(np.float32)/(255.0/2) - 1.0
    return X

X = test()
print (X.shape)

float32
(1, 3, 320, 320)


In [5]:
#X_train = bonemarrow_cell()
X_train = bonemarrow_cell()
X_label = torch.LongTensor(np.zeros((X_train.shape[0]),dtype=int))
X_train = torch.FloatTensor(X_train)
train = torch.utils.data.TensorDataset(X_train,X_label)
train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=batchsize)

dataiter = iter(train_loader)

In [6]:
import torch.nn.parallel

class _netG(nn.Module):
    def __init__(self, isize = 32, nz = 149, nc = 3, ngf = 64, n_extra_layers=0):
        super(_netG, self).__init__()
        assert isize % 16 == 0, "isize has to be a multiple of 16"

        cngf, tisize = ngf//2, 4
        while tisize != isize:
            cngf = cngf * 2
            tisize = tisize * 2

        main = nn.Sequential()
        # input is Z, going into a convolution
        main.add_module('initial.{0}-{1}.convt'.format(nz, cngf),
                        nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False))
        main.add_module('initial.{0}.batchnorm'.format(cngf),
                        nn.BatchNorm2d(cngf))
        main.add_module('initial.{0}.relu'.format(cngf),
                        nn.ReLU(True))

        csize, cndf = 4, cngf
        while csize < isize//2:
            main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2),
                            nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False))
            main.add_module('pyramid.{0}.batchnorm'.format(cngf//2),
                            nn.BatchNorm2d(cngf//2))
            main.add_module('pyramid.{0}.relu'.format(cngf//2),
                            nn.ReLU(True))
            cngf = cngf // 2
            csize = csize * 2

        # Extra layers
        for t in range(n_extra_layers):
            main.add_module('extra-layers-{0}.{1}.conv'.format(t, cngf),
                            nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False))
            main.add_module('extra-layers-{0}.{1}.batchnorm'.format(t, cngf),
                            nn.BatchNorm2d(cngf))
            main.add_module('extra-layers-{0}.{1}.relu'.format(t, cngf),
                            nn.ReLU(True))

        main.add_module('final.{0}-{1}.convt'.format(cngf, nc),
                        nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))
        main.add_module('final.{0}.tanh'.format(nc),
                        nn.Tanh())
        self.main = main

    def forward(self, input):
        return self.main(input)
    
netG = _netG(nz = rand+dis*10+cont)
print (netG)


""" ==================== DISCRIMINATOR  ======================== """

class _netD(nn.Module):
    def __init__(self, isize = 32, nz = 149, nc = 3, ndf = 64, n_extra_layers=0):
        super(_netD, self).__init__()
        assert isize % 16 == 0, "isize has to be a multiple of 16"

        main = nn.Sequential()
        # input is nc x isize x isize
        main.add_module('initial.conv.{0}-{1}'.format(nc, ndf),
                        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
        main.add_module('initial.relu.{0}'.format(ndf),
                        nn.LeakyReLU(0.2, inplace=True))
        csize, cndf = isize / 2, ndf

        # Extra layers
        for t in range(n_extra_layers):
            main.add_module('extra-layers-{0}.{1}.conv'.format(t, cndf),
                            nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False))
            main.add_module('extra-layers-{0}.{1}.batchnorm'.format(t, cndf),
                            nn.BatchNorm2d(cndf))
            main.add_module('extra-layers-{0}.{1}.relu'.format(t, cndf),
                            nn.LeakyReLU(0.2, inplace=True))

        while csize > 4:
            in_feat = cndf
            out_feat = cndf * 2
            main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat),
                            nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))
            main.add_module('pyramid.{0}.batchnorm'.format(out_feat),
                            nn.BatchNorm2d(out_feat))
            main.add_module('pyramid.{0}.relu'.format(out_feat),
                            nn.LeakyReLU(0.2, inplace=True))
            cndf = cndf * 2
            csize = csize / 2

        # state size. K x 4 x 4
       # main.add_module('final.{0}-{1}.conv'.format(cndf, 1),
       #                 nn.Conv2d(cndf, 1, 4, 1, 0, bias=False))
        self.main = main
        
    def forward(self, input):
        return self.main(input)


netD = _netD(nz = rand+dis*10+cont)
print (netD)

_netG (
  (main): Sequential (
    (initial.142-256.convt): ConvTranspose2d(142, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (initial.256.batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (initial.256.relu): ReLU (inplace)
    (pyramid.256-128.convt): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid.128.batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (pyramid.128.relu): ReLU (inplace)
    (pyramid.128-64.convt): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid.64.batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (pyramid.64.relu): ReLU (inplace)
    (final.64-3.convt): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (final.3.tanh): Tanh ()
  )
)
_netD (
  (main): Sequential (
    (initial.conv.3-64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bi

In [7]:
class _netD_D(nn.Module):
    def __init__(self):
        super(_netD_D, self).__init__()
        self.conv = nn.Conv2d(256, 1, 4, 1, 0, bias=False)
        
    def forward(self, x):
        x = self.conv(x)
        return x
    
class _netD_Q(nn.Module):
    def __init__(self, nd = 10):
        super(_netD_Q, self).__init__()
        # input is Z, going into a convolution
        self.conv = nn.Conv2d(256, 10, 4, 1, 0, bias=False)
        self.softmax = nn.LogSoftmax()

    def forward(self, x):
        x = self.conv(x)
        x = self.softmax(x)
       # x = x.view(64,10)
        return x
'''   
class _netD_Q_2(nn.Module):
    def __init__(self):
        super(_netD_Q_2, self).__init__()
        # input is Z, going into a convolution
        self.conv = nn.Conv2d(256, 10, 4, 1, 0, bias=False)
        self.softmax = nn.LogSoftmax()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.softmax(x)
      #  x = x.view(64,10)
        return x
''' 

class _netD_Q_3(nn.Module):
    def __init__(self, nc = 4):
        super(_netD_Q_3, self).__init__()
        # input is Z, going into a convolution
        self.conv = nn.Conv2d(256, nc, 4, 1, 0, bias=False)
        
    def forward(self, x):
        x = self.conv(x)
        return x

    
netD_D = _netD_D()
netD_Q = _netD_Q(nd = 10*dis)
netD_Q_3 = _netD_Q_3(nc = cont)

In [8]:
parallel = False

if parallel == True:
    netD, netG, netD_D, netD_Q,  netD_Q_3 = [torch.nn.DataParallel(netD.cuda(),device_ids=device_ids),
                                                torch.nn.DataParallel(netG.cuda(),device_ids=device_ids),
                                                torch.nn.DataParallel(netD_D.cuda(),device_ids=device_ids),
                                                torch.nn.DataParallel(netD_Q.cuda(),device_ids=device_ids),
                                                torch.nn.DataParallel(netD_Q_3.cuda(),device_ids=device_ids)]
else:
    netD, netG, netD_D, netD_Q, netD_Q_3 = netD.cuda(), netG.cuda(), netD_D.cuda(), netD_Q.cuda(),  netD_Q_3.cuda()
# In[35]:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

netG.apply(weights_init)
netD.apply(weights_init)
netD_Q.apply(weights_init)
netD_Q_3.apply(weights_init)
netD_D.apply(weights_init)

_netD_D (
  (conv): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
)

In [9]:
optimizerD = optim.RMSprop([
                {'params': netD.parameters()},
                {'params': netD_D.parameters()}
            ], 0.00005)

optimizerG = optim.RMSprop(netG.parameters(), lr = 0.00005)
 
optimizerQ = optim.RMSprop([
                {'params': netG.parameters()},            
                {'params': netD.parameters()},
                {'params': netD_Q.parameters()},
                {'params': netD_Q_3.parameters()}
            ], 0.00005)

In [10]:
input = torch.FloatTensor(batchsize, 3, 32, 32)
noise = torch.FloatTensor(batchsize, rand+10*dis+cont,1 ,1 )

fixed_noise = torch.FloatTensor(np.random.multinomial(batchsize, 10*[0.1], size=1))
c = torch.randn(batchsize, 10)
c2 = torch.randn(batchsize, 10)
c3 = torch.FloatTensor(np.random.uniform(-1,1,(batchsize,cont)))
z = torch.randn(batchsize, rand)

label = torch.FloatTensor(1)

real_label = 1
fake_label = 0

criterion = nn.BCELoss()
criterion_logli = nn.NLLLoss()
criterion_mse = nn.MSELoss()

criterion, criterion_logli, criterion_mse = criterion.cuda(), criterion_logli.cuda(), criterion_mse.cuda()
input, label = input.cuda(), label.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
z, c, c2, c3 = z.cuda(), c.cuda(), c2.cuda(), c3.cuda()

In [11]:
def sample_c(batchsize):
    rand_c = np.zeros((batchsize,10),dtype='float32')
    for i in range(0,batchsize):
        rand = np.random.multinomial(1, 10*[0.1], size=1)
        rand_c[i] = rand
    
    label_c = np.argmax(rand_c,axis=1)
    label_c = torch.LongTensor(label_c.astype('int'))
    rand_c = torch.from_numpy(rand_c.astype('float32'))
    return rand_c,label_c

def zero_grad():
    netD.zero_grad()
    netD_Q.zero_grad()
    #netD_Q_2.zero_grad()
    netD_Q_3.zero_grad()
    netD_D.zero_grad()
    netG.zero_grad()

def weight_clamp():
    for p in netD.parameters():
            p.data.clamp_(-0.01, 0.01)
    for p in netD_D.parameters():
            p.data.clamp_(-0.01, 0.01)
    #for p in netD_Q.parameters():
            #p.data.clamp_(-0.01, 0.01)
    #for p in netD_Q_2.parameters():
            #p.data.clamp_(-0.01, 0.01)
    #for p in netD_Q_3.parameters():
            #p.data.clamp_(-0.01, 0.01)
        
def generate_fix_noise(dis=1, cont=4, rand=128):
    
    fixed_z = np.random.randn(10,rand).repeat(10,axis=0)
    changing_dis = np.zeros((100,10),dtype = np.float32)
    list = [n for n in range(0,10)]*10
    for i in range(0,100):
        changing_dis[i,list[i]] = 1
    fixed_cont = np.zeros((100,cont),dtype = np.float32)
    map1 = np.concatenate((changing_dis,fixed_cont,fixed_z),axis=1)
    
    lst = [map1.astype(np.float32)]
    single_cont = np.asarray([float(n-5)*2/5 for n in range(0,10)]*10,dtype = np.float32)
    
    fixed_dis = np.zeros((100,10),dtype=np.float32)
    for t in range(0,5):
        fixed_dis[t*20:t*20+20,t*2] = 1
        
    for t in range (0,4):
        fixed_cont = np.zeros((100,cont),dtype = np.float32)
        fixed_cont[:,t] = single_cont
        map2 = np.concatenate((fixed_dis,fixed_cont,fixed_z),axis=1)
        lst.append(map2.astype(np.float32))
    
    return lst

one = torch.FloatTensor([1])
mone = one * -1
one = one.cuda()
mone = mone.cuda()

In [12]:
gen_iterations = 0

for epoch in range(100000):

    dataiter = iter(train_loader)
    i = 0
    
    while i < len(train_loader):
        weight_clamp()
        if gen_iterations < 25 or gen_iterations % 500 == 0:
            Diters = 1
        else:
            Diters = 1
        
        j = 0
        while j < Diters and i < len(train_loader):
            j += 1
            image_, _ = dataiter.next()
            _batchsize = image_.size(0)
            
            image_ = image_.cuda()
            
            i +=1
            weight_clamp()
    #train on D
    #sending real data 
            zero_grad()
            input.resize_as_(image_).copy_(image_)
            inputv = Variable(input)
         #   label.data.resize_(1).fill_(real_label)
            D_real =netD_D(netD(inputv)).mean(0).view(1)
            #D_loss_real = criterion(D_real, label)
            D_real.backward(one)

    #sending noise
            z.normal_(0, 1)
            rand_c,label_c = sample_c(batchsize)
            c.copy_(rand_c)
            c3.uniform_(-1,1)
            noise = torch.cat([c,c3,z],1)
            noise_resize = noise.view(batchsize,rand+10*dis+cont,1,1)
            noisev = Variable(noise_resize)
            
            G_sample = Variable(netG(noisev).data)
            inputv = G_sample
            D_fake = netD_D(netD(inputv)).mean(0).view(1)
         #   label.data.resize_(1).fill_(fake_label)
           # D_loss_fake = criterion(D_fake, label)
            D_fake.backward(mone)
        
    # update D
            optimizerD.step()
    
        for p in netD.parameters():
            p.requires_grad = False # to avoid computation
        for p in netD_D.parameters():
            p.requires_grad = False # to avoid computation

    # update G  
        zero_grad()
        noisev = Variable(noise_resize)
        G_sample = netG(noisev)
        D_fake = netD_D(netD(G_sample)).mean(0).view(1)
      #  label.data.resize_(1).fill_(real_label)
       # G_loss = criterion(D_fake, label)
        D_fake.backward(one)
        optimizerG.step()
        
        gen_iterations += 1
        
        for p in netD.parameters():
            p.requires_grad = True # to avoid computation
        for p in netD_D.parameters():
            p.requires_grad = True # to avoid computation

    # update Q
        zero_grad()
        noisev = Variable(noise_resize)
        G_sample = netG(noisev)
        Q_c_given_x = netD_Q(netD(G_sample)).view(batchsize, 10)
        Q_c_given_x_3 = netD_Q_3(netD(G_sample))
        
        crossent_loss = criterion_logli(Q_c_given_x ,Variable(label_c.cuda()))
       # print (Q_c_given_x)
        crossent_loss_3 = criterion_mse(Q_c_given_x_3, Variable(c3)) 

        # ent_loss = torch.mean(-torch.sum(c * torch.log(c + 1e-8), dim=1))
       # ent_loss_2 = torch.mean(-torch.sum(c2 * torch.log(c2 + 1e-8), dim=1))
       # ent_loss_3 = torch.mean(-torch.sum(c3 * torch.log(c3 + 1e-8), dim=1))

        mi_loss = 0.1*crossent_loss  + 1*crossent_loss_3

        mi_loss.backward()
        optimizerQ.step()
        
        if gen_iterations % 20 == 0 :
            errD = D_real - D_fake
            with open("output_cell.txt","w") as f:
                f.write('{0} {1} {2} {3}'.format(epoch, gen_iterations , -errD.data[0] , mi_loss.data[0]) + '\n')
            #print ('{0} {1} {2} {3}'.format(epoch, gen_iterations , -errD.data[0] , mi_loss.data[0]))
            
            #vutils.save_image(G_sample.data, '{0}fake_samples_{1}.png'.format(-errD.data[0], gen_iterations))
            vutils.save_image(G_sample.data, 'fake_samples.png',normalize = True)
            
            for t in range(0,4):
                fixed_noise = generate_fix_noise(dis, cont, rand)[i].reshape(100,rand+dis*10+cont,1,1)
                G_sample = netG(Variable(torch.FloatTensor(fixed_noise).cuda()))
                vutils.save_image(G_sample.data, 'map_%d_cell.png'% (t),nrow=10,normalize=True)

            #torch.save(netG.state_dict(), './params/0517/tumor_netG_epoch_%d.pth' % (epoch))
            #torch.save(netD.state_dict(), './params/0517/tumor_netD_epoch_%d.pth' % (epoch))
            #torch.save(netD_D.state_dict(), './params/0517/tumor_netD_D_epoch_%d.pth' % (epoch))
            #torch.save(netD_Q.state_dict(), './params/0517/tumor_netD_Q_epoch_%d.pth' % (epoch))
            #torch.save(netD_Q_2.state_dict(), './params/0517/tumor_netD_Q_2_epoch_%d.pth' % (epoch))
            #torch.save(netD_Q_3.state_dict(), './params/0517/tumor_netD_Q_3_epoch_%d.pth' % (epoch))
        
        #storage = np.zeros((100,3,64,64),dtype=np.float32)
        #z_fix = Variable(torch.randn(1,128,1,1).cuda().normal_(0, 1))

0
3841
1
3840
2
3839
3
3838
4
3837
5
3836
6
3835
7
3834
8
3833
9
3832
10
3831
11
3830
12
3829
13
3828
14
3827
15
3826
16
3825
17
3824
18
3823
19
3822
3
3821
4
3820
5
3819
6
3818
7
3817
8
3816
9
3815
10
3814
11
3813
12
3812
13
3811
14
3810
15
3809
16
3808
17
3807
18
3806
19
3805
20
3804
21
3803
22
3802
3
3801
4
3800
5
3799
6
3798
7
3797
8
3796
9
3795
10
3794
11
3793
12
3792
13
3791
14
3790
15
3789
16
3788
17
3787
18
3786
19
3785
20
3784
21
3783
22
3782
3
3781
4
3780
5
3779
6
3778
7
3777
8
3776
9
3775
10
3774
11
3773
12
3772
13
3771
14
3770
15
3769
16
3768
17
3767
18
3766
19
3765
20
3764
21
3763
22
3762
3
3761
4
3760
5
3759
6
3758
7
3757
8
3756
9
3755
10
3754
11
3753
12
3752
13
3751
14
3750
15
3749
16
3748
17
3747
18
3746
19
3745
20
3744
21
3743
22
3742
3
3741
4
3740
5
3739
6
3738
7
3737
8
3736
9
3735
10
3734
11
3733
12
3732
13
3731
14
3730
15
3729
16
3728
17
3727
18
3726
19
3725
20
3724
21
3723
22
3722
3
3721
4
3720
5
3719
6
3718
7
3717
8
3716
9
3715
10
3714
11
3713
12
3712
13
3711
14
3

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/ultratb.py", line 1132, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/ultratb.py", line 313, in wrapped
    return f(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/ultratb.py", line 358, in _fixed_getinnerframes
    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
  File "/usr/lib/python2.7/inspect.py", line 1049, in getinnerframes
    framelist.append((tb.tb_frame,) + getframeinfo(tb, context))
  File "/usr/lib/python2.7/inspect.py", line 1009, in getframeinfo
    filename = getsourcefile(frame) or getfile(frame)
  File "/usr/lib/python2.7/inspect.py", line 454, in getsourcefile
    if hasattr(getmodule(object, filename), '__loader__'):
  File "/usr/lib/python2.7/inspect.py", line 500, in getmodule
    os.path.realpath(f)] = module

IndexError: string index out of range

In [None]:
netG(Variable(torch.from_numpy(fixed_noise)))

In [None]:
dataiter.samples_remaining

In [None]:
torch.save(netG.state_dict(), '/data/params/netG_0524_epoch_%d.pth' % (epoch))
torch.save(netD.state_dict(), '/data/params/netD_0524_epoch_%d.pth' % (epoch))
torch.save(netD_D.state_dict(), '/data/params/netD_D_0524_epoch_%d.pth' % (epoch))
torch.save(netD_Q.state_dict(), '/data/params/netD_Q_0524_epoch_%d.pth' % (epoch))
torch.save(netD_Q_2.state_dict(), '/data/params/netD_Q_2_0524_epoch_%d.pth' % (epoch))