# Dual Domain Synthesis

- [github](https://github.com/denabazazian/Dual-Domain-Synthesis)

In [None]:
import os
os.chdir("/home/extra/micheal/dd_synthesis")

In [None]:
import argparse
from argparse import Namespace

from tqdm import tqdm
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from utils_repurpose import tensor2image
from perceptual_model import VGG16_for_Perceptual
from stylegan2 import Generator

## Step 1. Optimize a Dual-domain Latent

Here I'll optimize a latent so that it generates images similar to OCT and iOCT with respective generators.

### 1. Configurations

In [None]:
generator_path = "checkpoints/aroi_to_op/036000.pt"
target_model_path = "submodules/stylegan2-pytorch/checkpoint/030000.pt" # OCT (AROI)
save_path_root = "artifacts/op2ioct/results"
source_image_path = "/home/extra/micheal/pixel2style2pixel/data/ioct/bscans/train/OS-2020-02-03_135647fs-073.png"
source_mask_path = "/home/extra/micheal/pixel2style2pixel/data/ioct/labels/train/OS-2020-02-03_135647fs-073.png"
target_image_path = "/home/extra/micheal/IDP/data/splits/AROI/original/bscans/train/patient15_raw0025.png"
target_mask_path = "/home/extra/micheal/IDP/data/splits/AROI/original/labels/train/patient15_raw0025.png"
image_size = 256
n_samples = 1
imshow_size = 256
latent_dim = 512
truncation = 0.7
n_test = 1
id_dir = "mix"
sample_z_path = "artifacts/op2ioct/latents"
save_iterations = True
mask_guided_iterations = 1002
lr = 0.01
n_mean_latent = 10000

### 2. Helper Functions

In [None]:
@torch.no_grad()
def concat_features(features):
    h = max([f.shape[-2] for f in features])
    w = max([f.shape[-1] for f in features])
    return torch.cat([torch.nn.functional.interpolate(f, (h, w), mode='nearest') for f in features], dim=1)

class Convert2Uint8(torch.nn.Module):
    '''
    Resize input when the target dim is not divisible by the input dim
    '''

    def __init__(self):
        super().__init__()

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be scaled.

        Returns:
            PIL Image or Tensor: Rescaled image.
        """
        img = torch.round(torch.mul(img, 255))
        return img


class ToOneHot(torch.nn.Module):
    '''
    Convert input to one-hot encoding
    '''

    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

    def forward(self, img):
        """
        Args:
            img (Tensor): Image to be scaled of shape (1, h, w).

        Returns:
            Tensor: Rescaled image.
        """
        img = img.long()[0]
        # img = torch.nn.functional.one_hot(img, num_classes=self.num_classes)
        img = img.permute(2, 0, 1)
        return img


class MapVal(torch.nn.Module):
    '''
    Map a list of value to another
    '''

    def __init__(self, src_vals, dst_vals):
        super().__init__()
        assert len(src_vals) == len(
            dst_vals), "src_vals and dst_vals must of equal length"
        self.src_vals = src_vals
        self.dst_vals = dst_vals

    def forward(self, img):
        for s, d in zip(self.src_vals, self.dst_vals):
            img[img == s] = d
        return img


def get_transforms(opts):
    transforms_dict = {
        'transform_gt_train': transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5] * opts.output_nc, [0.5] * opts.output_nc)]),
        'transform_source': transforms.Compose([
            transforms.Resize(
                (256, 256), interpolation=InterpolationMode.NEAREST),
            transforms.ToTensor(),
            Convert2Uint8(),
            MapVal(opts.src_vals, opts.dst_vals),
            # ToOneHot(opts.label_nc)
        ])
    }
    return transforms_dict


def load_bscan(bscan_path):
    """Load and add a batch dimension
    """
    image = Image.open(bscan_path).convert('RGB')
    image = bscan_transform(image)
    return image


def load_label(label_path):
    """Load and add a batch dimension
    """
    image = Image.open(label_path).convert('L')
    image = label_transform(image)
    return image



def caluclate_loss(synth_img, img, perceptual_net, mask, MSE_Loss, image_resolution):

    img_p = torch.nn.Upsample(scale_factor=(
        256/image_resolution), mode='bilinear')(img)
    real_0, real_1, real_2, real_3 = perceptual_net(img_p)
    synth_p = torch.nn.Upsample(scale_factor=(
        256/image_resolution), mode='bilinear')(synth_img)
    synth_0, synth_1, synth_2, synth_3 = perceptual_net(synth_p)

    perceptual_loss = 0
    mask = torch.nn.Upsample(scale_factor=(
        256/image_resolution), mode='bilinear')(mask)
    perceptual_loss += MSE_Loss(synth_0*mask.expand(1,
                                64, 256, 256), real_0*mask.expand(1, 64, 256, 256))
    perceptual_loss += MSE_Loss(synth_1*mask.expand(1,
                                64, 256, 256), real_1*mask.expand(1, 64, 256, 256))
    mask = torch.nn.Upsample(scale_factor=(64/256), mode='bilinear')(mask)
    perceptual_loss += MSE_Loss(synth_2*mask.expand(1,
                                256, 64, 64), real_2*mask.expand(1, 256, 64, 64))
    mask = torch.nn.Upsample(scale_factor=(32/64), mode='bilinear')(mask)
    perceptual_loss += MSE_Loss(synth_3*mask.expand(1,
                                512, 32, 32), real_3*mask.expand(1, 512, 32, 32))

    return perceptual_loss


def noise_normalize_(noises):
    for noise in noises:
        mean = noise.mean()
        std = noise.std()

        noise.data.add_(-mean).div_(std)


def horizontal_expand(label, feature, to_expand=20):
    if isinstance(label, torch.Tensor):
        label_copy = label.clone()
    else:
        label_copy = label.copy()
    x, y = np.where(label_copy == feature)
    xacc, yacc = x, y
    for i in range(to_expand):
        xacc = np.concatenate([xacc, x])
        yacc = np.concatenate([yacc, y+i])
    label_copy[xacc, yacc] = feature
    return label_copy


def expand_label(label, instrument_label=2, shadow_label=4, expansion_instrument=30,
                 expansion_shadow=60):
    """The input size is expected to be [h, w]
    """
    # For label 2 4 (instrument & its mirroring), we horizontally expand
    # a couple of pixels rightward
    label = horizontal_expand(label, instrument_label,
                              to_expand=expansion_instrument)
    # shadows are generally broader
    label = horizontal_expand(label, shadow_label, to_expand=expansion_shadow)
    return label


def get_shadow(label, instrument_label=2, shadow_label=4, top_layer_label=1, img_width=256, img_height=256):
    shadow_x = np.array([], dtype=np.int64)
    shadow_y = np.array([], dtype=np.int64)
    # Requirements for the shadow label:
    # 1. Horizontally after the starting of the instrument/mirroring & before the
    #    ending of the instrument/mirroring
    # 2. Vertically below the lower bound of instrument/mirroring
    x, y = np.where(np.logical_or(label==instrument_label, label==shadow_label)) # (1024, 512)
    if len(x) == 0:
        return shadow_x, shadow_y
    left_bound = np.min(y)
    right_bound = np.max(y)
    accumulated_min_lowerbound = 0
    for i in range(left_bound, right_bound):
        instrument_above = np.where(np.logical_or(label[:, i] == instrument_label, label[:, i] == shadow_label))[0]
        if len(instrument_above) == 0:
            if accumulated_min_lowerbound == 0:
                continue
            else:
                # set to current recorded lowest shadow
                instrument_lowerbound = accumulated_min_lowerbound
        else:
            # print("instrument_above", instrument_above, len(instrument_above))
            instrument_lowerbound = np.max(instrument_above)
            if accumulated_min_lowerbound == 0:
                # initialize
                accumulated_min_lowerbound = instrument_lowerbound
            else:
                accumulated_min_lowerbound = max(accumulated_min_lowerbound, instrument_lowerbound)
        x_vertical = np.arange(instrument_lowerbound, img_height) # upperbound to bottom
        y_vertical = np.full_like(x_vertical, i)
        shadow_x = np.concatenate([shadow_x, x_vertical])
        shadow_y = np.concatenate([shadow_y, y_vertical])
    return shadow_x, shadow_y


### 2. Load Generators

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

generator = Generator(image_size, latent_dim, 8)
generator_ckpt = torch.load(generator_path, map_location='cpu')
generator.load_state_dict(generator_ckpt["g_ema"], strict=False)
generator = generator.eval().to(device)

targ_generator = Generator(image_size, latent_dim, 8).to(device)
targ_generator = nn.parallel.DataParallel(targ_generator)

targ_generator_ckpt = torch.load(target_model_path)
targ_generator.load_state_dict(targ_generator_ckpt["g_ema"], strict=False)
targ_generator = targ_generator.eval().to(device)

### 3. Load bscan & label pair

Load source & target bscan and label

In [None]:
AROI_LABELS = [19, 57, 171, 190]
INSTRUMENT_LABELS = [100, 200]
instrument_map = [5, 6]
FLUID_LABELS = [80, 160, 240]
transform_dict = get_transforms(Namespace(
    output_nc=3, label_nc=2, src_vals=[2, 3, 4]+AROI_LABELS+INSTRUMENT_LABELS+FLUID_LABELS,
    dst_vals=[5, 2, 6]+[1, 2, 3, 4]+instrument_map+[0, 0, 0]))
bscan_transform = transform_dict['transform_gt_train']
label_transform = transform_dict['transform_source']

In [None]:
imgs_gen = load_bscan(source_image_path)  # # source is iOCT, [3, h, w]
masks = load_label(source_mask_path)  # assume to be [1, h, w]
targ_imgs_gen = load_bscan(target_image_path)  # target is OCT
targ_masks = load_label(target_mask_path)

print(f"image shape {list(imgs_gen.shape)}, mask shape {list(masks.shape)}")
print(np.unique(masks))

In [None]:
img_source = imgs_gen.unsqueeze(0).to(
    device)  # (1,3,image_size,image_size) (1,3,256,256)
img_target = targ_imgs_gen.unsqueeze(0).to(
    device)  # (1,3,image_size,image_size) (1,3,256,256)

In [None]:
# Visualize source and target image
fig, axes = plt.subplots(1, 2, figsize=(5, 10))
axes[0].imshow(Image.open(source_image_path))
axes[1].imshow(Image.open(target_image_path))

In [None]:
# Visualize source and target mask
fig, axes = plt.subplots(1, 2, figsize=(5, 10))
axes[0].imshow(masks[0])
axes[1].imshow(targ_masks[0])

### 4. Create cross-over masks

expand instrumen|ts and shadows in source

In [None]:
mask = expand_label(masks[0, :, :], instrument_label=5, shadow_label=6,
                    expansion_instrument=15, expansion_shadow=15)  # (256, 256)

In [None]:
# Select classes of interest (instrument, its mirroring and the shadow below)
classes_of_interest = [5, 6]
mask_copy = np.zeros_like(mask)
for c in classes_of_interest:
    mask_copy[mask == c] = 1

# Get the shadow and set to intrested
shadow_x, shadow_y = get_shadow(
    mask, instrument_label=5, shadow_label=6, top_layer_label=1)
mask_copy[shadow_x, shadow_y] = 1

# Only instrument, mirroring and shadow are 1 now
mask = torch.as_tensor(mask_copy)

mask_0 = mask.unsqueeze(0)
mask_1 = mask_0.clone()
mask_1 = 1 - (mask_1)  # (1,h,w)

# Here I set the target mask to the same as original mask
targ_mask_0 = mask_0.clone()
targ_mask_1 = mask_1.clone()

mask_0 = mask_0.to(device)
mask_1 = mask_1.to(device)
targ_mask_0 = targ_mask_0.to(device)
targ_mask_1 = targ_mask_1.to(device)

cross_over_source = (img_source * mask_1) + (img_target * targ_mask_0)
cross_over_target = (img_source * mask_0) + (img_target * targ_mask_1)

# for visualization
cross_over_source_image = tensor2image(cross_over_source.to('cpu'))  # [256, 256, n_channels(1)]
cross_over_target_image = tensor2image(cross_over_target.to('cpu'))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(5, 10))
axes[0].imshow(cross_over_source_image)
axes[1].imshow(cross_over_target_image)

### 5. Create a input latent and noise

In [None]:
g_ema = generator
with torch.no_grad():
    noise_sample = torch.randn(n_mean_latent, 512, device=device)
    latent_out = g_ema.style(noise_sample)
    latent_mean = latent_out.mean(0)
    
noises_single = g_ema.make_noise()
noises = []
for noise in noises_single:
    noises.append(noise.repeat(img_source.shape[0], 1, 1, 1).normal_())

# Create a new latent as optimization starting point
latent_in = latent_mean.detach().clone().unsqueeze(
    0).repeat(img_source.shape[0], 1)
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)

# Both latents and noise are optimized? (Only latent_in is in optimizer)
latent_in.requires_grad = True
for noise in noises:
    noise.requires_grad = True

In [None]:
print(latent_in.shape)
print(noises[0].shape)

### 6. Initialize models

Prepare models, optimizers, etc.

In [None]:
perceptual_net = VGG16_for_Perceptual(n_layers=[2, 4, 14, 21]).to(
    device)  # conv1_1,conv1_2,conv2_2,conv3_3
MSE_Loss = nn.MSELoss(reduction="mean")
optimizer = optim.Adam([latent_in], lr=lr)
# TODO: manage size elsewhere
mask_1 = mask_1.unsqueeze(0)
mask_0 = mask_0.unsqueeze(0)

loss_list = []
latent_path = []

In [None]:
mask_guided_iterations

### 7. Optimize the latent

In [None]:
intermediate_images = []

for i in tqdm(10000):
    t = i / mask_guided_iterations
    optimizer.param_groups[0]["lr"] = lr

    synth_img, _ = g_ema([latent_in], input_is_latent=True, noise=noises)

    batch, channel, height, width = synth_img.shape

    if height > image_size:
        factor = height // image_size

        synth_img = synth_img.reshape(
            batch, channel, height // factor, factor, width // factor, factor
        )
        synth_img = synth_img.mean([3, 5])

    loss_wl1 = caluclate_loss(synth_img, img_source,
                              perceptual_net, mask_1, MSE_Loss, image_size)
    loss_wl0 = caluclate_loss(synth_img, img_target,
                              perceptual_net, mask_0, MSE_Loss, image_size)
    mse_w0 = F.mse_loss(synth_img*mask_1.expand(1, 3, image_size, image_size),
                        img_source*mask_1.expand(1, 3, image_size, image_size))
    mse_w1 = F.mse_loss(synth_img*mask_0.expand(1, 3, image_size, image_size),
                        img_target*mask_0.expand(1, 3, image_size, image_size))
    mse_crossover = 3*(F.mse_loss(synth_img.float(),
                       cross_over_source.float()))
    p_loss = 2*((loss_wl0)+loss_wl1)
    mse_loss = (mse_w0)+mse_w1
    loss = (p_loss)+(mse_loss)+(mse_crossover)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    noise_normalize_(noises)

    lr_schedule = optimizer.param_groups[0]['lr']

    if (i + 1) % 1000 == 0:
        latent_path.append(latent_in.detach().clone())

        loss_np = loss.detach().cpu().numpy()
        loss_0 = loss_wl0.detach().cpu().numpy()
        loss_1 = loss_wl1.detach().cpu().numpy()
        mse_0 = mse_w0.detach().cpu().numpy()
        mse_1 = mse_w1.detach().cpu().numpy()
        mse_loss = mse_loss.detach().cpu().numpy()

        print("iter{}: loss -- {:.5f},  loss0 --{:.5f},  loss1 --{:.5f}, mse0--{:.5f}, mse1--{:.5f}, mseTot--{:.5f}, lr--{:.5f}".format(i,
              loss_np, loss_0, loss_1, mse_0, mse_1, mse_loss, lr_schedule))

        if save_iterations:
            img_gen, _ = g_ema([latent_path[-1]],
                               input_is_latent=True, noise=noises)
            img_tens = (
                img_gen.clamp_(-1., 1.).detach().squeeze().permute(1, 2, 0).cpu().numpy())*0.5 + 0.5
            # pil_img = Image.fromarray((img_tens*255).astype(np.uint8))
            # pil_img.save(img_name)
            intermediate_images.append(img_tens)

    if i == (mask_guided_iterations-1):
        img_name = save_path_root+"{}_D1.png".format(str(i).zfill(6))
        img_gen, _ = g_ema([latent_path[-1]],
                           input_is_latent=True, noise=noises)
        img_tens = (img_gen.clamp_(-1., 1.).detach().squeeze().permute(1,
                    2, 0).cpu().numpy())*0.5 + 0.5
        pil_img = Image.fromarray((img_tens*255).astype(np.uint8))
        pil_img.save(img_name)

In [None]:
import math
n_col = 4
n_row = math.ceil(len(intermediate_images) / n_col)
plt.figure(figsize=(n_row * 3, n_col * 3))
for i, img in enumerate(intermediate_images):
    ax = plt.subplot(n_row, n_col, i+1)
    ax.imshow(img)