In [None]:
%run ApplyFilter.py
%run Loop.py

input_files = ["images/p03_02.png", "images/p03_03.png", "images/p03_04.png", "images/p03_01.png"]
tensorboard = True

In [None]:
import os
from time import time

import cv2
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
from IPython.display import display
from skimage.color import label2rgb
from skimage.segmentation import mark_boundaries

starttime = time()
%matplotlib inline
plt.rcParams['figure.figsize'] = [15, 15]

In [None]:
class MNet(nn.Module):
  def __init__(self, input_dim, feature_dim, conv= 5):
    super().__init__()
    self.input_dim = input_dim
    self.feature_dim = feature_dim
    self.convs = conv
    self.conv1 = nn.Conv2d(self.input_dim, self.feature_dim, kernel_size=5, stride=1, padding=int(5/2))
    self.act1 = nn.ReLU()
    self.bn1 = nn.BatchNorm2d(self.feature_dim)
    self.conv2 = nn.ModuleList()
    self.act2 = nn.ModuleList()
    self.bn2 = nn.ModuleList()
    for i in range(self.convs-1):
        self.conv2.append( nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=5, stride=1, padding=int(5/2) ) )
        self.act2.append(nn.ReLU())
        self.bn2.append( nn.BatchNorm2d(self.feature_dim) )
    self.bn2b = nn.BatchNorm2d(self.feature_dim)
    self.conv2c = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=conv, stride=1, padding=int(5/2))
    self.act2c = nn.ReLU()
    self.bn2c = nn.BatchNorm2d(self.feature_dim)
    self.conv3 = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=1, stride=1, padding=0)
    self.act3 = nn.ReLU()
    self.bn3 = nn.BatchNorm2d(self.feature_dim)

    
  def forward(self, x):
    x = self.conv1(x)
    x = self.act1(x)
    x = self.bn1(x)
    for i in range(self.convs-1):
            x = self.conv2[i](x)
            x = self.act2[i]( x )
            x = self.bn2[i](x)
    x = self.conv3(x)
    x = self.act3(x)
    x = self.bn3(x)
    return x[0]


In [None]:
images = []
for file in input_files:
    try:
        images.append(cv2.cvtColor(cv2.imread(file)[:,:,::-1], cv2.COLOR_RGB2GRAY)[:,:,np.newaxis])
    except:
        pass

conv_num = [2, 3, 4, 5, 6]
    
mat = sio.loadmat("./filterbanks/filterbanks.mat")
g = mat["RFSfilters"]

starttime = time()
for i in range(len(images)):
    for n in conv_num:
        f = os.path.basename(input_files[i])
        model = MNet(1, 100, conv = n)
        clustered = applyFilter(images[i], g, mr = True)
        segmented = run(images[i], clustered, model, starttime = starttime, filename = f+"_"+str(n), stopping=["segments", 6])
        print(segmented["n_labels"])
        print(segmented["epochs"])
        plt.figure()
        plt.imshow(label2rgb(segmented["labels"], cv2.cvtColor(images[i], cv2.COLOR_GRAY2RGB), alpha=0.4))
        plt.axis('off')
        plt.savefig("gen/convolution_number/"+f+"_"+str(n)+".png", bbox_inches="tight", pad_inches = 0)
        plt.show()
        plt.figure()
        plt.imshow(mark_boundaries(cv2.cvtColor(images[i], cv2.COLOR_GRAY2RGB), segmented["labels"]))
        plt.show()