In [None]:
%load_ext autoreload
%autoreload 2
# some basics
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import torch
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt

# unconditonal GAN 
from gan import Generator, Discriminator
from gan.utils import reset_random_seeds, train
from gan.utils import EnsembleDataset
from gan.utils import count_trainable_parameters

# for visualization 3d model
from visual_3d import visual_3d

In [9]:
# set random seed
reset_random_seeds(77777)

In [10]:
# import training data - 100 channel facies model
ensemble = np.load('Ensemble.npy')

In [None]:
plt.imshow(ensemble[0,0])

In [None]:
visual_3d(ensemble[0],save_html='test.html')

In [None]:
ensemble.shape

In [18]:
# Preprocess data
ensemble_pad = np.zeros((100, 8, 64, 64))
for i in range(100):
    ensemble_pad[i] = np.pad(ensemble[i], ((0, 1), (2, 2), (2, 2)), 'reflect')


In [None]:
ensemble_pad.shape

In [None]:
plt.subplot(1,2,1)
plt.imshow(ensemble[0,0])
plt.subplot(1,2,2)
plt.imshow(ensemble_pad[0,0])

In [None]:
# hyper-parameters
BUFFER_SIZE = 6000000
BATCH_SIZE = 15
LATENT_DIM = 100
EPOCHS = 2501
FIRST_CHANNEL_GEN = 32*3
FIRST_CHANNEL_DIS = 32*3

# rescale training data from -1 to 1
ensemble_log_perm_pad = np.log(ensemble_pad)
ensemble_log_perm_pad[ensemble_log_perm_pad<4] = 4
ensemble_log_perm_min, ensemble_log_perm_max = ensemble_log_perm_pad.min(), ensemble_log_perm_pad.max()
ensemble_log_perm_sacled = (ensemble_log_perm_pad - ensemble_log_perm_min)/(ensemble_log_perm_max- ensemble_log_perm_min)*2 -1
train_dataset = EnsembleDataset(ensemble_log_perm_sacled)
train_dataloader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Instantiate models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim=LATENT_DIM, first_channel = FIRST_CHANNEL_GEN)
discriminator = Discriminator(first_channel = FIRST_CHANNEL_DIS)
# print(count_trainable_parameters(generator))
# print(count_trainable_parameters(discriminator))

# Train models
train(generator, 
      discriminator, 
      train_dataloader, 
      EPOCHS, 
      LATENT_DIM, 
      device,
      save_every_epoch=100,
      test_every_epoch=50)

In [None]:
num_of_model = 9
generator = Generator(latent_dim=LATENT_DIM, first_channel = FIRST_CHANNEL_GEN).to(device)
generator.load_state_dict(torch.load('saved_gan_models/generator_weights_epoch_02500.pth'))
LATENT_DIM= 100
noise = torch.randn(num_of_model, LATENT_DIM, device=device)
generated_models = generator(noise).detach().cpu().numpy().squeeze()

plt.figure(figsize = (10,10))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.imshow(generated_models[i][0])
