In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Literal, Union, Dict, Any
import random
from PIL import Image
import matplotlib.pyplot as plt


In [None]:
def weights_init_uniform(activation: str = 'relu'):
    if not isinstance(activation, str):
        raise ValueError('activation must be a string')
    def _weights_init_uniform(m: torch.nn.Module):
        nonlocal activation
        with torch.no_grad():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity=activation)
                gain = nn.init.calculate_gain(activation, 0)
                fan = nn.init._calculate_correct_fan(m.weight, 'fan_in')
                std = gain / math.sqrt(fan)
                if m.bias is not None:
                    m.bias.data.uniform_(-std, std)
    return _weights_init_uniform

def weights_init_normal(activation: str = 'relu'):
    if not isinstance(activation, str):
        raise ValueError('activation must be a string')
    def _weights_init_normal(m: torch.nn.Module):
        nonlocal activation
        with torch.no_grad():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity=activation)
                gain = nn.init.calculate_gain(activation, 0)
                fan = nn.init._calculate_correct_fan(m.weight, 'fan_in')
                std = gain / math.sqrt(fan)
                if m.bias is not None:
                    m.bias.data.uniform_(-std, std)
    return _weights_init_normal

In [None]:
class SkipBlock(nn.Module):
    def __init__(self,
                 in_features: int = 130,
                 out_features: int = 130,
                 in_skip_features: int = 2,
                 **kwargs) -> None:
        super().__init__()
        self.ln = nn.Linear(in_features, out_features)
        self.skp = nn.Linear(in_skip_features, out_features, bias=False)

    def forward(self, x: torch.Tensor, x_input: torch.Tensor):
        return F.relu(self.ln(x) + self.skp(x_input))

    def reset_parameters(self) -> None:
        self.ln.apply(weights_init_uniform('relu'))
        self.skp.apply(weights_init_uniform('relu'))

    def enforce_convexity(self) -> None:
        with torch.no_grad():
            self.ln.weight.data = F.relu(self.ln.weight.data)


class OutBlock(SkipBlock):
    def __init__(self,
                 in_features: int = 130,
                 out_features: int = 1,
                 in_skip_features: int = 2,
                 **kwargs) -> None:
        super().__init__(
            in_features=in_features, 
            out_features=out_features, 
            in_skip_features=in_skip_features
        )

    def forward(self, x: torch.Tensor, x_input: torch.Tensor):
        return self.ln(x) + self.skp(x_input)

    def reset_parameters(self) -> None:
        self.ln.apply(weights_init_uniform('linear'))
        self.skp.apply(weights_init_uniform('linear'))

class ConvexNextNet(nn.Module):
    def __init__(self,
                 n_hidden: int = 130,
                 in_features: int = 2,
                 out_features: int = 1,
                 n_hidden_layers: int = 1,
                 ** kwargs):
        # call constructor from superclass
        super().__init__()

        # define network layers
        self.input = nn.Linear(in_features, n_hidden)
        self.skip = nn.ModuleList([
                        SkipBlock(in_features=n_hidden, 
                                  out_features=n_hidden, 
                                  in_skip_features=in_features) for _ in range(n_hidden_layers)])
        self.out = OutBlock(
            in_features=n_hidden, 
            out_features=out_features, 
            in_skip_features=in_features)

    def reset_parameters(self) -> None:
        self.input.apply(weights_init_uniform('linear'))
        for i in range(len(self.skip)):
            self.skip[i].reset_parameters()
        self.out.reset_parameters()
        return True

    def forward(self, x):
        # define forward pass
        # Input of shape (batch_size, 2)
        x_input = x
        x = F.relu(self.input(x))
        for i in range(len(self.skip)):
            x = self.skip[i](x, x_input=x_input)
        x = self.out(x, x_input=x_input)
        return x

    def enforce_convexity(self) -> None:
        with torch.no_grad():
            for i in range(len(self.skip)):
                self.skip[i].enforce_convexity()
            self.out.enforce_convexity()

In [None]:
class SimpleResnet(nn.Module):

    def __init__(self, 
                 in_channels: int = 2,
                 mid_channels: int = 128, 
                 out_channels: int = 2,
                 num_blocks: int = 1, 
                 double_after_norm: bool = False
                 ):
        """1D ResNet for scale and translate factors in 1D Real NVP.


        Parameters
        ----------
        in_channels : int, optional
            Number of input channels, by default 2
        mid_channels : int, optional
            Number if intermediate channels, by default 128
        out_channels : int, optional
            Number of output channels, by default 2
        num_blocks : int, optional
            Number of residual blocks, by default 2
        double_after_norm : bool, optional
            If the channel values should be doubled after norming the input, by default False
        """
        super(SimpleResnet, self).__init__()
        self.in_norm = nn.BatchNorm1d(in_channels, track_running_stats=False)
        self.double_after_norm = double_after_norm
        self.in_linear = WNLinear(2 * in_channels, mid_channels, bias=True)
        self.in_skip = WNLinear(mid_channels, mid_channels, bias=True)
        self.blocks = nn.ModuleList([ResidualBlock1D(mid_channels, mid_channels)
                                     for _ in range(num_blocks)])
        self.skips = nn.ModuleList([WNLinear(mid_channels, mid_channels, bias=True)
                                    for _ in range(num_blocks)])
        self.out_norm = nn.BatchNorm1d(mid_channels, track_running_stats=False)
        self.out_linear = WNLinear(mid_channels, out_channels, bias=True)
        
        self.in_linear.apply(weights_init_normal('relu'))
        self.in_skip.apply(weights_init_normal('relu'))
        self.skips.apply(weights_init_normal('relu'))
        self.out_linear.apply(weights_init_normal('tanh'))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.in_norm(x)
        if self.double_after_norm:
            x *= 2.
        x = torch.cat((x, -x), dim=1)
        x = F.relu(x)
        x = self.in_linear(x)
        x_skip = self.in_skip(x)

        for block, skip in zip(self.blocks, self.skips):
            x = block(x)
            x_skip += skip(x)

        x = self.out_norm(x_skip)
        x = F.relu(x)
        x = self.out_linear(x)
        x = torch.tanh(x)
        return x

class ResidualBlock1D(nn.Module):
    """ResNet basic block with weight norm."""
    def __init__(self, 
                 in_channels: int = 1,
                 out_channels:int = 1,
                 **kwargs
                   ):
        super(ResidualBlock1D, self).__init__()

        self.in_norm = nn.BatchNorm1d(in_channels, track_running_stats=False)
        self.in_linear = WNLinear(in_channels, out_channels, bias=False)

        self.out_norm = nn.BatchNorm1d(out_channels, track_running_stats=False)
        self.out_linear = WNLinear(out_channels, out_channels,  bias=True)
        self.in_linear.apply(weights_init_normal('relu'))
        self.out_linear.apply(weights_init_normal('relu'))


    def forward(self, x):
        skip = x

        x = self.in_norm(x)
        x = F.relu(x)
        x = self.in_linear(x)

        x = self.out_norm(x)
        x = F.relu(x)
        x = self.out_linear(x)

        x = x + skip
        return x


class ResNet1D(nn.Module):
    """1D ResNet for scale and translate factors in 1D Real NVP."""
    
    def __init__(self, 
                 in_channels: int = 2,
                 mid_channels: int = 128, 
                 out_channels: int = 2,
                 num_blocks: int = 2, 
                 double_after_norm: bool = False
                 ):
        """1D ResNet for scale and translate factors in 1D Real NVP.


        Parameters
        ----------
        in_channels : int, optional
            Number of input channels, by default 2
        mid_channels : int, optional
            Number if intermediate channels, by default 128
        out_channels : int, optional
            Number of output channels, by default 2
        num_blocks : int, optional
            Number of residual blocks, by default 2
        double_after_norm : bool, optional
            If the channel values should be doubled after norming the input, by default False
        """
        super(ResNet1D, self).__init__()
        self.in_norm = nn.BatchNorm1d(in_channels, track_running_stats=False)
        self.double_after_norm = double_after_norm
        self.in_linear = WNLinear(2 * in_channels, mid_channels, bias=True)
        self.in_skip = WNLinear(mid_channels, mid_channels, bias=True)
        self.blocks = nn.ModuleList([ResidualBlock1D(mid_channels, mid_channels)
                                     for _ in range(num_blocks)])
        self.skips = nn.ModuleList([WNLinear(mid_channels, mid_channels, bias=True)
                                    for _ in range(num_blocks)])
        self.out_norm = nn.BatchNorm1d(mid_channels, track_running_stats=False)
        self.out_linear = WNLinear(mid_channels, out_channels, bias=True)
        
        self.in_linear.apply(weights_init_normal('relu'))
        self.in_skip.apply(weights_init_normal('relu'))
        self.skips.apply(weights_init_normal('relu'))
        self.out_linear.apply(weights_init_normal('linear'))


    def forward(self, x):
        x = self.in_norm(x)
        if self.double_after_norm:
            x *= 2.
        x = torch.cat((x, -x), dim=1)
        x = F.relu(x)
        x = self.in_linear(x)
        x_skip = self.in_skip(x)

        for block, skip in zip(self.blocks, self.skips):
            x = block(x)
            x_skip += skip(x)

        x = self.out_norm(x_skip)
        x = F.relu(x)
        x = self.out_linear(x)
        return x

class WNLinear(nn.Module):
    """Weight-normalized linear layer."""

    def __init__(self, 
                 in_channels: int = 1, 
                 out_channels: int = 1, 
                 bias=True,
                 **kwargs):
        super(WNLinear, self).__init__()
        self.linear = nn.utils.weight_norm(nn.Linear(in_channels, out_channels, bias=bias), dim=None)

    def reset_parameters(self, activation: str = 'relu'):
        with torch.no_grad():
            self.linear.weight_g.data.fill_(1)
            gain = nn.init.calculate_gain(activation, 0)
            fan = nn.init._calculate_correct_fan(self.linear.weight_v, 'fan_in')
            std = gain / math.sqrt(fan)
            nn.init.kaiming_uniform_(self.linear.weight_v, mode='fan_in', nonlinearity=activation)
            if self.linear.bias is not None:
                self.linear.bias.data.uniform_(-std, std)


    def forward(self, x):
        x = self.linear(x)
        return x

class SimpleBackbone(nn.Module):

    def __init__(self,
                 in_channels: int = 2,
                 network_width: int = 10,
                 **kwargs) -> None:
        super().__init__()
        self.linear1 = WNLinear(in_channels, network_width)
        self.linear2 = WNLinear(network_width, in_channels)
        #self.linear1.apply(weights_init_normal('relu'))
        #self.linear2.apply(weights_init_normal('linear'))

    def reset_parameters(self) -> None:
        self.linear1.apply(weights_init_uniform('relu'))
        self.linear2.apply(weights_init_uniform('tanh'))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = torch.tanh(x)
        return x
    
class NormalBlock(nn.Module):
    """Basic block with weight norm."""

    def __init__(self, 
                 in_channels: int = 1,
                 mid_channels: int = 128,
                 out_channels:int = 1,
                 **kwargs
                   ):
        super(NormalBlock, self).__init__()
        self.in_linear = WNLinear(in_channels, mid_channels, bias=True)
        self.out_linear = WNLinear(mid_channels, out_channels,  bias=True)


    def reset_parameters(self) -> None:
        self.in_linear.apply(weights_init_uniform('leaky_relu'))
        self.out_linear.apply(weights_init_uniform('tanh'))

    def forward(self, x):
        x = self.in_linear(x)
        x = F.leaky_relu(x)
        x = self.out_linear(x)
        x = torch.tanh(x)
        return x

class WNScale(nn.Module):

    weight: torch.Tensor

    def __init__(self, dim: int = 1, **kwargs) -> None:
        super().__init__()
        self.scale = nn.utils.weight_norm(nn.Linear(dim, dim))
        self.weights_init_normal(self.scale)
        self.weight = nn.Parameter(torch.tensor(
            [1.0 + 0.01 * torch.randn((1, ))]))

    def reset_parameters(self) -> None:
        self.weights_init_normal(self.scale)
        with torch.no_grad():
            self.weight.data = torch.tensor([1.0 + 0.01 * torch.randn((1, ))], 
                                                dtype=self.weight.dtype, 
                                                device=self.weight.device)
            
    def forward(self, *args, **kwargs) -> torch.Tensor:
        return self.scale(self.weight)

    def weights_init_normal(self, m):
        y = m.in_features
        m.weight.data.normal_(0.0, 1/np.sqrt(y))
        m.bias.data.fill_(0)

class NormalizingFlow1D(nn.Module):

    def __init__(self,
                 num_coupling: int = 4,
                 width: int = 130,
                 num_blocks: int = 1,
                 in_features: int = 2,
                 backbone: Literal['default', 'residual_block', 'resnet'] = 'default',
                 **kwargs
                 ):
        super(NormalizingFlow1D, self).__init__()
        self.num_coupling = num_coupling
        if self.num_coupling % in_features != 0:
            raise ValueError(
                f'Number of coupling layers should be divisible by in_features ({in_features})')
        
        _backbone: Union[SimpleBackbone, ResNet1D] = None
        args = dict(in_channels=1)
        if backbone == 'default':
            _backbone = SimpleBackbone
            args['network_width'] = width
        elif backbone == 'resnet':
            _backbone = SimpleResnet
            args['mid_channels'] = width
            args['out_channels'] = args['in_channels']
            args['num_blocks'] = num_blocks
        elif backbone == 'residual_block' or backbone == 'normal_block':
            _backbone = NormalBlock
            args['mid_channels'] = width
            args['out_channels'] = args['in_channels']
            #args['num_blocks'] = num_blocks

        else:
            raise ValueError(f'Unknown backbone: {backbone}')
        
        self.in_features = in_features
        self.s = nn.ModuleList([_backbone(**args)
                               for x in range(num_coupling)])
        self.t = nn.ModuleList([_backbone(**args)
                               for x in range(num_coupling)])
        # Learnable scaling parameters for outputs of S
        self.scale = nn.ModuleList([WNScale(dim=1)
                                   for x in range(num_coupling)])

    def reset_parameters(self) -> None:
        for s, t, scale in zip(self.s, self.t, self.scale):
            s.reset_parameters()
            t.reset_parameters()
            scale.reset_parameters()
        return True

    def forward(self, x):
        # s_vals = []
        x1, x2 = x[:, :1], x[:, 1:]
        for i in range(self.num_coupling):
            # Alternating which var gets transformed
            if i % 2 == 0:
                s = self.scale[i]() * self.s[i](x1)
                x2 = torch.exp(s) * x2 + self.t[i](x1)
            else:
                s = self.scale[i]() * self.s[i](x2)
                x1 = torch.exp(s) * x1 + self.t[i](x2)
            # s_vals.append(s)

        # Return outputs and vars needed for determinant
        return torch.cat([x1, x2], 1)  # , torch.cat(s_vals)


In [None]:
class ConvexDiffeomorphismNet(nn.Module):

    def __init__(self,
                 n_hidden: int = 130,
                 n_hidden_layers: int = 1,
                 nf_layers: int = 4,
                 nf_hidden: int = 70,
                 in_features: int = 2,
                 diffeo_args: Dict[str, Any] = None,
                 **kwargs):
        # call constructor from superclass
        super().__init__()

        self.convex_net = ConvexNextNet(
            n_hidden=n_hidden,
            in_features=in_features,
            n_hidden_layers=n_hidden_layers)
        if diffeo_args is None:
            diffeo_args = dict()
        self.in_features = in_features
        # self.diffeo_net = DiffeomorphismNet()
        if "num_coupling" not in diffeo_args:
            diffeo_args["num_coupling"] = nf_layers
        if "width" not in diffeo_args:
            diffeo_args["width"] = nf_hidden
        if "in_features" not in diffeo_args:
            diffeo_args["in_features"] = in_features

        self.diffeo_net = NormalizingFlow1D(**diffeo_args)
        self.linear = nn.Linear(in_features, in_features)
        self.linear.apply(self.weights_init_normal)

    def weights_init_normal(self, m):
        classname = type(m).__name__
        if classname.find('Linear') != -1:
            y = m.in_features
            m.weight.data.normal_(0.0, 1/np.sqrt(y))
            m.bias.data.fill_(0)

    def reset_parameters(self) -> None:
        self.convex_net.reset_parameters()
        self.diffeo_net.reset_parameters()
        self.linear.apply(self.weights_init_normal)
        return True

    def forward(self, x):
        x = self.linear(x)
        xd = self.diffeo_net(x)
        xc = self.convex_net(xd)
        return xc


In [None]:
def extractInformationFromLikelihood(likelihood, mask):
    indices = torch.nonzero(mask)
    N_fore = indices.shape[0]
    print(N_fore)
    pixel_info = torch.zeros((N_fore,2)) # store x,y values of all pixels the user marked as foreground

    labels = torch.zeros(N_fore)
    pixel_info[:,0] = indices[:,0] / nx -0.5
    pixel_info[:,1] = indices[:,1] / ny -0.5
    labels = likelihood[mask]
    return pixel_info, labels


In [None]:
img_dir="cat_scribbled.jpg"
img_pil=Image.open(img_dir)
width, height = img_pil.size 
newsize = (int(width/2), int(height/2))
img_pil = img_pil.resize(newsize)

img= np.array(img_pil, dtype='float')/255.0
img = img[:,:,0:3]
nx,ny,nc = img.shape

In [None]:
img_orig_dir="cat.jpg"
img_orig_pil=Image.open(img_orig_dir)
width, height = img_orig_pil.size 
newsize = (int(width), int(height))
img_orig_pil = img_orig_pil.resize(newsize)

img_orig= np.array(img_orig_pil, dtype='float')/255.0
img_orig = img_orig[:,:,0:3]

In [None]:
likelihood = torch.tensor((img[:,:,0]-img[:,:,1])>0.7).float()

like = Image.fromarray(255*(likelihood.detach().numpy()).astype('uint8'))
like.save("likelihood.png")

plt.imshow(likelihood)
plt.colorbar()
plt.show()
print(img[20:21,0:5,:])

In [None]:
def train(optimizer, criterion,  convexdiff, pix_fore, labels_fore, pix_back, labels_back, num_epochs, number):
    # Train the model
    loss_full = []
    print(pix_back.size())
    print(pix_fore.size())
    for epoch in range(num_epochs):
        # if epoch >= 500:
        perm = torch.randperm(pix_back.size(0))
        idx = perm[:number]
        random_pix_back = pix_back[idx,:]
        pix_back_labels = labels_back[idx]
        
        perm = torch.randperm(pix_fore.size(0))
        idx = perm[:number]
        random_pix_fore = pix_fore[idx,:]
        pix_fore_labels = labels_fore[idx]
        

        outputs_back = torch.sigmoid(convexdiff(random_pix_back)).squeeze()
        outputs_fore = torch.sigmoid(convexdiff(random_pix_fore)).squeeze()

        loss = 2*criterion(outputs_back, pix_back_labels) + 1*criterion(outputs_fore, pix_fore_labels)


        
        loss_full.append(loss)
            
            
        # Backprpagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (epoch+1) % 100 == 0:
            print ('Epoch [{}/{}],  Loss: {:.4f}' 
                .format(epoch+1, num_epochs, loss.item()))
            
    return convexdiff, loss_full

In [None]:
convexdiff = ConvexDiffeomorphismNet()


criterion = nn.BCELoss()
optimizer = torch.optim.Adam(convexdiff.parameters(), lr=1e-3)  
num_epochs = 2000
pix_back,labels_back = extractInformationFromLikelihood(likelihood,  likelihood<0.5)
pix_fore,labels_fore = extractInformationFromLikelihood(likelihood, likelihood>0.5)

number = 1000


In [None]:

convexdiff, loss_full = train(optimizer, criterion, convexdiff, pix_fore, labels_fore, pix_back, labels_back, num_epochs, number)

allPixels,temp = extractInformationFromLikelihood(likelihood,  likelihood>-0.5)

inferenceResult = convexdiff(allPixels)
inferenceResult = inferenceResult.detach().numpy().reshape((nx,ny))



In [None]:


# Plot Loss  
plt.figure(figsize=(10,4))
plt.plot(np.array([x.detach().cpu().numpy() for x in loss_full]))
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()
plt.close()

# Plot Segmentation Image
_, axs = plt.subplots(1,3,figsize=(20,4))
axs[0].imshow(img)
axs[0].set_title('Original Image')
axs[1].imshow(likelihood.detach().numpy())
axs[1].set_title('Likelihood')
axs[2].imshow(img)
axs[2].imshow(inferenceResult<0.5, cmap='binary', alpha=0.7)
axs[2].set_title('Segmented Area')
plt.show()
plt.close()

In [None]:

plt.imshow(img)
plt.contour(inferenceResult, levels=[0.0], colors='purple',linewidths=3)
#plt.imshow(inferenceResult<0.5, cmap='binary', alpha=0.7)
#plt.colorbar()
plt.axis('off')
plt.savefig('connected.png',bbox_inches='tight')
plt.show()


In [None]:

plt.imshow(img)
plt.contour(likelihood, levels=[0.5], colors='purple',linewidths=3)
plt.axis('off')
plt.savefig('connected_naive.png',bbox_inches='tight')
plt.show()


In [None]:

plt.imshow(img_orig_pil)
plt.axis('off')
plt.savefig('cat_re.png',bbox_inches='tight')
plt.show()


In [None]:

from awesome.run.functions import *

img_orig_dir="cat.jpg"
img_orig_pil=np.array(Image.open(img_orig_dir))

mask_path = './original/pc_prior_mask_rescale.png'
orig_mask = load_mask_single_channel(mask_path) / 255

like_path = './original/likelihood_rescaled.png'
likelihood = load_mask_single_channel(like_path) / 255

img = img_orig_pil
crop_y = slice(11, img.shape[0] - 12)
crop_x = slice(11, img.shape[1] - 24)

constraint_name = "pc"
image_name = "cat"
path = "./new/"
target_px = 1024
target_py = 768
actual_px = (crop_x.stop - crop_x.start)
actual_py = (crop_y.stop - crop_y.start)
# Recalculate crop start to get same aspect ratio as target_px and target_py
aspect = target_px / target_py
new_start = int(max(crop_x.start + ((actual_px - actual_py * aspect) / 2), 0))
crop_x = slice(int(new_start), int(actual_px * aspect + new_start))
actual_px = (crop_x.stop - crop_x.start)

naive = likelihood[crop_y, crop_x]
constraint = orig_mask[crop_y, crop_x]
pimg = img[crop_y, crop_x]
size = target_px / actual_px

def resize_img(path, target_px, target_py):
    img = Image.open(path)
    img = img.resize((target_px, target_py))
    img.save(path)

color = plt.get_cmap('tab10')(0)
save_path = path + f"{image_name}_{constraint_name}_naive.png"
plot_mask(pimg, naive, contour_linewidths=1, size=size, color=color, tight=True, save=True, override=True, path=save_path, auto_close=True, display=True)
resize_img(save_path, target_px, target_py)

color = plt.get_cmap('tab10')(1)
save_path = path + f"{image_name}_{constraint_name}.png"
plot_mask(pimg, constraint, size=size, color=color, tight=True, save=True, override=True, path=save_path, auto_close=True, display=True)
resize_img(save_path, target_px, target_py)