# Porous media flow field estimation using cGANs

### Import libraries

In [1]:
import os

import torch
import torchvision
import torch.nn as nn # All neural network modules, e.g. nn.Linear, nn.Conv2d, BatchNorm, Loss Functions
import torch.optim as optim # For all Optmization algorithms, SGD, Adam, etc
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader # Gives easier dataset management and creates mini batches
# from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
import PIL
from PIL import Image
import glob
import matplotlib.pyplot as plt

### Define data paths

In [2]:
# Get mesh and velocity data paths
mesh_paths = glob.glob('./mesh_data/*.tif')
vel_paths = glob.glob('./vel_data/*.tif')

total_samples = len(mesh_paths)
train_size = 0.85

# Separate training samples
train_mesh_paths = mesh_paths[:int(total_samples*train_size)]
train_vel_paths = vel_paths[:int(total_samples*train_size)]

# Separate test samples
test_mesh_paths = mesh_paths[int(total_samples*train_size):]
test_vel_paths = vel_paths[int(total_samples*train_size):]

print("Total samples:", total_samples)
print("Train samples", len(train_mesh_paths))
print("Test samples", len(test_mesh_paths))

Total samples: 200
Train samples 170
Test samples 30


### Preprocess the dataset

In [3]:
# Create a dataset class to consider an image pair 
class mesh_vel_dataset(Dataset):
    def __init__(self, meshes, veles, train=True):
        self.meshes = meshes
        self.vels = veles
    
    def transform(self, mesh, vel):
        resize_mesh = transforms.Resize(size = (64,64), interpolation=Image.NEAREST)
        resize_vel = transforms.Resize(size = (64,64), interpolation=Image.NEAREST)
        gray =  transforms.Grayscale(num_output_channels=1)

        mesh = TF.to_tensor(resize_mesh(mesh))
        vel = TF.to_tensor(resize_vel(vel))

        vel = vel.__ge__(0.7).type(torch.FloatTensor)
        
        return mesh, vel
    
    def __getitem__(self, idx):
        mesh = Image.open(self.meshes[idx])
        vel = Image.open(self.vels[idx])
        x, y = self.transform(mesh, vel)
        
        return x, y
    
    def __len__(self):
        return len(self.meshes)



### Load the data

In [4]:
train_data = mesh_vel_dataset(train_mesh_paths, train_vel_paths, train=True)
train_loader = DataLoader(train_data, batch_size = 5, shuffle = True)

test_data = mesh_vel_dataset(test_mesh_paths, test_vel_paths, train=False)
test_loader = DataLoader(test_data, batch_size = 5, shuffle = False)

print(len(train_loader.dataset))

170


### Build the Discriminator

In [5]:
# Discriminator
class Discriminator(torch.nn.Module):
    
    #3 hidden-layer discriminative nn
    
    def __init__(self):
        super(Discriminator, self).__init__()
        input_dim = 256+256
        output_dim = 1
        self.label_embedding = nn.Embedding(10,10)
        
        self.hidden0=nn.Sequential(
            nn.Conv2d(1,64,kernel_size=2, stride=2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.hidden1=nn.Sequential(
            nn.Conv2d(64,128,kernel_size=2, stride=2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.hidden2=nn.Sequential(
            nn.Conv2d(128,256,kernel_size=2, stride=2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.out=nn.Sequential(
            nn.Conv2d(256,1,kernel_size=2, stride=2),
            torch.nn.Sigmoid()
        )
        
    def forward(self, x,msk):
        x = torch.cat([x,msk])

        output = self.hidden0(x)
        output = self.hidden1(output)
        output = self.hidden2(output)
        output = self.out(output)
        return output

### Build the Generator


In [6]:
class Generator(nn.Module):
    def __init__(self):#, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.maxpool = nn.MaxPool2d(kernel_size =2, stride =2)

        self.net = nn.Sequential(
            
            nn.ConvTranspose2d(1, 64, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()

        )
    def forward(self, x):
        return self.net(x)

### Build the Adversarial Model

In [7]:
class adversarialModel(object):
    def __init__(self, num_epochs=500, samples=3, batch=5, betas=(0.5,0.5), g_lr = 0.001, d_lr = 0.001,                      size = 64, data_path = './models', channels_img = 1, channels_noise =  64, features_g = 16,                 features_d = 16, dataloader = train_loader, transforms = None):
        
        # Define parameters
        self.num_epochs = num_epochs
        self.samples = samples
        self.batch = batch
        self.betas = betas
        self.g_lr = g_lr
        self.d_lr = d_lr
        self.size = size
        self.channels_img = channels_img
        self.channels_noise = channels_noise
        self.features_d = features_d
        self.features_g = features_g
        self.dataset = 'input_data'
        self.name = 'adversarialModel'
        self.output_dir = glob.glob('./')

        # Generator and Discriminator
        self.generator = Generator()
        self.discriminator = Discriminator()

        # Optimizers
        self.optim_g = optim.Adam(self.generator.parameters(), lr = self.g_lr, betas =(0.9, 0.999))
        self.optim_d = optim.Adam(self.discriminator.parameters(), lr = self.d_lr, betas =(0.9, 0.999))

        decay_factor = 0.5 

        # Loss functions
        self.bce_loss = nn.BCELoss()
        self.L1_loss = nn.L1Loss()

        # Dataset
        self.dataloader = dataloader

    def train(self):
        losses_d = []
        losses_g = []

        for epoch in range(self.num_epochs):
            for batch_idx, (meshes, vels) in enumerate(self.dataloader):

                meshes = meshes.to(device)
                vels = vels.to(device)

                # Discriminator loss
                self.optim_d.zero_grad()

                fake_vels = self.generator(meshes)
                pred_real = self.discriminator(meshes, vels)
                ones = torch.ones(pred_real.shape).to(device)

                pred_fake = self.discriminator(meshes, fake_vels)
                zeros = torch.zeros(pred_fake.shape).to(device)

                loss_discriminator_real = self.bce_loss(pred_real, ones)
                loss_discriminator_fake = self.bce_loss(pred_fake, zeros)
                loss_discriminator = loss_discriminator_real + loss_discriminator_fake
                losses_d.append(loss_discriminator)

                loss_discriminator.backward()
                self.optim_d.step()

                # Generator Loss
                self.optim_g.zero_grad()

                fake_vels = self.generator(meshes)
                loss_generator_bce = self.bce_loss(self.discriminator(meshes, fake_vels), ones)
                loss_generator_l1 = self.L1_loss(fake_vels, vels)
                loss_generator = loss_generator_bce + loss_generator_l1
                losses_g.append(loss_generator)

                loss_generator.backward()
                self.optim_g.step()

                if(epoch%10 == 0  and batch_idx==0):
                    print('Epoch : ', epoch)
                    print("generator loss: ",loss_generator)
                    print("discriminator loss: ",loss_discriminator)

                if(epoch%50 ==0 and batch_idx==0):
                    plt.title("Model losses")
                    plt.plot(losses_d, label="Discriminator")
                    plt.plot(losses_g, label="Generator")
                    plt.xticks(np.arange(1, epoch+1, 1.0))
                    plt.legend()
                    plt.show()
    

    def evaluate(self, test_loader = test_loader):
        losses_d = []
        losses_g = []
        with torch.no_grad():
            for batch_idx, (meshes, vels) in enumerate(test_loader):

                meshes = meshes.to(device)
                vels = vels.to(device)

                fake_vels = self.generator(meshes)
                pred_real = self.discriminator(meshes, vels)
                ones = torch.ones(pred_real.shape).to(device)

                pred_fake = self.discriminator(meshes, fake_vels)
                zeros = torch.zeros(pred_fake.shape).to(device)

                loss_dicriminator_real = self.bce_loss(pred_real, ones)
                loss_dicriminator_fake = self.bce_loss(pred_fake, zeros)
                loss_discriminator = loss_dicriminator_real + loss_dicriminator_fake

                fake_vels = self.generator(meshes)
                loss_generator_bce = self.bce_loss(self.discriminator(meshes, fake_vels), ones)
                loss_generator_l1 = self.L1_loss(fake_vels, vels)
                loss_generator = loss_generator_bce + loss_generator_l1
                loss_generator = loss_generator_bce + loss_generator_l1

                losses_d.append(loss_discriminator)
                losses_g.append(loss_generator)

                if(batch_idx==0):
                    print("generator loss: ",loss_generator)
                    print("discriminator loss: ",loss_discriminator)
                    plt.title("Model losses")
                    plt.plot(losses_d, label="Discriminator")
                    plt.plot(losses_g, label="Generator")
                    
                    plt.legend()
                    
                    fig = plt.figure()
                    ax1 = fig.add_subplot(2,2,1)
                    ax1.imshow(np.squeeze(masks[0].cpu().detach().numpy()),cmap = 'gray')
                    ax2 = fig.add_subplot(2,2,2)
                    ax2.imshow(np.squeeze(fake_images[0].cpu().detach().numpy()),cmap = 'gray')    
                    ax3 = fig.add_subplot(2,2,3)
                    ax3.imshow(np.squeeze(images[0].cpu().detach().numpy()),cmap = 'gray')
        
                    plt.show()





### Set the hyperparameters

In [8]:
if torch.cuda.is_available():
    device = 'cuda'

In [11]:

model = adversarialModel(num_epochs=800, dataloader = train_loader)

In [12]:
print(type(train_loader))

<class 'torch.utils.data.dataloader.DataLoader'>


### Instantiate Discriminator and Generator