In [None]:
%run ApplyFilter.py
%run Loop.py
%run Metrics.py

import numpy as np
import cv2
import csv
from skimage.color import label2rgb
import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join
from time import time

starttime = time()

mat = sio.loadmat("./filterbanks/filterbanks.mat")
g = mat["RFSfilters"]

def gen_sets(tags):
    cells = []
    for c in np.unique(tags):
        cells.append(np.where(c == tags)[0])
    return cells

In [None]:
class MNet(nn.Module):
  def __init__(self, input_dim, feature_dim):
    super().__init__()
    self.input_dim = input_dim
    self.feature_dim = feature_dim
    self.conv1 = nn.Conv2d(self.input_dim, self.feature_dim, kernel_size=5, stride=1, padding=2)
    self.act1 = nn.ReLU()
    self.bn1 = nn.BatchNorm2d(self.feature_dim)
    self.conv2a = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=5, stride=1, padding=2)
    self.act2a = nn.ReLU()
    self.bn2a = nn.BatchNorm2d(self.feature_dim)
    self.conv2b = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=5, stride=1, padding=2)
    self.act2b = nn.ReLU()
    self.bn2b = nn.BatchNorm2d(self.feature_dim)
    self.conv2c = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=5, stride=1, padding=2)
    self.act2c = nn.ReLU()
    self.bn2c = nn.BatchNorm2d(self.feature_dim)
    self.conv3 = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=1, stride=1, padding=0)
    self.act3 = nn.ReLU()
    self.bn3 = nn.BatchNorm2d(self.feature_dim)

    
  def forward(self, x):
    x = self.conv1(x)
    x = self.act1(x)
    x = self.bn1(x)
    x = self.conv2a(x)
    x = self.act2a(x)
    x = self.bn2a(x)
    x = self.conv2b(x)
    x = self.act2b(x)
    x = self.bn2b(x)
    #x = self.conv2c(x)
    #x = self.act2c(x)
    #x = self.bn2c(x)
    x = self.conv3(x)
    x = self.act3(x)
    x = self.bn3(x)
    return x[0]

In [None]:
imgpath = "/data/ba/datasets/BSR/BSDS500/data/images/val/"
csvpath = "/data/ba/datasets/BSR/BSDS500/data/csv_groundTruth/val/"

filelist = ["42049.jpg", "182053.jpg", "66053.jpg", "167083.jpg", "86016.jpg", "241004.jpg"]

#filelist = listdir(imgpath)[:20]
csvlist = listdir(csvpath)
gtlist = []
for f in filelist:
    base = os.path.splitext(f)[0]
    tmp = []
    for c in csvlist:
        if base in c:
            tmp.append(c)
    gtlist.append(tmp)

for i in range(len(filelist)):
    input = cv2.imread(imgpath+filelist[i])[:,:,::-1]
    plt.figure()
    plt.imshow(input)
    plt.show()
    
    current_gts = []
    for gtfile in gtlist[i]:
        
        with open(csvpath+gtfile, newline='') as csvfile:
            reader = csv.reader(csvfile)
            gt = []
            for row in reader:
                gt.append(row)
            gt = np.array(gt).astype(np.uint8)
            plt.figure()
            plt.imshow(label2rgb(gt, input, alpha=0.4))
            plt.show()
            current_gts.append(gt)
    
    model = MNet(3, 100)
    clustered = applyFilter(input, g, mr = True, w_color=0.8, size = 1, w_spatial=1)
    segmented = run(input, clustered, model, starttime = starttime, filename = filelist[i])
    print(segmented["n_labels"])
    print(segmented["epochs"])
    plt.figure()
    plt.imshow(label2rgb(segmented["labels"], input, alpha=0.4))
    plt.show()
    
    print("Thesis (VOI):")
    for current_gt in current_gts:
        print(voi(gen_sets(segmented["labels"].flatten()), gen_sets(current_gt.flatten())))
    print("Thesis (RI):")
    for current_gt in current_gts:
        print(ri(segmented["labels"].flatten(), current_gt.flatten()))
        