In [1]:
import cv2
import torch
import numpy as np
from sklearn.metrics import mean_absolute_error
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict, Counter
import json

from argus import load_model

from src.transforms import SimpleDepthTransform, SaltTransform, CenterCrop
from src.argus_models import SeResnextFPNProb50
from mosaic.mosaic_api import SaltData

import matplotlib.pyplot as plt
%matplotlib inline

please compile abn


In [2]:
def imshow(image, figsize=(4, 4)):
    plt.figure(figsize=figsize)
    plt.imshow(image)
    plt.show()

In [3]:
ORIG_IMAGE_SIZE = (101, 101)
PRED_IMAGE_SIZE = (128, 128)
TRANSFORM_MODE = "crop"


class EmbeddingPredictor:
    def __init__(self, model_path, device=None):
        self.model = load_model(model_path, device=None)
        self.model.nn_module.eval()

        self.depth_trns = SimpleDepthTransform()
        self.crop_trns = CenterCrop(ORIG_IMAGE_SIZE)
        self.trns = SaltTransform(PRED_IMAGE_SIZE, False, TRANSFORM_MODE)

    def __call__(self, image):
        tensor = self.depth_trns(image, 0)
        tensor = self.trns(tensor)
        tensor = tensor.unsqueeze(0).to(self.model.device)

        with torch.no_grad():
            embedding = self.model.nn_module.embedding_forward(tensor)
        embedding = embedding.cpu().numpy()[0]

        return np.log1p(embedding)

In [4]:
model_path = '/workdir/data/experiments/fpn-lovasz-se-resnext50-006/fold_0/model-246-0.851148.pth'

emb_model = EmbeddingPredictor(model_path)

In [5]:
mosaic_path = '/workdir/data/mosaic/pazzles_6013.csv'
saltdata = SaltData(mosaic_csv_path=mosaic_path, images_dir_name='images148', masks_dir_name='masks148')

In [None]:
id2emb = dict()
for id in saltdata.ids:
    emb = emb_model(saltdata.id2image[id])
    id2emb[id] = emb

In [None]:
id2emb_lst = {id: emb.tolist() for id, emb in id2emb.items()}

with open('id2emb.json', 'w') as file:
    file.write(json.dumps(id2emb_lst))

In [6]:
# mosaic_folds = [
#     [0, 3, 22, 24, 62, 73, 76, 85, 88, 99, 114, 116, 125, 133, 135],
#     [1, 2, 8, 31, 43, 45, 141],
#     [10, 14, 15, 23, 42, 65, 113, 131],
#     [4, 16, 21, 28, 33, 60, 91, 95, 96, 98, 111, 120, 123, 130, 138, 153, 182],
#     [5, 19, 29, 32, 36, 37, 40, 41, 46, 50, 51, 54, 59, 66, 67, 82, 86, 126, 136, 139, 140, 145, 156, 183],
#     [6, 9, 11, 12, 26, 30, 52, 53, 61, 70, 84, 97, 106, 124, 128, 129],
#     [57, 63, 64, 68, 81, 83, 94, 104, 105, 150, 163, 173, 174, 176],
#     [13, 17, 18, 121],
#     [7, 20, 25, 27, 34, 38, 39, 44, 47, 48, 49, 55, 56, 58, 69, 72, 74, 127, 77],
#     [35, 71, 75, 78, 79, 80, 87, 89, 90, 92, 93, 100, 101, 102, 103, 107, 108, 109, 110, 112, 115, 117, 118, 119, 122, 143, 144]
# ]

mosaic_folds = [
    [0, 3, 22, 24, 62, 73, 76, 85, 88, 99, 114, 116, 125, 133, 135] + [7, 20, 25, 27, 34, 38, 39, 44, 47, 48, 49, 55, 56, 58, 69, 72, 74, 127, 77], # 0
    [1, 2, 8, 31, 43, 45, 141] + [10, 14, 15, 23, 42, 65, 113, 131], # 1
    [4, 16, 21, 28, 33, 60, 91, 95, 96, 98, 111, 120, 123, 130, 138, 153, 182] + [35, 71, 75, 78, 79, 80, 87, 89, 90, 92, 93, 100, 101, 102, 103, 107, 108, 109, 110, 112, 115, 117, 118, 119, 122, 143, 144], # 2
    [5, 19, 29, 32, 36, 37, 40, 41, 46, 50, 51, 54, 59, 66, 67, 82, 86, 126, 136, 139, 140, 145, 156, 183], # 3
    [6, 9, 11, 12, 26, 30, 52, 53, 61, 70, 84, 97, 106, 124, 128, 129], # 4
    [57, 63, 64, 68, 81, 83, 94, 104, 105, 150, 163, 173, 174, 176] + [13, 17, 18, 121], # 5
]

fold2mosaic_ids = {i: m for i, m in enumerate(mosaic_folds)}

fold2ids = dict()
for fold, mosaic_ids in fold2mosaic_ids.items():
    fold_ids = set()
    for mosaic_id in mosaic_ids:
        mosaic = saltdata.mosaics.mosaic_id2mosaic[mosaic_id]
        fold_ids |= mosaic.ids - saltdata.test_ids
    fold2ids[fold] = fold_ids
    
id2fold = dict()
for fold, ids in fold2ids.items():
    for id in ids:
        id2fold[id] = fold
        
with open('/workdir/data/mosaic/id2pred_fold.json') as file:
    id2pred_fold = json.loads(file.read())
        
for id, fold in id2pred_fold.items():
    id2fold[id] = fold
    fold2ids[fold] |= set([id])

fold_ids = set(id2fold.keys())

In [None]:
with open('/workdir/data/mosaic/id2fold.json', 'w') as file:
    file.write(json.dumps(id2fold))

In [7]:
fold2len = {fold: len(ids) for fold, ids in fold2ids.items()}
fold2len

{0: 666, 1: 666, 2: 666, 3: 666, 4: 666, 5: 590}

In [8]:
{fold: np.mean([np.sum(saltdata.id2mask[id]) / (101*101*255) for id in ids]) for fold, ids in fold2ids.items()}

{0: 0.7039492683547187,
 1: 0.39591831808281175,
 2: 0.4927113664002205,
 3: 0.5924095353072903,
 4: 0.5639603430506284,
 5: 0.426223750080999}

# Distribute unmosaic train tiles between folds

In [None]:
from sklearn.metrics.pairwise import cosine_distances


def get_nearest(id):
    dist_lst = []
    for infold_id in id2fold:
        dist = cosine_distances(id2emb[id][np.newaxis], id2emb[infold_id][np.newaxis])[0, 0]
        dist_lst.append((dist, infold_id, id2fold[infold_id]))
    dist_lst = sorted(dist_lst, key=lambda x: x[0])
    return dist_lst

In [None]:
len(id2fold)

In [None]:
import tqdm 

train_unfold_ids = saltdata.train_ids - fold_ids

unfold_id2fold_dists = dict()
for id in tqdm.tqdm(train_unfold_ids):
    unfold_id2fold_dists[id] = get_nearest(id)

In [None]:
fold_dists = sorted(unfold_id2fold_dists.items(), key=lambda x: x[1][0][0])

id2pred_fold = dict()
pred_fold2len = defaultdict(int)

for id, nearest_dists in fold_dists:
    for dist, another_id, fold in nearest_dists:
        fold_len = fold2len[fold] + pred_fold2len[fold]
        if fold_len < 666:
            id2pred_fold[id] = fold
            pred_fold2len[fold] += 1
            break
            
    if id not in id2pred_fold:
        raise Exception

In [None]:
from collections import Counter
Counter(id2pred_fold.values())

In [None]:
{fold: len(ids) for fold, ids in fold2ids.items()}

In [None]:
with open('/workdir/data/mosaic/id2pred_fold.json', 'w') as file:
    file.write(json.dumps(id2pred_fold))

In [None]:
id = '1cec04bb12'

dist_lst = []
for infold_id in tqdm.tqdm(id2emb):
    dist = cosine_distances(id2emb[infold_id][np.newaxis], id2emb[id][np.newaxis])
    dist_lst.append((infold_id, dist))
dist_lst = sorted(dist_lst, key=lambda x: x[1])

# Nearest tile in another fold

In [None]:
from sklearn.metrics.pairwise import manhattan_distances, cosine_distances

train_id_lst = sorted(saltdata.train_ids)
train_emb_lst = [id2emb[id] for id in train_id_lst]
train_X = np.stack(train_emb_lst, axis=0)

dist_array = cosine_distances(train_X)
dist_array.shape

In [None]:
N = 0
fold_n_ids = fold2ids[N]
another_folds_ids = fold_ids - fold_n_ids
id2index_train_id_lst = {id:train_id_lst.index(id) for id in train_id_lst}

for id in fold_n_ids:
    id_index = train_id_lst.index(id)
    for anothe_id in another_folds_ids:
        if id == anothe_id:
            continue
        anothe_id_index = train_id_lst.index(anothe_id)
        
        dist = dist_array[id_index, anothe_id_index]
        if dist < 0.04:
            print(id, anothe_id, dist, id2fold[id], id2fold[anothe_id],
                  saltdata.mosaics.id2mosaic_id[id], saltdata.mosaics.id2mosaic_id[anothe_id])
    

In [None]:
N = 0
M = 1

def compare_folds(M, N):
    fold_n_ids = fold2ids[N]
    fold_m_ids = fold2ids[M]
    another_folds_ids = fold_ids - fold_n_ids
    id2index_train_id_lst = {id:train_id_lst.index(id) for id in train_id_lst}

    sum_dist = 0
    count = 0

    for id in fold_n_ids:
        id_index = train_id_lst.index(id)
        for anothe_id in fold_m_ids:
            anothe_id_index = train_id_lst.index(anothe_id)

            dist = dist_array[id_index, anothe_id_index]
            
            if dist < 0.04:
                print(M, N, id, anothe_id, dist)

            sum_dist += dist
            count += 1
    return sum_dist / count

In [None]:
dist_folds_lst = []

for i in range(len(fold2ids)):
    for j in range(i+1, len(fold2ids)):
        mean_dist = compare_folds(i, j)
        dist_folds_lst.append((i, j, mean_dist))
        print((i, j, mean_dist))

In [None]:
id1, id2 = 'a2da67afff', 'b7c462dd1c'

print(id2fold[id1], id2fold[id2])

image1 = saltdata.id2image[id1]
emb1 = emb_model(image1) 
# 1cec04bb12.png bfbb9b9149.png 7c0b76979f
image2 = saltdata.id2image[id2]
emb2 = emb_model(image2) 

imshow(image1)
imshow(image2)
cosine_distances(emb1[np.newaxis], emb2[np.newaxis])