In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
from awesome.run.functions import *

In [None]:
img_dir="starfish.jpg"
img_pil=Image.open(img_dir)
width, height = img_pil.size 
newsize = (int(width/2), int(height/2))
img_pil = img_pil.resize(newsize)

img= np.array(img_pil, dtype='float')/255.0
img = img[:,:,0:3]

nx,ny,nc = img.shape

likelihood = img[:,:,0]-np.mean(img[:,:,0:2], axis=2) 
likelihood = torch.from_numpy(likelihood>0.1).float()

plt.imshow(likelihood)
plt.colorbar()
plt.show()


In [None]:

class myNet(nn.Module):
    def __init__(self,n_hidden):
        # call constructor from superclass
        super().__init__()
        
        # define network layers
        self.offset = torch.nn.Parameter(torch.zeros(1,2))
        self.offset.requires_grad = False
        self.W0 = nn.Linear(2, n_hidden)
        self.W1 = nn.Linear(n_hidden, n_hidden)
        self.W2 = nn.Linear(n_hidden, 1)
        
        self.W1_r = nn.Linear(1, n_hidden)
        self.W2_r = nn.Linear(n_hidden,1)
        
        
    def forward(self, x):
        # define forward pass
        x = x+self.offset
        r = torch.sqrt(torch.sum(x**2,dim=1, keepdim=True))
        x = x/(0.01+r)
        
        x_old = F.relu(self.W0(x))
        r_aug = F.relu(self.W1(x_old)+self.W1_r(r))
        x = r*(self.W2(x_old) + self.W2_r(r_aug))-1
        return x

In [None]:

def prepare_indices(img_shape):
    ny, nx = img_shape
    x = torch.linspace(-0.5, 0.5, nx)
    y = torch.linspace(-0.5, 0.5, ny)
    X, Y = torch.meshgrid(x, y)
    xy = torch.stack((X.flatten(), Y.flatten()), dim=1)
    return xy

def extractInformationFromLikelihood(likelihood, mask):
    indices = torch.nonzero(mask)
    N_fore = indices.shape[0]
    print(N_fore)
    pixel_info = torch.zeros((N_fore,2)) # store x,y values of all pixels the user marked as foreground

    labels = torch.zeros(N_fore)
    pixel_info[:,0] = indices[:,0] / (nx-1) -0.5
    pixel_info[:,1] = indices[:,1] / (ny-1) -0.5
    labels = 1-likelihood[mask]
    return pixel_info, labels

net = myNet(150)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)  

num_epochs = 10000
pix_back,labels_back = extractInformationFromLikelihood(likelihood,  likelihood<0.5)
pix_fore,labels_fore = extractInformationFromLikelihood(likelihood, likelihood>0.5)

number = 500

# Train the model
for epoch in range(num_epochs):
    perm = torch.randperm(pix_back.size(0))
    idx = perm[:number]
    random_pix_back = pix_back[idx,:]
    pix_back_labels = labels_back[idx]
    
    perm = torch.randperm(pix_fore.size(0))
    idx = perm[:number]
    random_pix_fore = pix_fore[idx,:]
    pix_fore_labels = labels_fore[idx]
    
    random_pix = torch.concat((random_pix_back, random_pix_fore), axis=0)
    pix_labels = torch.concat((pix_back_labels, pix_fore_labels), axis=0)
    
    
    outputs = torch.sigmoid(net(random_pix)).squeeze()
    
    loss = criterion(outputs, pix_labels) 
    if epoch ==1000:
        net.offset.requires_grad = True
    #    loss += 0.1*torch.sum(torch.sigmoid(net(net.offset.data)))
    
    #print(torch.sigmoid(net(net.offset.data)))
        
    # Backprpagation and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        net.W2_r.weight.data = F.relu(net.W2_r.weight.data)

    if (epoch+1) % 400 == 0:
        print ('Epoch [{}/{}],  Loss: {:.4f}' 
               .format(epoch+1, num_epochs, loss.item()))


In [None]:
allPixels,temp = extractInformationFromLikelihood(likelihood,  likelihood>-0.5)

inferenceResult = net(allPixels) # torch tensor of size nx*ny
inferenceResult = inferenceResult.detach().numpy().reshape((nx,ny))

im = Image.fromarray(255*(inferenceResult<0.5).astype('uint8'))
im.save("mask.png")

if False:
    img[0:2,:,:]=0.0
    img[:,0:2,:]=0.0
    img[:,-2:,:]=0.0
    img[-2:,:,:]=0.0

plt.imshow(img)
plt.contour(likelihood, levels=[0.5], colors='purple')
#plt.imshow(inferenceResult<0.5, cmap='binary', alpha=0.7)
#plt.plot((-net.offset.data.detach().numpy()[0,1]+0.5)*ny, (-net.offset.data.detach().numpy()[0,0]+0.5)*nx,'x', color='purple')
#plt.plot((0.5)*ny, (0.5)*nx,'x', color='green')
plt.axis('off')
#plt.colorbar()
plt.savefig('result_naive.png',bbox_inches='tight')
plt.show()


In [None]:
x = torch.linspace(-0.5, 0.5, nx)
y = torch.linspace(-0.5, 0.5, ny)
X, Y = torch.meshgrid(x, y)
xy = torch.stack((Y.flatten(), X.flatten()), dim=1)

In [None]:
x = torch.linspace(-0.5, 0.5, nx)
y = torch.linspace(-0.5, 0.5, ny)
X, Y = torch.meshgrid(x, y)
xy = torch.stack((X.flatten(), Y.flatten()), dim=1)


def prepare_indices(img_shape):
    ny, nx = img_shape
    x = torch.linspace(-0.5, 0.5, nx)
    y = torch.linspace(-0.5, 0.5, ny)
    X, Y = torch.meshgrid(x, y)
    xy = torch.stack((X.flatten(), Y.flatten()), dim=1)
    return xy

nx, ny = img.shape[:2]
with torch.no_grad():
    pred_logits = net(prepare_indices((ny, nx))).detach().reshape(nx, ny)
    pred = torch.sigmoid(pred_logits).numpy()

In [None]:
img.shape[:2][::-1]

In [None]:
pred_logits.min()

In [None]:
from awesome.run.functions import *


mask_path = './original/mask.png'
orig_mask = load_mask_single_channel(mask_path) / 255


crop_y = slice(0, img.shape[0])
crop_x = slice(0, img.shape[1])

constraint_name = "starconvex"
image_name = "starfish"
path = "./new/"
target_px = 1024
target_py = 768
actual_px = (crop_x.stop - crop_x.start)
actual_py = (crop_y.stop - crop_y.start)
# Recalculate crop start to get same aspect ratio as target_px and target_py
aspect = target_px / target_py
new_start = int(max(round(crop_x.start + ((actual_px - actual_py * aspect)) // 2), 0))
crop_x = slice(int(new_start), int(actual_py * aspect + new_start))

actual_px = (crop_x.stop - crop_x.start)

naive = likelihood[crop_y, crop_x]
constraint = orig_mask[crop_y, crop_x]
pimg = img[crop_y, crop_x]

size = target_px / actual_px


def resize_img(path, target_px, target_py):
    img = Image.open(path)
    img = img.resize((target_px, target_py))
    img.save(path)


color = plt.get_cmap('tab10')(0)
save_path = path + f"{image_name}_{constraint_name}_naive.png"
plot_mask(pimg, naive, contour_linewidths=1, size=size, color=color, tight=True, save=True, override=True, path=save_path, auto_close=True)
resize_img(save_path, target_px, target_py)

color = plt.get_cmap('tab10')(1)
save_path = path + f"{image_name}_{constraint_name}.png"
plot_mask(pimg, constraint, size=size, color=color, tight=True, save=True, override=True, path=save_path, auto_close=True)
resize_img(save_path, target_px, target_py)

In [None]:
target_px / actual_px

In [None]:
((actual_px - actual_py * aspect) / 2)

In [None]:
actual_px - actual_py * aspect

In [None]:
torch.tensor(pimg.shape[:2]) * size

In [None]:
crop_x

In [None]:
from awesome.run.functions import *
plot_mask_multi_channel(img, np.stack([likelihood > 0.5, 1 - (pred > 0.5)], axis=2), size=5, tight=True, save=True, override=True, path='./starfish_naive_and_cvx.png')

In [None]:
def image_subsample(img: torch.Tensor, factor: int = 6, mode: Literal["grid_sample", "slicing"] = "grid_sample"):
    if mode == "grid_sample":
        x = torch.arange(-1, 1, (2 * factor) / img.shape[-2])
        y = torch.arange(-1, 1, (2 * factor) / img.shape[-1])
        xx, yy = torch.meshgrid(x, y)
        flowgrid = torch.stack((yy, xx), dim=-1).float()[None,...]
        return F.grid_sample(img[None,...], flowgrid, align_corners=True)[0, ...]
    elif mode == "slicing":
        return img[..., ::factor, ::factor]
    else:
        raise ValueError("Invalid mode")

factor = 1

img_sub = image_subsample(torch.tensor(img).permute(2,0,1).float(), factor)
res_hull_sub =  image_subsample(pred_logits.unsqueeze(0), factor)

mask_like_sub = image_subsample(likelihood.unsqueeze(0), factor) > 0.5
mask_pred_sub = res_hull_sub > 0.5


fig = plot_surface_logits(img_sub, res_hull_sub, 
    foreground_scribble_mask=torch.zeros(img_sub.shape[1:3]), 
    background_scribble_mask=torch.zeros(img_sub.shape[1:3]),
    image_subsampling=1,
    surface_log=True,
    surface_log_eps=1e-2,
    elevation=60,
    azimuth=-90,
    zoom=1.3,
    transparent=True,
    save=True, 
    path="./starfish_naive_and_cvx_surface", ext=["png", "pdf"], override=True)

fig = plot_mask_multi_channel(img, np.stack([likelihood > 0.5, 1 - (pred > 0.5)], axis=2), size=3.2, tight=True, darkening_background=0.)
inpainted_img = torch.tensor(figure_to_numpy(fig, dpi=fig.dpi, transparent=False)[:, :, :3].astype(np.float32) / 255.0).permute(2,0,1).float()
inpainted_img_sub = image_subsample(inpainted_img, factor)

fig = plot_surface_logits(inpainted_img_sub, res_hull_sub, 
    foreground_scribble_mask=torch.zeros(img_sub.shape[1:3]), 
    background_scribble_mask=torch.zeros(img_sub.shape[1:3]),
    image_subsampling=1,
    surface_log=True,
    surface_log_eps=1e-2,
    elevation=60,
    azimuth=-90,
    zoom=1.3,
    transparent=True,
    save=True, 
    path="./starfish_naive_and_cvx_surface_mask", ext=["png", "pdf"], override=True)

In [None]:
inpainted_img.shape[1:3]

In [None]:
fig = plot_surface_logits(inpainted_img, likelihood.unsqueeze(0), 
    foreground_scribble_mask=torch.zeros(inpainted_img.shape[1:3]), 
    background_scribble_mask=torch.zeros(inpainted_img.shape[1:3]),
    image_subsampling=1,
    surface_log=True,
    surface_log_eps=1e-2,
    elevation=60,
    azimuth=-90,
    zoom=1.3,
    transparent=True,
    save=True, 
    path="./starfish_naive_and_cvx_surface_hr", ext=["png"], override=True)


In [None]:
inpainted_img.shape

In [None]:
fig = plot_mask_multi_channel(img, np.stack([likelihood > 0.5, 1 - (pred > 0.5)], axis=2), size=3.2, tight=True)
inpainted_img = figure_to_numpy(fig, dpi=fig.dpi)



In [None]:
pred

In [None]:
arr = figure_to_numpy(fig, dpi=fig.dpi)
plot_as_image(arr)

In [None]:
img_sub.shape

In [None]:
torch.tensor(img).permute(2,0,1).dtype

In [None]:
%matplotlib agg

In [None]:
plot_as_image(pred_logits)

In [None]:
pred_logits.shape

In [None]:
def prepare_indices(img_shape):
    ny, nx = img_shape
    x = torch.linspace(-0.5, 0.5, nx)
    y = torch.linspace(-0.5, 0.5, ny)
    X, Y = torch.meshgrid(x, y)
    xy = torch.stack((Y, X), dim=0)
    return xy.reshape(2, -1).T
pxy = prepare_indices(img.shape[0:2])


In [None]:

def resize_img(path, lng_px):
    path = os.path.abspath(path)
    dirname = os.path.dirname(path)
    basename, ext = os.path.splitext(os.path.basename(path))
    rn_path = os.path.join(dirname, f"{basename}_resized{ext}")
    
    img = Image.open(path)
    py, px = img.size
    ratio = py / px
    target_px = lng_px
    target_py = lng_px

    if ratio > 1:
        target_py = int(lng_px / ratio)
    else:
        target_px = int(lng_px * ratio)    
    img = img.resize((target_px, target_py))
    img.save(rn_path)
resize_img("./temp/cars3_joined_axes_40_60.png", 1024)

In [None]:
resize_img("./temp/cars3_joined_axes_40_60.png", 1024)

In [None]:
os.getcwd()