# Dual Domain Synthesis

- [github](https://github.com/denabazazian/Dual-Domain-Synthesis)

In [None]:
import os
os.chdir("/home/extra/micheal/dd_synthesis")

## Step 1. Optimize a Dual-domain Latent

Here I'll optimize a latent so that it generates images similar to OCT and iOCT with respective generators.

### 1. Load bscan & label pair

In [None]:
from argparse import Namespace
import torch
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from dataset.images_dataset import ImagesDataset

In [None]:
num_gap = 12 + 1 # self-defined
gap = 255 // num_gap
ILM = 1 * gap # present in 1, 2, 3, 4
RNFL_o = 2 * gap # NFL/FCL in DME, present in 2
IPL_INL = 3 * gap
INL_OPL = 4 * gap
OPL_o = 5 * gap # OPL/ONL in DME
ISM_ISE = 6 * gap
IS_OS = 7 * gap
OS_RPE = 8 * gap
# not sure whether they are the same
RPE = 9 * gap
# RPEDC = 10 * gap
# RPE = 11 * gap

BM = 10 * gap

AROI_LABELS = [ILM, IPL_INL, RPE, BM] # [19, 57, 171, 190]
FLUID_LABELS = [80, 160, 240]
OP_LABELS = [ILM, RPE]
aroi_map = [1, 2, 3, 4]
op_map = [1, 3]
fluid_map = [0, 0, 0]

In [None]:
# from /home/extra/micheal/pixel2style2pixel/configs/transforms_config.py
class Convert2Uint8(torch.nn.Module):
    '''
    Resize input when the target dim is not divisible by the input dim
    '''
    def __init__(self):
        super().__init__()
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be scaled.

        Returns:
            PIL Image or Tensor: Rescaled image.
        """
        img = torch.round(torch.mul(img, 255))
        return img

class ToOneHot(torch.nn.Module):
    '''
    Convert input to one-hot encoding
    '''
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
    def forward(self, img):
        """
        Args:
            img (Tensor): Image to be scaled of shape (1, h, w).

        Returns:
            Tensor: Rescaled image.
        """
        img = img.long()[0]
        img = torch.nn.functional.one_hot(img, num_classes=self.num_classes)
        img = img.permute(2, 0, 1)
        return img
    
class MapVal(torch.nn.Module):
    '''
    Map a list of value to another
    '''
    def __init__(self, src_vals, dst_vals):
        super().__init__()
        assert len(src_vals) == len(dst_vals), "src_vals and dst_vals must of equal length"
        self.src_vals = src_vals
        self.dst_vals = dst_vals
    
    def forward(self, img):
        for s, d in zip(self.src_vals, self.dst_vals):
            img[img==s] = d
        return img

def get_transforms(opts):
    transforms_dict = {
        'transform_gt_train': transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5] * opts.output_nc, [0.5] * opts.output_nc)]),
        'transform_source': transforms.Compose([
            transforms.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
            transforms.ToTensor(),
            Convert2Uint8(),
            MapVal(opts.src_vals, opts.dst_vals),
            ToOneHot(opts.label_nc)
            ])
    }
    return transforms_dict

In [None]:
aroi_bscan_root = "/home/extra/micheal/IDP/data/splits/AROI/original/bscans/train"
aroi_label_root = "/home/extra/micheal/IDP/data/splits/AROI/original/labels/train"
aroi_opts = Namespace(output_nc=1, label_nc=7, src_vals=AROI_LABELS+FLUID_LABELS, dst_vals=aroi_map+fluid_map)
transform_dict = get_transforms(aroi_opts)
aroi_dataset = ImagesDataset(source_root=aroi_label_root,
                            target_root=aroi_bscan_root,
                            source_transform=transform_dict['transform_source'],
                            target_transform=transform_dict['transform_gt_train'],
                            opts=aroi_opts)

In [None]:
"/home/extra/micheal/IDP/data/splits/AROI/original/labels/train"

In [None]:
import numpy as np

aroi_dataset[0][0].shape