# MolShapeGAN

## Generative multiscale analysis of de novo proteome-inspired molecular structures and nanomechanical optimization using a VoxelPerceiver transformer model

Zhenze Yang, Yu-Chuan Hsu, Markus J. Buehler, "Generative multiscale analysis of de novo proteome-inspired molecular structures and nanomechanical optimization using a VoxelPerceiver transformer model," Journal of the Mechanics and Physics of Solids, Volume 170, January 2023, 105098.

mbuehler@MIT.EDU

In [None]:
import torch
from torch import optim
from torch import nn
from collections import OrderedDict
from utils import *
import os
 
import datetime

import matplotlib.pyplot as plt
import numpy as np

import visdom

print("Torch version:", torch.__version__) 

In [None]:
CPUonly=True
CPUonly=False

device = torch.device("cuda" if torch.cuda.is_available() 
                                  else "cpu")

if CPUonly == True:
     print ("CPU!")
     device = torch.device("cpu")

In [None]:
import numpy as np
import os
import trimesh  
import time, warnings
from PIL import Image
import glob
import matplotlib.pyplot as plt
import cv2
        
def img2vox(loc, end='png', thresh=127, im_sh=False):
    vox = []
    imgs = sorted(glob.glob(loc+'/*.'+end), key=lambda x: (len(x), x))

    print('found',len(imgs),'images.',imgs)
    for i in imgs:
        new_frame = Image.open(i).convert('RGB').convert('L')
        
        new_frame =np.array(new_frame)
        _, new_frame = cv2.threshold(new_frame,thresh,255,cv2.THRESH_BINARY)
        
        if im_sh:
                plt.imshow(new_frame, interpolation='nearest',cmap="hot")
                plt.colorbar()
                plt.show()
                
        vox.append(np.array(new_frame))

    return np.array(vox)

def TwoDimg23Dvox(img_,maxheight=10, invvv=False, thresh=0, normalize=False,norm_zero=False,clipvalue=0,sat=1,GaussSmoothimage=False, iblur=0, BilatSmoothimage=False, centerrep=0, darea=0, invertresult=False):

   vox = []
   print ("Number of images: ", len (img_[:]))  
   for imc in range (len(img_[:])):
    img=img_[imc]
    print ("Considering image: ", imc)
    if GaussSmoothimage==True:
        print ("smoothen using Gaussian Blur...")
        #Applying  blur filter
        img = cv2.GaussianBlur(img,(5,5),0)
        
        for ij in range (iblur):
                img = cv2.GaussianBlur(img, (3,3), 10,10)
    if BilatSmoothimage==True:
        print ("smoothen using Bilat Blur...")
        #Applying  blur filter
        img = cv2.bilateralFilter(img,9,75,75) 
        for ij in range (iblur):
                img = cv2.bilateralFilter(img,9,75,75) 

    img =np.array(img)
    img = np.array(img, dtype = np.float16)

    #apply overall threshold - below thresh = 0, creates holes
    img[img <thresh] = 0
    
    if normalize==True:
            minval=np.amin(img[img>0])-clipvalue
         
            if norm_zero: 
             
                img=img-minval
              
                img=np.clip(img, 0, 255)
               
                
                plt.imshow(img, interpolation='nearest',cmap="hot")
                plt.colorbar()
                plt.show()
                
            print ( np.amin(img), np.amax(img))
        
            img=img/np.amax(img)*255
    
            
    img=img*sat
    img[img >255] = 255
    
    
    img = np.array(img, dtype = np.uint8)

    if darea>0:
        
        _, thresh_i = cv2.threshold(img,80,255,cv2.THRESH_BINARY)
        plt.imshow(thresh_i, interpolation='nearest',cmap="hot")
        plt.colorbar()
        plt.show()

        cnts = cv2.findContours(thresh_i, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE )
        cnts = cnts[0] if len(cnts) == 2 else cnts[1]
        for c in cnts:
            area = cv2.contourArea(c)
           
            if area < darea:
              
                cv2.drawContours(img, [c], -1, (0,0,0), -1)
 
    if invertresult:
        img = cv2.bitwise_not(img)
    
    plt.imshow(img, interpolation='nearest',cmap="hot")
    plt.colorbar()
    plt.show()
   
    if invvv==True:
   
      for ii in range (maxheight):
    
        threshold=(maxheight-ii)/(maxheight)*255
        
        z = np.copy(img)
        
        z[z <threshold] = 0
        z[z >0] = 255
            
        vox.append(z)

    if centerrep>0:
    
        z2 = np.copy(img)

        z2[z2 >=0] = 255
    
        for infg in range( (centerrep)):
             vox.append(z2)

    for ii in range (maxheight):
    
        threshold=(ii/maxheight)*255
       
        z = np.copy(img)
        
        z[z <threshold] = 0
        z[z >0] = 255
        
        vox.append(z)

   return np.array(vox)

def TwoDimg23Dvox_BACKUP(img,maxheight=10, invvv=False, thresh=0, normalize=False,sat=1,GaussSmoothimage=False, BilatSmoothimage=False):

    if GaussSmoothimage==True:
      
        #Applying the blur filter
        img = cv2.GaussianBlur(img,(5,5),0)
     
    if BilatSmoothimage==True:
        print ("smoothen using Gaussian Blur...")
    
        img = cv2.bilateralFilter(img,9,75,75) 
        
    img =np.array(img)
   
    img[img <thresh] = 0
    
    if normalize==True:
            
            print ("Normalize...", np.amin(img), np.amax(img))
            img=img-(np.amin(img)   )
            img[img <0] = 0
    
            print (  np.amax(img))
            img=img/np.amax(img)*255
            
    img=img*sat
    img[img >255] = 255
    
    plt.imshow(img, interpolation='nearest',cmap="hot")
    plt.colorbar()
    plt.show()

    vox = []
     
    if invvv==True:
    
      for ii in range (maxheight):
    
        threshold=(maxheight-ii)/(maxheight)*255
  
        z = np.copy(img)
        
        z[z <threshold] = 0
        z[z >0] = 255
            
        vox.append(z)

    for ii in range (maxheight):
    
        threshold=(ii/maxheight)*255
       
        z = np.copy(img)
        
        z[z <threshold] = 0
        z[z >0] = 255
    
            
        vox.append(z)

    return np.array(vox)

def vox2stl(vox, loc='.', filename='', save=True, smooth=False, smooth_iter=20, stamp=False):
    mesh = trimesh.voxel.ops.matrix_to_marching_cubes(vox)

    if smooth:
         
        mesh=  trimesh.smoothing.filter_mut_dif_laplacian(mesh, lamb=0.85, iterations=smooth_iter)
        
    mesh.rezero()
    if save:

        from time import strftime
        stamp = strftime("%m_%d_%H_%M")
        os.makedirs(loc, exist_ok=True)
       
        exportname=loc+'/'+filename+'.stl'    
        mesh.export(exportname)
        print(f'save stl model to {exportname}' )
        
        return exportname

# return a vox that vox1 OR vox2 exist (A||B) 
def union(vox1, vox2): 
    return np.logical_or(vox1, vox2)

# return a vox that EITHER only vox1 OR vox2 exists (A||B-A&&B)
def xor(vox1, vox2):
    return np.logical_xor(vox1, vox2)

# return a vox that both vox1 exists BUT vox2 does not (A-B) 
def substraction(vox1, vox2):
    return np.logical_xor(np.logical_or(vox1, vox2), vox2)

# return a vox that vox1 and vox2 BOTH exist (A&&B) 
def intersection(vox1, vox2):
    return np.logical_and(vox1, vox2)

# return a vox inverse to the original (B!=A) 
def inverse(vox):
    return np.logical_not(vox)

def repeat(vox, repeatance_array):
    return np.tile(vox, repeatance_array)*1

def vox2img(vox, loc='.', filename=''):
    
    from time import strftime
    stamp = strftime("%m_%d_%H_%M")
    os.makedirs(loc+'/'+filename+'_img/', exist_ok=True)
    for i in range(vox.shape[0]):
        temp_img=vox[i]
        plt.imsave(loc+'/'+filename+'_img/'+filename+'_'+stamp+'_'+str(i)+'.png', temp_img, cmap='gray')
        from IPython import display
        display.clear_output(wait=True)
        plt.imshow(temp_img, cmap='gray')    
        plt.axis('off')
        plt.title(str(i))
        plt.show()
    
    print('save stl model as a stack of images into {}'.format(loc+'/'+filename+'_img/'))

def vox2npy(vox, loc='.', filename=''):
    from time import strftime
    stamp = strftime("%m_%d_%H_%M")
    os.makedirs(loc, exist_ok=True)
    np.save(loc+'/'+filename+'_'+stamp+'.npy', vox)
    print('save stl model as a 3D array to {}'.format(loc+'/'+filename+'_'+stamp+'.npy'))

def npy2vox(filename=''):
    return np.load(filename)*1

In [None]:
#https://github.com/xchhuang/simple-pytorch-3dgan
class net_G(torch.nn.Module):
    def __init__(self, args):
        super(net_G, self).__init__()
        self.args = args
        self.cube_len = cube_len
        self.bias = bias
        self.z_dim = z_dim
        self.f_dim = cube_len

        padd = (0, 0, 0)
        if self.cube_len == 32:
            padd = (1,1,1)

        self.layer1 = self.conv_layer(self.z_dim, self.f_dim*8, kernel_size=4, stride=2, padding=padd, bias=self.bias)
        self.layer2 = self.conv_layer(self.f_dim*8, self.f_dim*4, kernel_size=4, stride=2, padding=(1, 1, 1), bias=self.bias)
        self.layer3 = self.conv_layer(self.f_dim*4, self.f_dim*2, kernel_size=4, stride=2, padding=(1, 1, 1), bias=self.bias)
        self.layer5 = self.conv_layer(self.f_dim*2, self.f_dim, kernel_size=4, stride=2, padding=(1, 1, 1), bias=self.bias)
        
        self.layer6 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(self.f_dim, 1, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
            torch.nn.Sigmoid()
        )

    def conv_layer(self, input_dim, output_dim, kernel_size=4, stride=2, padding=(1,1,1), bias=False):
        layer = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding),
            torch.nn.BatchNorm3d(output_dim),
            torch.nn.ReLU(True)
         
        )
        return layer

    def forward(self, x):
        out = x.view(-1, self.z_dim, 1, 1, 1)
        # print(out.size())  # torch.Size([32, 200, 1, 1, 1])
        out = self.layer1(out)
        # print(out.size())  # torch.Size([32, 256, 2, 2, 2])
        out = self.layer2(out)
        # print(out.size())  # torch.Size([32, 128, 4, 4, 4])
        out = self.layer3(out)
        # print(out.size())  # torch.Size([32, 64, 8, 8, 8])
        #out = self.layer4(out)
        # print(out.size())  # torch.Size([32, 32, 16, 16, 16])
        out = self.layer5(out)
        # print(out.size())  # torch.Size([32, 1, 32, 32, 32])
        out = self.layer6(out)
        #print(out.size())  # torch.Size([32, 1, 32, 32, 32])
        out = torch.squeeze(out)
        return out


class net_D(torch.nn.Module):
    def __init__(self, args):
        super(net_D, self).__init__()
        self.args = args
        self.cube_len = cube_len
        self.leak_value = leak_value
        self.bias = bias

        padd = (0,0,0)
        if self.cube_len == 32:
            padd = (1,1,1)

        self.f_dim = cube_len

        self.layer1 = self.conv_layer(1, self.f_dim, kernel_size=4, stride=2, padding=(1,1,1), bias=self.bias)
        self.layer2 = self.conv_layer(self.f_dim, self.f_dim*2, kernel_size=4, stride=2, padding=(1,1,1), bias=self.bias)
        self.layer4 = self.conv_layer(self.f_dim*2, self.f_dim*4, kernel_size=4, stride=2, padding=(1,1,1), bias=self.bias)
        self.layer5 = self.conv_layer(self.f_dim*4, self.f_dim*8, kernel_size=4, stride=2, padding=(1,1,1), bias=self.bias)

        self.layer6 = torch.nn.Sequential(
            torch.nn.Conv3d(self.f_dim*8, 1, kernel_size=4, stride=2, bias=self.bias, padding=padd),
            torch.nn.Sigmoid()
        )

    def conv_layer(self, input_dim, output_dim, kernel_size=4, stride=2, padding=(1,1,1), bias=False):
        layer = torch.nn.Sequential(
            torch.nn.Conv3d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding),
            torch.nn.BatchNorm3d(output_dim),
            torch.nn.LeakyReLU(self.leak_value, inplace=True)
        )
        return layer

    def forward(self, x):
        # out = torch.unsqueeze(x, dim=1)
        out = x.view(-1, 1, self.cube_len, self.cube_len, self.cube_len)
        # print(out.size()) # torch.Size([32, 1, 32, 32, 32])
        out = self.layer1(out)
        # print(out.size())  # torch.Size([32, 32, 16, 16, 16])
        out = self.layer2(out)
        # print(out.size())  # torch.Size([32, 64, 8, 8, 8])
        #out = self.layer3(out)
        # print(out.size())  # torch.Size([32, 128, 4, 4, 4])
        out = self.layer4(out)
        # print(out.size())  # torch.Size([32, 256, 2, 2, 2])
        # out = out.view(-1, 256*2*2*2)
        # print (out.size())
        out = self.layer5(out)
        # print(out.size())  # torch.Size([32, 1, 1, 1, 1])
        out = self.layer6(out)
        #print(out.size())  # torch.Size([32, 1, 1, 1, 1])
        out = torch.squeeze(out)
        return out

In [None]:
from simple_3dviz import Mesh
from simple_3dviz.window import show
from simple_3dviz.utils import render
from simple_3dviz.behaviours.io import SaveFrames
from simple_3dviz.behaviours.movements import CameraTrajectory
from simple_3dviz.behaviours.trajectory import Circle
from simple_3dviz.behaviours.misc import LightToCamera

def tester(args, threshold=0.5, startnum=0):

    model_name + '/' + args.logs + '/test_outputs'
    if not os.path.exists(image_saved_path):
        os.makedirs(image_saved_path)

    if args.use_visdom:
        vis = visdom.Visdom()

    save_file_path = output_dir + '/' + args.model_name
    pretrained_file_path_G = save_file_path + '/' + args.logs + '/models/G.pth'
    pretrained_file_path_D = save_file_path + '/' + args.logs + '/models/D.pth'

    print(pretrained_file_path_G)

    D = net_D(args)
    G = net_G(args)

    if not torch.cuda.is_available():
        G.load_state_dict(torch.load(pretrained_file_path_G, map_location={'cuda:0': 'cpu'}))
        D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))
    else:
        G.load_state_dict(torch.load(pretrained_file_path_G))
        D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))

    print('visualizing model')

    G.to(device)
    D.to(device)
    G.eval()
    D.eval()

    N = num_examples

    for i in range(N):
     
        z = generateZ(args, 1)

        fake = G(z)
        samples = fake.unsqueeze(dim=0).detach().cpu().numpy()
      
        y_prob = D(fake)
        y_real = torch.ones_like(y_prob)

        # visualization
        if not args.use_visdom:
            _=SavePloat_Voxels(samples, image_saved_path, 'tester_' + str(i+startnum), threshold)  # norm_
        else:
            plotVoxelVisdom(samples[0, :], vis, "tester_" + str(i+startnum))
            
        if args.save_np:
            print ("Save numpy file(s)...: ")
           
            for iii in range (samples.shape[0]):
                f=image_saved_path+f'/np_vox_{i+startnum:04d}_{iii:04d}'
                print ("save ", f)
             
                np.save(f, samples[iii,:,:,:])
            
def tester_interpolate(args, threshold=0.5, z1=None, z2=None, steps=5):
    print('Evaluation Mode...')

    image_saved_path = output_dir + '/' + args.model_name + '/' + args.logs + '/test_outputs'
    if not os.path.exists(image_saved_path):
        os.makedirs(image_saved_path)

    if args.use_visdom:
        vis = visdom.Visdom()

    save_file_path = output_dir + '/' + args.model_name
    pretrained_file_path_G = save_file_path + '/' + args.logs + '/models/G.pth'
    pretrained_file_path_D = save_file_path + '/' + args.logs + '/models/D.pth'

    print(pretrained_file_path_G)

    D = net_D(args)
    G = net_G(args)

    if not torch.cuda.is_available():
        G.load_state_dict(torch.load(pretrained_file_path_G, map_location={'cuda:0': 'cpu'}))
        D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))
    else:
        G.load_state_dict(torch.load(pretrained_file_path_G))
        D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))

    print('visualizing model')

    G.to(device)
    D.to(device)
    G.eval()
    D.eval()

    if z1 == None:
        z1 = generateZ(args, 1)
    if z2 == None:
        z2 = generateZ(args, 1)
    
    for i in range(steps):
      
        z=z1+ (z2-z1)*(i/(steps-1))
     
        fake = G(z)
        samples = fake.unsqueeze(dim=0).detach().cpu().numpy()
        
        y_prob = D(fake)
        y_real = torch.ones_like(y_prob)
       
        # visualization
        if not args.use_visdom:
            fname=SavePloat_Voxels(samples, image_saved_path, 'tester_interpol_' + str(i), threshold)  # norm_
            fname_n=image_saved_path+f'/{i:04d}.png' 
            print ("New name: ", fname_n)
            os.rename(fname, fname_n)
        if args.save_np:
            print ("Save numpy file(s)...: ")
          
            for iii in range (samples.shape[0]):
                f=image_saved_path+f'/np_vox_{i:04d}_{iii:04d}'
                print ("save ", f)
            
                np.save(f, samples[iii,:,:,:] )
                
        if args.use_3dviz:
            
            render(Mesh.from_voxel_grid(voxels=samples[0,:] > threshold,  colors=[0.4,0.4,0.8], sizes=(.006,.006,.006)),
               behaviours=[
                    #ctrj,
                    LightToCamera(),
                    SaveFrames(image_saved_path+f'/3dvis_{i:04d}.png')#, every_n=5)
               ],
                   size=(1024,1024),
                   background=(1.0, 1.0, 1.0, 1.0),
               n_frames=1,
               camera_position=(.8,.8,.8), camera_target=(0., 0, 0),
               light=(1, 5, 5)
        )

        else:
            plotVoxelVisdom(samples[0, :], vis, "tester_interpol_" + str(i))

In [None]:
import torch
from torch import optim
from torch import nn
from utils import *
import os

# added
import datetime
import time
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
import params
from tqdm import tqdm


def save_train_log(writer, loss_D, loss_G, itr):
    scalar_info = {}
    for key, value in loss_G.items():
        scalar_info['train_loss_G/' + key] = value

    for key, value in loss_D.items():
        scalar_info['train_loss_D/' + key] = value

    for tag, value in scalar_info.items():
        writer.add_scalar(tag, value, itr)


def save_val_log(writer, loss_D, loss_G, itr):
    scalar_info = {}
    for key, value in loss_G.items():
        scalar_info['val_loss_G/' + key] = value

    for key, value in loss_D.items():
        scalar_info['val_loss_D/' + key] = value

    for tag, value in scalar_info.items():
        writer.add_scalar(tag, value, itr)


def trainer(args,train_dset_loaders, restart=False):
   
    save_file_path = output_dir + '/' + args.model_name
    print(save_file_path)  # ../outputs/dcgan
    
    if not os.path.exists(save_file_path):
        os.makedirs(save_file_path)

    if args.logs:
        print ("SETTING UP LOGS...")
        model_uid = datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
        #writer = SummaryWriter(output_dir+'/'+args.model_name+'/'+model_uid+'_'+args.logs+'/logs')
        writer = SummaryWriter(output_dir + '/' + args.model_name + '/' + args.logs + '/logs')

        image_saved_path = output_dir + '/' + args.model_name + '/' + args.logs + '/images'
        model_saved_path = output_dir + '/' + args.model_name + '/' + args.logs + '/models'

        if not os.path.exists(image_saved_path):
            os.makedirs(image_saved_path)
        if not os.path.exists(model_saved_path):
            os.makedirs(model_saved_path)

    # model define
    if restart==False:
        D = net_D(args)
        G = net_G(args)
    if restart:
        print ("Restart from existing neural net....")
        
    # print total number of parameters in a model
    x = sum(p.numel() for p in G.parameters() if p.requires_grad)
    x = sum(p.numel() for p in D.parameters() if p.requires_grad)
    
    D_solver = optim.Adam(D.parameters(), lr=d_lr, betas=beta)
    # D_solver = optim.SGD(D.parameters(), lr=args.d_lr, momentum=0.9)
    G_solver = optim.Adam(G.parameters(), lr=g_lr, betas=beta)

    D.to(device)
    G.to(device)

    criterion_D = nn.MSELoss()

    criterion_G = nn.L1Loss()

    itr_val = -1
    itr_train = -1
    
    print ("START TRAINING....")
    for epoch in range(epochs):
        start = time.time()
        for phase in ['train']:
            if phase == 'train':
                # if args.lrsh:
                #     D_scheduler.step()
                D.train()
                G.train()
            else:
                D.eval()
                G.eval()

            running_loss_G = 0.0
            running_loss_D = 0.0
            running_loss_adv_G = 0.0
            
            for i, X in enumerate(train_dset_loaders):
                if phase == 'train':
                    itr_train += 1

                X = X.to(device)
              
                batch = X.size()[0]

                Z = generateZ(args, batch)
              
                d_real = D(X)

                fake = G(Z)
                
                d_fake = D(fake)

                real_labels = torch.ones_like(d_real).to(device)
                fake_labels = torch.zeros_like(d_fake).to(device)
               
                if soft_label:
                    real_labels = torch.Tensor(batch).uniform_(0.7, 1.2).to(device)
                    fake_labels = torch.Tensor(batch).uniform_(0, 0.3).to(device)

                d_real_loss = criterion_D(d_real, real_labels)

                d_fake_loss = criterion_D(d_fake, fake_labels)

                d_loss = d_real_loss + d_fake_loss

                # no deleted
                d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
                d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
                d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

                if d_total_acu < d_thresh:
                    D.zero_grad()
                    d_loss.backward()
                    D_solver.step()

                # =============== Train the generator ===============#

                Z = generateZ(args, batch)

                fake = G(Z)  # generated fake: 0-1, X: 0/1
                d_fake = D(fake)

                adv_g_loss = criterion_D(d_fake, real_labels)
                recon_g_loss = criterion_G(fake, X)
                g_loss = adv_g_loss

                if args.local_test:
                    print('Iteration-{} , D(x) : {:.4}, D(G(x)) : {:.4}'.format(itr_train, d_loss.item(),
                                                                                adv_g_loss.item()))

                D.zero_grad()
                G.zero_grad()
                g_loss.backward()
                G_solver.step()

                # =============== logging each 10 iterations ===============#

                running_loss_G += recon_g_loss.item() * X.size(0)
                running_loss_D += d_loss.item() * X.size(0)
                running_loss_adv_G += adv_g_loss.item() * X.size(0)

                if args.logs:
                    loss_G = {
                        'adv_loss_G': adv_g_loss,
                        'recon_loss_G': recon_g_loss,
                    }

                    loss_D = {
                        'adv_real_loss_D': d_real_loss,
                        'adv_fake_loss_D': d_fake_loss,
                    }

                    if itr_train % 10 == 0 and phase == 'train':
                        print(".", end="")
                        save_train_log(writer, loss_D, loss_G, itr_train)
                        
            # =============== each epoch save model or save image ===============#
            epoch_loss_G = running_loss_G# 
            epoch_loss_D = running_loss_D#  
            epoch_loss_adv_G = running_loss_adv_G  

            end = time.time()
            epoch_time = end - start

            print('\nEpoch {}, D(x) : {:.4}, D(G(x)) : {:.4}'.format(epoch, epoch_loss_D, epoch_loss_adv_G))
            print('Elapsed Time: {:.4} min'.format(epoch_time / 60.0))

            if (epoch + 1) % model_save_step == 0:
              
                torch.save(G.state_dict(), model_saved_path + '/G.pth')
                torch.save(D.state_dict(), model_saved_path + '/D.pth')

                samples = fake.cpu().data[:num_examples].squeeze().numpy()

                _=SavePloat_Voxels(samples, image_saved_path, epoch)

In [None]:
import scipy.ndimage as nd
import scipy.io as io
import matplotlib
import params

import matplotlib.pyplot as plt
import skimage.measure as sk
from mpl_toolkits import mplot3d
import matplotlib.gridspec as gridspec
import numpy as np
from torch.utils import data
from torch.autograd import Variable
import torch
import os
import pickle

def getVoxelFromMat(path, cube_len=64,pr_info=False, numpyfile=False):
    
    if numpyfile: #load numpy file...
        voxels=np.load(path)
        voxels = nd.zoom(voxels, (.5, .5, .5), mode='constant', order=0)
    
    if numpyfile==False:

        if cube_len == 32:
            voxels = io.loadmat(path)[Instance_name] # 30x30x30
            if pr_info:
                print ('Raw: ', voxels.shape)
            voxels = np.pad(voxels, (1, 1), 'constant', constant_values=(0, 0))

        else:
            voxels = io.loadmat(path)[Instance_name] # 30x30x30
            if pr_info:
                print ('Raw: ', voxels.shape)

            voxels = np.pad(voxels, (1, 1), 'constant', constant_values=(0, 0))
            voxels = nd.zoom(voxels, (2, 2, 2), mode='constant', order=0)
            
        if pr_info:
            print ("Final shape: ", voxels.shape)
    return voxels
 


def getVFByMarchingCubes(voxels, threshold=0.5):
    v, f, _, _  = sk.marching_cubes(voxels, level=threshold)
    return v, f


def plotVoxelVisdom(voxels, visdom, title):
    v, f = getVFByMarchingCubes(voxels)
    visdom.mesh(X=v, Y=f, opts=dict(opacity=0.5, title=title))


def SavePloat_Voxels(voxels, path, iteration, threshold=.5, name=''):
    print ("Generate ", num_examples, )
    
    if save_STL:
        voxels_clean=np.copy(voxels)
        for ii, sample in enumerate(voxels_clean[:num_examples]):
          
            sample[sample > threshold] = 1
            sample[sample <= threshold] = 0
            
            vox2stl(sample, loc=path, filename=f'{name}_voxel_{ii}_{iteration}', save=True, smooth=smooth_STL, smooth_iter=20)

    voxels = voxels[:num_examples].__ge__(threshold)
    fig = plt.figure(figsize=(32, 16))
    gs = gridspec.GridSpec(2, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(voxels):
        x, y, z = sample.nonzero() #returns indixes or coordinates of all nonzero entries...
        ax = plt.subplot(gs[i], projection='3d')
        ax.scatter(x, y, z, zdir='z', c='red', marker='s', s=2)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        
        ax.set_xlim((0,cube_len))
        ax.set_ylim((0,cube_len))
        ax.set_zlim((0,cube_len))
     
    fname=path +f'/{name}_'+ '{}.png'.format(str(iteration).zfill(4))
    plt.savefig(fname, bbox_inches='tight')
  
    plt.close()
    
    return fname


class ShapeNetDataset(data.Dataset):

    def __init__(self, root, args, train_or_val="train", numpyfile=False):
        
        
        self.root = root
        self.listdir = os.listdir(self.root)
        
        self.numpyfile=numpyfile
       
        data_size = len(self.listdir)
        self.listdir = self.listdir[0:int(data_size)]
        
        print ('Total data size =', len(self.listdir))
        self.args = args

    def __getitem__(self, index):
        
        with open(self.root + self.listdir[index], "rb") as f:
            
            volume = np.asarray(getVoxelFromMat(f, cube_len, numpyfile=self.numpyfile), dtype=np.float32)
           
        return torch.FloatTensor(volume)

    def __len__(self):
        return len(self.listdir)

def generateZ(args, batch):

    if z_dis == "norm":
        Z = torch.Tensor(batch, z_dim).normal_(0, 0.33).to(device)
    elif z_dis == "uni":
        Z = torch.randn(batch, z_dim).to(device).to(device)
    else:
        print("z_dist is not normal or uniform")

    return Z

In [None]:
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

In [None]:
def print_params():
    l = 16
    print(l * '*' + 'hyper-parameters' + l * '*')

    print('epochs =', epochs)
    print('batch_size =', batch_size)
    print('soft_labels =', soft_label)
    print('adv_weight =', adv_weight)
    print('d_thresh =', d_thresh)
    print('z_dim =', z_dim)
    print('z_dis =', z_dis)
    print('model_images_save_step =', model_save_step)
    print('data =', model_dir)
    print('device =', device)
    print('g_lr =', g_lr)
    print('d_lr =', d_lr)
    print('cube_len =', cube_len)
    print('leak_value =', leak_value)
    print('bias =', bias)

    print(l * '*' + 'hyper-parameters' + l * '*')

### Set up and train

In [None]:
import argparse

epochs = 500

batch_size = 128

soft_label = False
adv_weight = 0
d_thresh = 0.8
z_dim = 512
z_dis = "norm"
model_save_step = 1
g_lr = 0.0025
d_lr = 0.00001
beta = (0.5, 0.999)
cube_len = 64
leak_value = 0.2
bias = False
Instance_name='instance' #original chair dataset
Instance_name='Volume'  #data in VOXELDATA
smooth_STL=False

save_STL=False

num_examples =8
model_dir= './np_voxels_proteins_128x128x128_ALLPDB/'

parser = argparse.ArgumentParser()

# loggings parameters
parser.add_argument('--logs', type=str, default='first_test', help='logs by tensorboardX')
parser.add_argument('--local_test', type=str2bool, default=False, help='local test verbose')
parser.add_argument('--model_name', type=str, default="dcgan", help='model name for saving')
parser.add_argument('--output_dir', type=str, default="output", help='output_dir')
parser.add_argument('--test', type=str2bool, default=False, help='call tester.py')
parser.add_argument('--use_visdom', type=str2bool, default=False, help='visualization by visdom')
parser.add_argument('--use_3dviz', type=str2bool, default=False, help='visualization by 3dviz')
parser.add_argument('--save_np', type=str2bool, default=False, help='save voxel files as npy')

print_params()

args = parser.parse_args("--model_name PMMD_3DGAN --output_dir output --use_3dviz True".split())

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import cv2
from PIL import Image

import pyvista as pv

def plot_single_mat (matname='./volumetric_data/chair/30/train/chair_000000200_12.mat',cube_len=32, 
                     numpyfile=False, savepath='./', name='', ID=''):

    voxel=getVoxelFromMat(matname, cube_len=cube_len,pr_info=True, numpyfile=numpyfile)
    voxel_exp=np.expand_dims(voxel, 0) 
    
    fname=SavePloat_Voxels(voxel_exp, savepath, iteration=0, name=name) 
   
    img = Image.open(fname)
    plt.imshow(img)
    plt.axis('off')
   
    plt.show()
    
    filename=f'{matname}_STL_.stl'
    from time import strftime
    exportname=vox2stl(voxel, loc='./', filename=filename, save=True, smooth=True, smooth_iter=20)
    print (exportname)
    mesh = pv.read(exportname)
    pngname=  strftime("./%m_%d_%H_%M")+'png'
    cpos = mesh.plot(  show_edges=False, color=True, screenshot=f"{ID}_{pngname}", 
                     background='white', show_axes=False, parallel_projection=True)


In [None]:
Instance_name=''  

data_dir = './PyUUL - protein to point cloud/pyuul/'
model_dir = './np_voxels_proteins_128x128x128_ALLPDB/'  # change it to train on other data models

output_dir = './outputs'
cube_len = 64
batch_size = 128
num_examples =8

dsets_path = data_dir + model_dir  

train_dsets = ShapeNetDataset(dsets_path, args, "train",  numpyfile=True)
train_dset_loaders = torch.utils.data.DataLoader(train_dsets, batch_size=batch_size, shuffle=True)#,num_workers=1)


In [None]:
#set up trainer

args = parser.parse_args("--model_name PMMD_PROTEIN_64_ALLPDB_30K --output_dir output".split())

trainer(args,train_dset_loaders, restart=False  )

In [None]:
args = parser.parse_args("--model_name PMMD_PROTEIN_64_ALLPDB_30K --output_dir output_10000 --use_visdom False --test True --save_np True".split())
print (args)
 
cube_len = 64

smooth_STL=True
num_examples =9000
save_STL=True
print_params()
tester(args, threshold=.1, startnum=1024)

In [None]:
#generate two z1 z2 and interpolate between them

args = parser.parse_args("--model_name PMMD_PROTEIN_64_ALLPDB_30K --output_dir output --use_visdom False --test True".split())
print (args)
 
cube_len = 64

smooth_STL=True
 
save_STL=True
print_params()
tester_interpolate(args, threshold=.1, steps=20)
