In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import logging
from os import listdir
from os.path import splitext
from pathlib import Path
import h5py
import numpy as np
import random
import scipy.io as io


In [3]:
class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)

class LocalAwareAttention(nn.Module):
    def __init__(self, channels, beta=0.5):
        super(LocalAwareAttention, self).__init__()

        self.beta = beta

        self.avg = nn.AvgPool2d(kernel_size=4,stride=2,padding=1)
        self.up = nn.ConvTranspose2d(channels, channels, kernel_size=4, stride=2, padding=1, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x1 = self.avg(x)
        x1 = self.up(x1)
        x1 = x1 * self.relu(x1 - x) * self.beta
        x1 = x + x1
        return x1

class PA(nn.Module):
    '''PA is pixel attention'''
    def __init__(self, nf):

        super(PA, self).__init__()
        self.conv = nn.Conv2d(nf, nf, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):

        y = self.conv(x)
        y = self.sigmoid(y)
        out = torch.mul(x, y)

        return out

In [4]:
class StarGanGenerator(nn.Module):
    """Generator network."""
    def __init__(self, input_nc=1, output_nc=1, conv_dim=64, c_dim=5, repeat_num=6, masked=True):
        super(StarGanGenerator, self).__init__()

        self.masked = masked
        if self.masked == False:
            c_dim = 0

        layers = []
        layers.append(nn.Conv2d(input_nc+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        layers.append(nn.LeakyReLU(inplace=True))

        # Down-sampling layers.
        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            layers.append(nn.LeakyReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck layers.
        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
            layers.append(LocalAwareAttention(channels=curr_dim))

        # Up-sampling layers.
        for i in range(3):
            layers.append(PA(curr_dim))
            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
            layers.append(nn.LeakyReLU(inplace=True))
            curr_dim = curr_dim // 2

        layers.append(PA(curr_dim))
        layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=(1,4), stride=(1,2), padding=(0,1), bias=False))
        layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
        layers.append(nn.LeakyReLU(inplace=True))
        curr_dim = curr_dim // 2

        layers.append(nn.Conv2d(curr_dim, output_nc, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c=None):
        if self.masked == True:
            c = c.view(c.size(0), c.size(1), 1, 1)
            c = c.repeat(1, 1, x.size(2), x.size(3))
            x = torch.cat([x, c], dim=1)
            y = self.main(x)
        else:
            y = self.main(x)
        return y

In [10]:
class BasicDataset(Dataset):
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self.ids = [splitext(file)[0] for file in listdir(data_dir) if not file.startswith('.')]
        self.type = []
        if not self.ids:
            raise RuntimeError(f'No input file found in {data_dir}, make sure you put your images there')
        logging.info(f'Creating dataset with {len(self.ids)} examples')

    def __len__(self):
        return len(self.ids)

    @classmethod
    def preprocess(cls, pil_img, is_mask):
        img_ndarray = np.asarray(pil_img)
        if not is_mask:
            img_ndarray = img_ndarray.astype(np.float32)
            img_ndarray = img_ndarray[np.newaxis, ...]
            img_ndarray = img_ndarray / np.max(np.abs(img_ndarray))

            return img_ndarray

        if is_mask:
            img_ndarray = img_ndarray.astype(np.float32)
            img_ndarray = (img_ndarray - np.min(img_ndarray)) / (np.max(img_ndarray)-np.min(img_ndarray))
            img_ndarray = img_ndarray * 2.0 - 1.0
            img_ndarray = img_ndarray[np.newaxis, :, 0:512]
            return img_ndarray

    @classmethod
    def load(cls, filename):
        dic = h5py.File(filename)

        rf = np.transpose(dic['RF'],[1,0])
        img = np.transpose(dic['Img'], [1, 0])

        return rf, img

    def __getitem__(self, idx):
        name = self.ids[idx]
        file = list(self.data_dir.glob(name + '.*'))

        rf, img = self.load(file[0])

        ## crop RF
        Nt,Nc = rf.shape

        rf_tmp1 = np.copy(rf)
        rf_tmp2 = np.copy(rf)
        threshold1 = 0.3
        threshold2 = 0.01

        rf_tmp1[np.abs(rf_tmp1) < threshold1] = 0
        rf_tmp2[np.abs(rf_tmp2) < threshold2] = 0
        non_zero_indices = np.nonzero(rf_tmp1)

        t_min = np.min(non_zero_indices[0])
        t_max = np.max(non_zero_indices[0])

        ## The sliding window is fixed to select a stable middle segment of the signal.
        upper = int((t_min + t_max) / 2) - 128
        lower = int((t_min + t_max) / 2) + 128

        ## The sliding window is randomly shifted, which may result in selecting poor-quality signals, but this is necessary during training.
        # if t_max - 256 <= t_min:
        #     offset = random.randint(-64, 64)
        #     upper = int((t_min + t_max) / 2) - 128 + offset
        #     lower = int((t_min + t_max) / 2) + 128 + offset
        # else:
        #     upper = random.randint(t_min, t_max - 256)
        #     lower = upper + 256
        # if upper < 0:
        #     upper = 0
        #     lower = 256
        # if lower >= Nt:
        #     upper = Nt-256
        #     lower = Nt       

        ## Crop the data
        rf_crop = rf_tmp2[upper:lower, :]

        ## Global normalization
        rf_norm = rf_crop / np.max(np.abs(rf_crop))

        ##
        rf = self.preprocess(rf_norm, is_mask=False)

        img = self.preprocess(img, is_mask=True)

        return {
            'A': torch.as_tensor(rf.copy()).float().contiguous(),
            'B': torch.as_tensor(img.copy()).float().contiguous(),
            'name': name,
        }


In [11]:
def disable_instance_norm_running_stats(model):
    for m in model.modules():
        if isinstance(m, torch.nn.InstanceNorm2d):
            m.track_running_stats = False
            m.running_mean = None
            m.running_var = None

net = StarGanGenerator(input_nc=1, output_nc=1, c_dim=3, masked=True)
net.load_state_dict(torch.load("model\\latest_net_G.pth"))
net.to(0)
net.eval()
disable_instance_norm_running_stats(net)

In [12]:
def get_label(trans_type):
    # get mask vector
    # P4-1: trans_type=1
    # L7-4: trans_type=2
    # CL15-7: trans_type=3

    clc = [1, 2, 3]
    label = np.zeros((1,len(clc)),dtype=np.float32)
    for i in range(3):
        if trans_type == clc[i]:
            label[0,i] = 1
    return label

## Example for P4-1
dir_data = Path('.\\data\\P4-1\\')
trans_type = 1

## Example for L7-4
# dir_data = Path('.\\data\\L7-4\\')
# trans_type = 2

## Example for CL15-7
# dir_data = Path('.\\data\\CL15-7\\')
# trans_type = 3

img_path1 = 'imgs/DNN/'
img_path2 = 'imgs/EISRCB/'

dataset = BasicDataset(dir_data)
loader_args = dict(batch_size=1, num_workers=0, pin_memory=False)
train_loader = DataLoader(dataset, shuffle=False, **loader_args)

i=0
for data in train_loader:

    input_data = data['A'].to(0)
    if trans_type == 1:
        input_data = F.pad(input_data, pad=(16, 16), mode='constant', value=0)

    label = torch.as_tensor(get_label(trans_type)).to(0)
    
    with torch.no_grad():
        output = net.forward(input_data,label)

    output = output.cpu().detach().numpy()
    output = np.squeeze(output)

    io.savemat(img_path1 + 'DNN_Img_' + data['name'][0] + '.mat',{'DNN_Img':output})
    io.savemat(img_path2 + 'EISRCB_Img_' + data['name'][0] + '.mat', {'EISRCB_Img': np.squeeze(data['B'].cpu().detach().numpy()),
                                                                'label': np.squeeze(label.cpu().detach().numpy())})

    i = i+1
    print(i)

1
