In [2]:
import os
import h5py
import numpy as np
from tqdm.auto import tqdm
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

In [None]:
batch_size = 100
mini_batch_size = 50
lr_G = 0.0025
lr_D = 0.0001
beta1 = 0.5
workers = 0
dataset_name = 'shapenet_v2'
obj = 'airplane'
dim = 128
noise_dim = 200 # latent space vector dim
in_channels = 512 # convolutional channels
run_parallel = False

In [3]:
k = int(batch_size / mini_batch_size)
print('batch size:', batch_size, 'mini batch:', mini_batch_size, 'k:', k)

# Set random seed for reproducibility
manualSeed = 42
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

print('dim:', dim)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:', device)

Random Seed:  42
batch size: 100 mini batch: 50 k: 2
dim: 128
device: cpu


In [4]:
data_filename = f'{dataset_name}_{obj}_r{dim}'

f = h5py.File(data_filename + '.h5', 'r')
dataset = torch.from_numpy(np.array(f[list(f.keys())[0]]).reshape(-1, 1, dim, dim, dim)).to(torch.float)

print('dataset shape:', dataset.shape)

FileNotFoundError: [Errno 2] Unable to open file (unable to open file: name = 'shapenet_v2_airplane_r128.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

# GAN Structure 

In [4]:
from src.GAN import Discriminator, Generator, weights_init

netG = Generator(in_channels=256, out_dim=dim, out_channels=1, noise_dim=noise_dim)
if run_parallel:
    netG = torch.nn.DataParallel(netG)
netG = netG.to(device)
netG.apply(weights_init)
# noise = torch.rand(1, noise_dim).to(device)
# generated_volume = netG(noise)
# print("Generator output shape", generated_volume.shape)
netD = Discriminator(in_channels=1, out_conv_channels=256, dim=dim)
if run_parallel:
    netD = torch.nn.DataParallel(netD)
netD = netD.to(device)
netD.apply(weights_init)

criterion = torch.nn.BCELoss()
# # Establish convention for real and fake labels during training
# real_label = 1.
# fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr_D, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr_G, betas=(beta1, 0.999))

# out = netD(generated_volume)
# print("Discriminator output", out.item())

print("\n\nGenerator summary\n\n")
summary(netG, (1, noise_dim))
print("\n\nDiscriminator summary\n\n")
summary(netD, (1, dim, dim, dim))



Generator summary


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1            [-1, 1, 131072]      26,345,472
   ConvTranspose3d-2      [-1, 128, 16, 16, 16]       2,097,152
       BatchNorm3d-3      [-1, 128, 16, 16, 16]             256
              ReLU-4      [-1, 128, 16, 16, 16]               0
   ConvTranspose3d-5       [-1, 64, 32, 32, 32]         524,288
       BatchNorm3d-6       [-1, 64, 32, 32, 32]             128
              ReLU-7       [-1, 64, 32, 32, 32]               0
   ConvTranspose3d-8       [-1, 32, 64, 64, 64]         131,072
       BatchNorm3d-9       [-1, 32, 64, 64, 64]              64
             ReLU-10       [-1, 32, 64, 64, 64]               0
  ConvTranspose3d-11     [-1, 1, 128, 128, 128]           2,048
          Sigmoid-12     [-1, 1, 128, 128, 128]               0
Total params: 29,100,480
Trainable params: 29,100,480
Non-trainable params: 0
---

# Training and Testing 3D-GAN

In [None]:
dataloader = DataLoader(
    dataset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=workers,
)

len(dataloader)

# Running Epochs

In [None]:
# Lists to keep track of progress
G_losses = []
D_real_losses = []
D_fake_losses = []
real_accuracies = []
fake_accuracies = []
start_epoch = 0
iters = 0

real_label = 1.
fake_label = 0.

os.makedirs(f'./weights/{data_filename}', exist_ok=True)

In [None]:
num_epochs = 500

# Training Loop
print("Starting Training Loop...")
# For each epoch
for epoch in tqdm(range(start_epoch, start_epoch+num_epochs)):
    # For each batch in the dataloader
    lst_train_acc_real = []
    lst_train_acc_fake = []
    for i, data_all in enumerate(dataloader, 0):
        data_split = torch.split(data_all, mini_batch_size)
        optimizerD.zero_grad()
#         print('reset netD grads')
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        # Format batch
        for j in range(len(data_split)):
            data = data_split[j]
#             print(data.shape)
            real_cpu = data.to(device)
            b_size = real_cpu.size(0)
            label_real = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            label_fake = torch.full((b_size,), fake_label, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output_real = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            D_x = output_real.mean().item()
            train_acc_real = torch.sum((output_real > 0.5).to(int)==label_real) / b_size
            lst_train_acc_real.append(train_acc_real.item())
            errD_real = criterion(output_real, label_real) / len(data_split)
            errD_real.backward()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.rand(b_size, noise_dim, device=device)
            # Generate fake image batch with G
            fake = netG(noise).detach()
            output_fake = netD(fake).view(-1)
            D_G_z1 = output_fake.mean().item()
            train_acc_fake = torch.sum((output_fake > 0.5).to(int) == label_fake) / b_size
            lst_train_acc_fake.append(train_acc_fake.item())
            errD_fake = criterion(output_fake, label_fake) / len(data_split)
            errD_fake.backward()

            errD = errD_real + errD_fake        
        
        # update D only if classification acc is less than 80% for stability
#         if (i+1) % k == 0 or (i+1) == len(dataloader):
            if j==len(data_split)-1:
                acc_real_mean = np.mean(lst_train_acc_real)
                acc_fake_mean = np.mean(lst_train_acc_fake)
                update = ((acc_real_mean + acc_fake_mean) / 2) < 0.8
                if update:
                    optimizerD.step()  # update the weights only after accumulating k small batches
#                     print('updated optD')

                optimizerD.zero_grad()  # reset gradients for accumulation for the next large batch
                lst_train_acc_real = []
                lst_train_acc_fake = []
            
        

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        # fake labels are real for generator cost
        optimizerG.zero_grad()
#         print('reset netG grads')
        for j in range(len(data_split)):
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)  
            # Since we just updated D, perform another forward pass of all-fake batch through D
            fake = netG(noise)
            output = netD(fake).view(-1)
            errG = criterion(output, label) / len(data_split)
            errG.backward()


            D_G_z2 = output.mean().item()
#             if (i+1) % k == 0 or (i+1) == len(dataloader):
            if j==len(data_split)-1:
                optimizerG.step()  # update the weights only after accumulating k small batches
                optimizerG.zero_grad()  # reset gradients for accumulation for the next large_batch
#                 print('updated optG')

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_fake_losses.append(errD_fake.item())
            D_real_losses.append(errD_real.item())
            fake_accuracies.append(train_acc_fake.item())
            real_accuracies.append(train_acc_real.item())

        # Output training stats
        if i % 10 == 0: # print progress every epoch
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, start_epoch+num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            
        iters += 1
    
    if epoch % 10 == 0:
        # save network weights
        netG_filename = f'./weights/{data_filename}/netG_e{epoch}_r{dim}_weights.pth'
        netD_filename = f'./weights/{data_filename}/netD_e{epoch}_r{dim}_weights.pth'
        torch.save(netG.state_dict(), netG_filename)
        torch.save(netD.state_dict(), netD_filename)
        print('saved network weights', netG_filename)


torch.save(netG.state_dict(), f'./weights/{data_filename}/netG_e{epoch}_r{dim}_weights.pth')
torch.save(netD.state_dict(), f'./weights/{data_filename}/netD_e{epoch}_r{dim}_weights.pth')
start_epoch = epoch # change start to the current

# Results 

## Convergence Graph

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_real_losses, label="D_real")
plt.plot(D_fake_losses, label="D_fake")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.ylim([0, 5])
plt.legend()
plt.show()

plt.figure(figsize=(10,5))
plt.title("Discriminator Accuracies During Training")
plt.plot(real_accuracies, label="acc_real")
plt.plot(fake_accuracies, label="acc_fake")
plt.xlabel("iterations")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

# Visualisation

## Real Samples

In [None]:
real_sample = next(iter(dataloader))

In [None]:
for i in range(5):
    s = real_sample[i][0]
    ax = plt.figure().add_subplot(projection='3d')
#     ax = plt.figure(figsize=(10, 10)).add_subplot(projection='3d')
    ax.voxels(s)
    plt.show()

# Generate fake samples

In [None]:
fake_samples = []

for i in tqdm(range(0, epoch, 20)):    
    try:
    
    #     file_netD = 'weights/shapnet_v2_car_r128/' + 'netD_shapenet_v2_car_r128_e' + f'{i*10}' + '_weights.pth'
        file_netG = f'weights/{data_filename}/netG_e{i}_r{dim}_weights.pth'
        print(file_netG)
        netG.load_state_dict(torch.load(file_netG))
    #     netD.load_state_dict(torch.load(file_netD))

        fixed_noise = torch.rand(5, noise_dim, device=device)
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu().numpy()
        fake_samples.append(fake)
        print('generated fake samples')
        
        
            
    except Exception as e:
        print(e)
        print('epoch', i, 'failed')

fake_samples = np.array(fake_samples)

os.makedirs('./fake_samples', exist_ok=True)
h5_filename = f'./fake_samples/{data_filename}_r{dim}.h5'
with h5py.File(h5_filename, "w") as f:
    dset = f.create_dataset("data", data=fake_samples)
    print(h5_filename, 'saved')

print('fake sample shape:', fake_samples.shape)