In [1]:
import torch.nn as nn
import torch
import scipy.io as sio
import numpy as np

In [2]:
class ICLabelNetImg(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, 
                               out_channels = 128, 
                               kernel_size=4,
                               padding=1,
                               stride=2)
        self.relu1 = nn.LeakyReLU(negative_slope = 0.2)
        self.conv2 = nn.Conv2d(in_channels=128,
                               out_channels = 256,
                               kernel_size=4,
                               padding=1,
                               stride=2)
        self.relu2 = nn.LeakyReLU(negative_slope = 0.2)
        self.conv3 = nn.Conv2d(in_channels=256,
                               out_channels = 512,
                               kernel_size=4,
                               padding=1,
                               stride=2)
        self.relu3 = nn.LeakyReLU(negative_slope = 0.2)
        self.sequential = nn.Sequential(self.conv1, self.relu1, 
                                        self.conv2, self.relu2, 
                                        self.conv3, self.relu3)
    
    def forward(self, x):
        return self.sequential(x)
    
    
class ICLabelNetPSDS(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels = 128,
                               kernel_size=(1,3),
                               padding=(0,1),
                               stride=1)
        self.relu1 = nn.LeakyReLU(negative_slope = 0.2)
        self.conv2 = nn.Conv2d(in_channels=128, 
                               out_channels = 256, 
                               kernel_size=(1,3), 
                               padding=(0,1),
                               stride=1)
        self.relu2 = nn.LeakyReLU(negative_slope = 0.2)
        self.conv3 = nn.Conv2d(in_channels=256, 
                               out_channels = 1, 
                               kernel_size=(1,3), 
                               padding=(0,1),
                               stride=1)
        self.relu3 = nn.LeakyReLU(negative_slope = 0.2)
        self.sequential = nn.Sequential(self.conv1, self.relu1, 
                                        self.conv2, self.relu2, 
                                        self.conv3, self.relu3)
    
    def forward(self, x):
        return self.sequential(x)



class ICLabelNetAutocorr(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels = 128,
                               kernel_size=(1,3),
                               padding=(0,1),
                               stride=1)
        self.relu1 = nn.LeakyReLU(negative_slope = 0.2)
        self.conv2 = nn.Conv2d(in_channels=128, 
                               out_channels = 256, 
                               kernel_size=(1,3), 
                               padding=(0,1),
                               stride=1)
        self.relu2 = nn.LeakyReLU(negative_slope = 0.2)
        self.conv3 = nn.Conv2d(in_channels=256,
                               out_channels = 1,
                               kernel_size=(1,3),
                               padding=(0,1),
                               stride=1)
        self.relu3 = nn.LeakyReLU(negative_slope = 0.2)
        self.sequential = nn.Sequential(self.conv1, self.relu1,
                                        self.conv2, self.relu2,
                                        self.conv3, self.relu3)
    
    def forward(self, x):
        return self.sequential(x)



class ICLabelNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.img_conv = ICLabelNetImg()
        self.psds_conv = ICLabelNetPSDS()
        self.autocorr_conv = ICLabelNetAutocorr()
        
        self.conv = nn.Conv2d(in_channels  = 712,
                              out_channels = 7,
                              kernel_size  = (4,4),
                              padding      = 0,
                              stride       = 1)
        self.softmax = nn.Softmax(dim=1)
        
        self.seq = nn.Sequential(self.conv, self.softmax)
        
    @staticmethod
    def reshape_fortran(x: torch.tensor, shape) -> torch.tensor:
        if len(x.shape) > 0:
            x = x.permute(*reversed(range(len(x.shape))))
        return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
        
    def forward(self, images, psds, autocorr):
        out_img = self.img_conv(images)
        out_psds = self.psds_conv(psds)
        out_autocorr = self.autocorr_conv(autocorr)
        
        # PSDS reshape, concat, permute
        psds_reshape = self.reshape_fortran(out_psds, [-1,1,1,100])
        psds_concat1 = torch.concat([psds_reshape, psds_reshape, psds_reshape, psds_reshape], 1)
        psds_concat2 = torch.concat([psds_concat1, psds_concat1, psds_concat1, psds_concat1], 2)
        psds_perm = torch.permute(psds_concat2, (0, 3, 1, 2))

        # Autocorr reshape, concat, permute
        autocorr_reshape = self.reshape_fortran(out_autocorr, [-1, 1, 1, 100])
        autocorr_concat1 = torch.concat([autocorr_reshape, autocorr_reshape, autocorr_reshape, autocorr_reshape], 1)
        autocorr_concat2 = torch.concat([autocorr_concat1, autocorr_concat1, autocorr_concat1, autocorr_concat1], 2)
        autocorr_perm = torch.permute(autocorr_concat2, (0, 3, 1, 2))

        concat = torch.concat([out_img, psds_perm, autocorr_perm], 1)
        
        labels = self.seq(concat)
        
        labels = labels.squeeze()
        labels = self.reshape_fortran(labels.permute(1,0), [-1, 4])
        labels = torch.mean(labels, 1)
        labels = self.reshape_fortran(labels, [7, -1])
        labels = labels.permute(1,0)
        
        return labels

In [3]:
iclabelNet = ICLabelNet()

# Load Image Params
img_params = sio.loadmat('img_params.mat')
permute = (3,2,0,1)

img_weight1 = torch.permute(torch.from_numpy(img_params['conv1_kernel']), permute)
img_weight2 = torch.permute(torch.from_numpy(img_params['conv2_kernel']), permute)
img_weight3 = torch.permute(torch.from_numpy(img_params['conv3_kernel']), permute)
img_bias1   = torch.squeeze(torch.from_numpy(img_params['conv1_bias']))
img_bias2   = torch.squeeze(torch.from_numpy(img_params['conv2_bias']))
img_bias3   = torch.squeeze(torch.from_numpy(img_params['conv3_bias']))

iclabelNet.state_dict()['img_conv.conv1.weight'][:] = img_weight1
iclabelNet.state_dict()['img_conv.conv1.bias'][:] = img_bias1
iclabelNet.state_dict()['img_conv.conv2.weight'][:] = img_weight2
iclabelNet.state_dict()['img_conv.conv2.bias'][:] = img_bias2
iclabelNet.state_dict()['img_conv.conv3.weight'][:] = img_weight3
iclabelNet.state_dict()['img_conv.conv3.bias'][:] = img_bias3

# Load PSDS Params
psds_params = sio.loadmat('psds_params.mat')

psds_weight1 = torch.permute(torch.from_numpy(psds_params['conv1_kernel']), permute)
psds_weight2 = torch.permute(torch.from_numpy(psds_params['conv2_kernel']), permute)
psds_weight3 = torch.permute(torch.from_numpy(np.expand_dims(psds_params['conv3_kernel'], axis=2)), (2,3,0,1))
psds_bias1   = torch.squeeze(torch.from_numpy(psds_params['conv1_bias']))
psds_bias2   = torch.squeeze(torch.from_numpy(psds_params['conv2_bias']))
psds_bias3   = torch.squeeze(torch.from_numpy(psds_params['conv3_bias']))

iclabelNet.state_dict()['psds_conv.conv1.weight'][:] = psds_weight1
iclabelNet.state_dict()['psds_conv.conv1.bias'][:]   = psds_bias1
iclabelNet.state_dict()['psds_conv.conv2.weight'][:] = psds_weight2
iclabelNet.state_dict()['psds_conv.conv2.bias'][:]   = psds_bias2
iclabelNet.state_dict()['psds_conv.conv3.weight'][:] = psds_weight3
iclabelNet.state_dict()['psds_conv.conv3.bias'][:]   = psds_bias3

# Load Autocorr Params
autocorr_params = sio.loadmat('autocorrs_params.mat')

autocorr_weight1 = torch.permute(torch.from_numpy(autocorr_params['conv1_kernel']), permute)
autocorr_weight2 = torch.permute(torch.from_numpy(autocorr_params['conv2_kernel']), permute)
autocorr_weight3 = torch.permute(torch.from_numpy(np.expand_dims(autocorr_params['conv3_kernel'], axis=2)), (2,3,0,1))
autocorr_bias1   = torch.squeeze(torch.from_numpy(autocorr_params['conv1_bias']))
autocorr_bias2   = torch.squeeze(torch.from_numpy(autocorr_params['conv2_bias']))
autocorr_bias3   = torch.squeeze(torch.from_numpy(autocorr_params['conv3_bias']))

iclabelNet.state_dict()['autocorr_conv.conv1.weight'][:] = autocorr_weight1
iclabelNet.state_dict()['autocorr_conv.conv1.bias'][:]   = autocorr_bias1
iclabelNet.state_dict()['autocorr_conv.conv2.weight'][:] = autocorr_weight2
iclabelNet.state_dict()['autocorr_conv.conv2.bias'][:]   = autocorr_bias2
iclabelNet.state_dict()['autocorr_conv.conv3.weight'][:] = autocorr_weight3
iclabelNet.state_dict()['autocorr_conv.conv3.bias'][:]   = autocorr_bias3

# Load Conv Params
conv_params = sio.loadmat('conv_params.mat')

conv_weight1 = torch.permute(torch.from_numpy(conv_params['conv_kernel']), permute)
conv_bias1   = torch.squeeze(torch.from_numpy(conv_params['conv_bias']))

iclabelNet.state_dict()['conv.weight'][:] = conv_weight1
iclabelNet.state_dict()['conv.bias'][:]   = conv_bias1

# Save State Dict
torch.save(iclabelNet.state_dict(), 'iclabelNet.pt')

# Format input

In [4]:
features = sio.loadmat('features.mat')['features']

images = features[0,0]
psds = features[0,1]
autocorrs = features[0,2]

formatted_images = np.concatenate((images, 
                                   -1 * images, 
                                   np.flip(images,axis=1), 
                                   np.flip(-1 * images,axis=1)), 
                                   axis=3)
formatted_psds = np.tile(psds, (1,1,1,4))
formatted_autocorrs = np.tile(autocorrs, (1,1,1,4))

img_tnsr = torch.permute(torch.from_numpy(formatted_images), (3, 2, 0, 1)).float()

psds_tnsr = torch.from_numpy(formatted_psds).float()
psds_tnsr = torch.permute(psds_tnsr, (3, 2, 0, 1)).float()

autocorrs_tnsr = torch.from_numpy(formatted_autocorrs).float()
autocorrs_tnsr = torch.permute(autocorrs_tnsr, (3, 2, 0, 1)).float()

## ICLabel Function

In [5]:
def run_iclabel(images, psds, autocorr):
    iclabelNet = ICLabelNet()
    iclabelNet.load_state_dict(torch.load('iclabelNet.pt'))
    labels = iclabelNet(images, psds, autocorr)
    return labels.detach().numpy()

In [6]:
labels_npy = run_iclabel(img_tnsr, psds_tnsr, autocorrs_tnsr)

In [7]:
np.set_printoptions(precision=4)
np.set_printoptions(suppress=True)

In [8]:
labels_npy

array([[0.9883, 0.    , 0.    , 0.    , 0.0117, 0.    , 0.    ],
       [0.0066, 0.0005, 0.9627, 0.0009, 0.0042, 0.0028, 0.0224],
       [0.1333, 0.0004, 0.0009, 0.0021, 0.7133, 0.016 , 0.1339],
       [0.8638, 0.    , 0.0002, 0.0007, 0.1096, 0.0001, 0.0256],
       [0.9924, 0.    , 0.    , 0.    , 0.0073, 0.    , 0.0002],
       [0.0061, 0.0007, 0.7202, 0.0013, 0.2251, 0.0043, 0.0423],
       [0.7828, 0.    , 0.0044, 0.002 , 0.2021, 0.0008, 0.0078],
       [0.8543, 0.    , 0.    , 0.    , 0.1453, 0.    , 0.0004],
       [0.9153, 0.    , 0.    , 0.    , 0.084 , 0.    , 0.0006],
       [0.9927, 0.    , 0.    , 0.0001, 0.0071, 0.    , 0.0001],
       [0.6861, 0.0016, 0.0017, 0.0006, 0.2643, 0.0097, 0.0359],
       [0.9949, 0.    , 0.    , 0.    , 0.0038, 0.    , 0.0012],
       [0.0611, 0.0002, 0.0072, 0.0024, 0.6473, 0.0106, 0.2712],
       [0.5056, 0.0001, 0.0003, 0.0004, 0.4238, 0.    , 0.0698],
       [0.3032, 0.    , 0.0011, 0.0005, 0.6825, 0.0006, 0.012 ],
       [0.0548, 0.0001, 0

In [9]:
labels_mat = sio.loadmat('labels.mat')['labels']
labels_mat

array([[0.9883, 0.    , 0.    , 0.    , 0.0117, 0.    , 0.    ],
       [0.0066, 0.0005, 0.9627, 0.0009, 0.0042, 0.0028, 0.0224],
       [0.1333, 0.0004, 0.0009, 0.0021, 0.7133, 0.016 , 0.1339],
       [0.8638, 0.    , 0.0002, 0.0007, 0.1096, 0.0001, 0.0256],
       [0.9924, 0.    , 0.    , 0.    , 0.0073, 0.    , 0.0002],
       [0.0061, 0.0007, 0.7202, 0.0013, 0.2251, 0.0043, 0.0423],
       [0.7828, 0.    , 0.0044, 0.002 , 0.2021, 0.0008, 0.0078],
       [0.8543, 0.    , 0.    , 0.    , 0.1453, 0.    , 0.0004],
       [0.9153, 0.    , 0.    , 0.    , 0.084 , 0.    , 0.0006],
       [0.9927, 0.    , 0.    , 0.0001, 0.0071, 0.    , 0.0001],
       [0.6861, 0.0016, 0.0017, 0.0006, 0.2643, 0.0097, 0.0359],
       [0.9949, 0.    , 0.    , 0.    , 0.0038, 0.    , 0.0012],
       [0.0611, 0.0002, 0.0072, 0.0024, 0.6473, 0.0106, 0.2712],
       [0.5056, 0.0001, 0.0003, 0.0004, 0.4238, 0.    , 0.0698],
       [0.3032, 0.    , 0.0011, 0.0005, 0.6825, 0.0006, 0.012 ],
       [0.0548, 0.0001, 0

In [16]:
np.allclose(labels_mat, labels_npy, atol=1e-7)

True