### Build and Train 3DGAN


In [1]:
import os
import torch
from torch import optim
from torch import nn
from utils import utils3D
from torch.utils import data
from torch.autograd import Variable
from models.threed.gan import GAN
import matplotlib
import pickle
matplotlib.use('agg')

#### Variables to define size of latent feature, learning rates of G and D ... beta parameter for adam and batch size

In [2]:
Z_LATENT_SPACE = 200
G_LR = 0.0025
D_LR = 0.001
EPOCHS = 1
BETA = (0.5, 0.5) 
BSIZE = 32
CUBE_LEN = 64

#### Define model

In [3]:
gan3D = GAN(epochs=EPOCHS, sample=8, 
            batch=BSIZE, betas=BETA,
            g_lr=G_LR, d_lr=D_LR, cube_len=CUBE_LEN, latent_v=Z_LATENT_SPACE)

---------- Networks architecture -------------
_G(
  (layer1): Sequential(
    (0): ConvTranspose3d(200, 512, kernel_size=(4, 4, 4), stride=(2, 2, 2), bias=False)
    (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer2): Sequential(
    (0): ConvTranspose3d(512, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer3): Sequential(
    (0): ConvTranspose3d(256, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer4): Sequential(
    (0): ConvTranspose3d(128, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (lay

    Found GPU0 GeForce GTX 860M which is of cuda capability 5.0.
    PyTorch no longer supports this GPU because it is too old.
    


In [None]:
gan3D.train()
print("Training finished!")

### Load and plot trained nets

In [None]:
import torch
from models.threed.generator import _G
from models.threed.discriminator import _D
from utils import utils3D
import skimage.measure as sk
import visdom
import trimesh

In [None]:
vis = visdom.Visdom()
D = _D().cuda()
G = _G().cuda()
G.load_state_dict(torch.load('output/gan_tmp/3DGAN_100epochs_G.pkl'))
D.load_state_dict(torch.load('output/gan_tmp/3DGAN_100epochs_D.pkl'))

In [None]:
Z_LATENT_SPACE = 200
BATCH_SIZE = 32

Z = utils3D.var_or_cuda(torch.randn(BATCH_SIZE, Z_LATENT_SPACE))   
fake = G(Z)
samples = fake.cpu().data[:10].squeeze().numpy()
for s, sample in enumerate(samples):
    utils3D.plotVoxelVisdom(str(s), sample, vis, "3D vessels")


#### Save to stl generated vessels

In [None]:
samples[0].shape
v, f = sk.marching_cubes_classic(samples[4], level=0.5)
sample_mesh = trimesh.Trimesh(v, f)
sample_mesh.export('/tmp/test.stl')

In [None]:
v, f = sk.marching_cubes_classic(samples[0], level=0.5)