In [1]:
import torch
import torch.nn as nn

from torchinfo import summary

import neptune

import matplotlib.pyplot as plt

from tqdm.auto import trange

import sys
import os
 
import numpy as np
import PIL

In [None]:
exp = neptune.init_run(
    project="",
    api_token="",
)  # your credentials

exp["sys/tags"].add(["MicrostructureGAN", 'PT'])

In [3]:
exp_id = exp['sys/id'].fetch()

os.makedirs(f'saved_images/{exp_id}')
os.makedirs(f'checkpoints/{exp_id}')

In [4]:
# Define image size
img_width = 256
img_height = 256
img_channels = 1
img_shape = (img_channels, img_height, img_width)

device = torch.device('cuda:0')

# Define hyperparameters
gp_coef = 1.
latent_dim = 100
lr_d = 1e-4
lr_g = 2e-5

batch_size = 128

In [5]:
class LabeledImageDataset(torch.utils.data.Dataset):
    def __init__(self, imgs, labels):
        self.imgs = imgs
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.imgs[idx], self.labels[idx]


        
    
def prepare_data():
    # Load samples and rotate

    data_dir = 'training data'

    train_imgs = []
    train_labels = []
    test_imgs = []
    test_labels = []
    labels = [0.73, 0.72, 0.7, 0.67, 0.66, 0.62, 0.56, 0.51]

    for i in range(4, 12):
        subset_imgs = []
        subset_label = labels[i - 4]
        for j in range(1, 6):
            img_dir = f'{data_dir}/{i}-{j}'
            for img_file in os.listdir(img_dir):
                if img_file.startswith('.'): continue
                img = PIL.Image.open(f'{img_dir}/{img_file}')
                img_90 = img.transpose(PIL.Image.ROTATE_90)
                img_180 = img.transpose(PIL.Image.ROTATE_180)
                img_270 = img.transpose(PIL.Image.ROTATE_270)
                arr = np.asarray(img)
                arr_90 = np.asarray(img_90)
                arr_180 = np.asarray(img_180)
                arr_270 = np.asarray(img_270)
                subset_imgs.append(arr)
                subset_imgs.append(arr_90)
                subset_imgs.append(arr_180)
                subset_imgs.append(arr_270)
        if i != 9:
            train_imgs.append(subset_imgs)
            train_labels.append(subset_label * np.ones((len(subset_imgs), 1)))
        else:
            test_imgs.append(subset_imgs)
            test_labels.append(subset_label * np.ones((len(subset_imgs), 1)))

    train_imgs = np.array(train_imgs).reshape((1080 * 7, 1, 256, 256)).astype(np.float32)
    train_imgs = (train_imgs.astype(np.float32) - 127.5) / 127.5
    train_labels = np.array(train_labels).reshape((1080 * 7, 1)).astype(np.float32)
    test_imgs = np.array(test_imgs).reshape((1080, 1, 256, 256)).astype(np.float32)
    test_imgs = (test_imgs.astype(np.float32) - 127.5) / 127.5
    test_labels = np.array(test_labels).reshape((1080, 1)).astype(np.float32)

    return train_imgs, train_labels, test_imgs, test_labels

    
train_imgs, train_labels, test_imgs, test_labels = prepare_data()

dataset_train = LabeledImageDataset(train_imgs, train_labels)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_test = LabeledImageDataset(test_imgs, test_labels)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True)


In [6]:
def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)

In [7]:
class ResBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.skip_conn = nn.Identity()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding='same'),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding='same'),
        )
        self.leakyrelu = nn.LeakyReLU(0.2)
        
    def forward(self, X):
        return self.leakyrelu(self.block(X) + self.skip_conn(X))

In [8]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net_noise = nn.Sequential(
            nn.Linear(latent_dim, 100 * 8 * 8),
            nn.Unflatten(1, (100, 8, 8)),
        )
        self.net_label = nn.Sequential(
            nn.Linear(1, 16 * 8 * 8),
            nn.Unflatten(1, (16, 8, 8)),
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(116, 64, kernel_size=9, stride=4, padding=3, output_padding=1),
            
            ResBlock(64),
            ResBlock(64),
            ResBlock(64),
            ResBlock(64),
            ResBlock(64),
            ResBlock(64),
            
            nn.Conv2d(64, 256, kernel_size=3, padding='same'),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, kernel_size=3, padding='same'),
            nn.Upsample(scale_factor=2),
            
            nn.ConvTranspose2d(256, 128, kernel_size=7, stride=2, padding=3, output_padding=1),
            nn.LeakyReLU(0.2),
        
            nn.Conv2d(128, 1, kernel_size=11, stride=1, padding='same'),
            nn.Tanh(),
        )
        
        self.net_noise.apply(init_weights)
        self.net_label.apply(init_weights)
        self.net.apply(init_weights)
        
    def forward(self, noise: torch.Tensor, label: torch.Tensor):
        return self.net(torch.hstack([self.net_noise(noise), self.net_label(label)]))

In [9]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net_img = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        )
        self.net_label = nn.Sequential(
            nn.Linear(1, 64 * 64 * 20),
            nn.Unflatten(1, (20, 64, 64)),
        )
        self.net = nn.Sequential(
            # Original version
            nn.Conv2d(148, 512, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(512, 256, kernel_size=3, padding='same'),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 128, kernel_size=3, padding='same'),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 64, kernel_size=3, padding='same'),
            nn.LeakyReLU(0.2),

            ResBlock(64),

            nn.Flatten(),
            
            nn.Linear(65536, 512),
            nn.LeakyReLU(0.2),
            
            nn.Linear(512, 1),
        )
        
        self.net_img.apply(init_weights)
        self.net_label.apply(init_weights)
        self.net.apply(init_weights)
        
              
    def forward(self, img: torch.Tensor, label: torch.Tensor):
        return self.net(torch.hstack([self.net_img(img), self.net_label(label)]))

In [10]:
net_g = Generator().to(device)
net_d = Discriminator().to(device)

opt_g = torch.optim.Adam(net_g.parameters(), lr=lr_g, betas=(0., .9), eps=1e-07)
opt_d = torch.optim.Adam(net_d.parameters(), lr=lr_d, betas=(0., .9), eps=1e-07)

def train():
    for epoch in trange(500):
        for i, batch in enumerate(dataloader_train, 0):
            train_step(batch)
        if epoch % 5 == 0: visualize(epoch)
        if epoch % 50 == 0: checkpoint(epoch)

def train_step(batch):
    imgs_real, labels = batch
    imgs_real = imgs_real.to(device)
    labels = labels.to(device)

    batch_size = labels.size()[0]
    noises = torch.randn((batch_size, latent_dim), device=device)

    train_step_d(imgs_real, labels, noises)
    train_step_g(noises, labels)
    
def train_step_d(imgs_real, labels, noises):
    opt_d.zero_grad()

    imgs_fake = net_g(noises, labels)
    loss_d_real = net_d(imgs_real, labels).mean()
    loss_d_fake = net_d(imgs_fake, labels).mean()

    grad_penalty = compute_gp(imgs_real, imgs_fake, labels)

    loss_d = loss_d_fake - loss_d_real + gp_coef * grad_penalty
    loss_d.backward()
    
    opt_d.step()

    exp['loss_d_fake'].append(loss_d_fake)
    exp['loss_d_real'].append(loss_d_real)
    exp['loss_d'].append(loss_d)
    exp['grad_penalty'].append(grad_penalty)
    
def train_step_g(noises, labels):
    opt_g.zero_grad()

    imgs_gen = net_g(noises, labels)
    
    loss_g = -net_d(imgs_gen, labels).mean()
    loss_g.backward()
    
    opt_g.step()

    exp['loss_g'].append(loss_g)



def compute_gp(imgs_real, imgs_fake, labels):
    batch_size = labels.size()[0]

    epsilon = torch.rand((batch_size, 1, 1, 1), device=device).expand_as(imgs_real)
    imgs_interpolated = epsilon * imgs_real + (1 - epsilon) * imgs_fake
    imgs_interpolated.requires_grad_()

    logits_interpolated = net_d(imgs_interpolated, labels)
    grad_outputs = torch.ones_like(logits_interpolated)

    grad_interpolated = torch.autograd.grad(
        outputs=logits_interpolated,
        inputs=imgs_interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
    )[0].view(batch_size, -1)

    grad_norm = grad_interpolated.norm(2, 1)
    grad_penalty = ((grad_norm - 1) ** 2).mean()
    
    return grad_penalty

def visualize(current_epoch):
    r = 2
    c = 2
    noises = torch.rand((1, 100), device=device).repeat((4, 1))
    labels = torch.tensor([0.72, 0.7, 0.62, 0.51], device=device).reshape((4, 1))
    imgs_gen = net_g(noises, labels) * 127.5 + 127.5
    fig, axs = plt.subplots(r, c)
    idx = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(imgs_gen[idx, 0, :, :].detach().cpu().reshape((256, 256)), cmap='gray')
            axs[i, j].axis('off')
            idx += 1
    exp["generated_imgs"].append(fig, step=current_epoch)
    fig.savefig(f'saved_images/{exp_id}/{current_epoch}.png')
    plt.close()
    
def checkpoint(tag):
    torch.save([net_g, net_d], f'checkpoints/{exp_id}/{tag}.pt')
    exp[f'model_checkpoints/{tag}'].upload(f'checkpoints/{exp_id}/{tag}.pt')

In [11]:
train()

  0%|          | 0/500 [00:00<?, ?it/s]

In [13]:
checkpoint('final')

In [None]:
exp.stop()