In [1]:
from collections import defaultdict

import numpy as np

from torch.utils.data import ConcatDataset
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

import datasets
import utils

train_samples = utils.get_train_samples()
test_samples = utils.get_test_samples()
train_dataset = datasets.TestDataAnalysisDataset(train_samples,  './data/train')
test_dataset = datasets.TestDataAnalysisDataset(test_samples,  './data/test')
dataset = ConcatDataset([train_dataset, test_dataset])

edges = defaultdict(list)
for image, id in tqdm(dataset):
    edges['top'].append(np.mean(image[:1, :, :], axis=0).flatten())
    edges['bottom'].append(np.mean(image[100:, :, :], axis=0).flatten())
    edges['left'].append(np.mean(image[:, :1, :], axis=1).flatten())
    edges['right'].append(np.mean(image[:, 100:, :], axis=1).flatten())


vertical_edges = edges['top'] + edges['bottom']
horizontal_edges = edges['left'] + edges['right']

neigh = NearestNeighbors(2, n_jobs=-1, metric='euclidean')
neigh.fit(vertical_edges)
dist_vertical, ind_vertical = neigh.kneighbors()

neigh = NearestNeighbors(2, n_jobs=-1, metric='euclidean')
neigh.fit(horizontal_edges)
dist_horizontal, ind_horizontal = neigh.kneighbors()

100%|██████████| 22000/22000 [00:28<00:00, 783.94it/s]


KeyboardInterrupt: 

In [None]:
ids = [id for image, id in tqdm(dataset)]

In [None]:
def get_neighbors(ind, dist):
    neighbors = zip(range(len(dataset)), ind, dist)
    sorted_neighbors = sorted(neighbors, key=lambda x: x[2][0] / x[2][1])
    sorted_neighbors = [(id, indices[0], d[0]/d[1]) for id, indices, d in sorted_neighbors if indices[0] >= len(dataset) and d[1] > 0]
    return sorted_neighbors
    
sorted_neighbors_of_top = get_neighbors(ind_vertical, dist_vertical)
sorted_neighbors_of_left = get_neighbors(ind_horizontal, dist_horizontal)

In [None]:
from os.path import join

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from skimage.io import imread
from skimage import img_as_float
rows = 25
cols = 6
fig, ax = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))

for i, (id, idx, d) in enumerate(sorted_neighbors_of_top[:rows]):
    
    bottom = id
    top = idx
    
    if top >= len(dataset):
        top = top - len(dataset)
        
    if bottom >= len(dataset):
        bottom = bottom - len(dataset)
    
    print(d, top, bottom)
    
    image_top = img_as_float(imread(join('./data/train', 'images', ids[top]) + '.png'))[:, :, :]
    image_bottom = img_as_float(imread(join('./data/train', 'images', ids[bottom]) + '.png'))
    image_cat = np.concatenate([image_top, image_bottom], axis=0)
    
    mask_top = img_as_float(imread(join('./data/train', 'masks', ids[top]) + '.png'))
    mask_bottom = img_as_float(imread(join('./data/train', 'masks', ids[bottom]) + '.png'))
    mask_cat = np.concatenate([mask_top, mask_bottom], axis=0)
    
    ax[i][0].imshow(image_top)
    ax[i][1].imshow(image_bottom)
    ax[i][2].imshow(image_cat)
    ax[i][3].imshow(mask_cat)
    ax[i][4].imshow(mask_top)
    ax[i][5].imshow(mask_bottom)


In [None]:
from os.path import join

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from skimage.io import imread
from skimage import img_as_float
rows = 25
cols = 6
fig, ax = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))

offset = 100
for i, (id, idx, d) in enumerate(sorted_neighbors_of_left[offset:offset+rows]):
    
    bottom = id
    top = idx
    
    if top >= len(dataset):
        top = top - len(dataset)
        
    if bottom >= len(dataset):
        bottom = bottom - len(dataset)
    
    print(d, top, bottom)
    
    image_top = img_as_float(imread(join('./data/train', 'images', ids[top]) + '.png'))[:, :, :]
    image_bottom = img_as_float(imread(join('./data/train', 'images', ids[bottom]) + '.png'))
    image_cat = np.concatenate([image_top, image_bottom], axis=1)
    
    mask_top = img_as_float(imread(join('./data/train', 'masks', ids[top]) + '.png'))
    mask_bottom = img_as_float(imread(join('./data/train', 'masks', ids[bottom]) + '.png'))
    mask_cat = np.concatenate([mask_top, mask_bottom], axis=1)
    
    ax[i][0].imshow(image_top)
    ax[i][1].imshow(image_bottom)
    ax[i][2].imshow(image_cat)
    ax[i][3].imshow(mask_cat)
    ax[i][4].imshow(mask_top)
    ax[i][5].imshow(mask_bottom)
