In [None]:
import sys
import os
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from copy import deepcopy
from tqdm import tqdm
import random
from random import shuffle

from IPython.display import clear_output
from matplotlib import pyplot as plt
import collections
%matplotlib inline

def live_plot(data, figsize=(16,8), title=''):
    clear_output(wait=True)
    plt.figure(figsize=figsize)
    plt.imshow(convert_to_rgb(data.astype(int)))
    plt.title(title)
    plt.show()

def convert_to_rgb(img):
    img = deepcopy(img)
    b,g,r = cv2.split(img)  
    img = cv2.merge([r,g,b])
    return img

file_num = 80

frame_dir = f'../scene_data/{file_num}'
frame_bt_dir = f'../scene_data/{file_num}_back_true'

scene_dict = {}
for file_name in os.listdir(frame_dir):
    if file_name[0] != '.':
        num, verti_pos, hori_pos, real_back_prob = file_name.replace('.jpg', '').split('_')
        scene_dict[int(num)] = {'file_name': file_name, 'num':num, 'verti_pos':int(verti_pos), 'hori_pos':int(hori_pos)}

scene_bt_dict = {}
for file_name in os.listdir(frame_bt_dir):
    if file_name[0] != '.':
        num, verti_pos, hori_pos, real_back_prob = file_name.replace('.npy', '').split('_')
        scene_bt_dict[int(num)] = {'file_name': file_name, 'num':num, 'verti_pos':int(verti_pos), 'hori_pos':int(hori_pos)}

from multiprocessing import Process, Manager
import json

def calculate_hedge_back(i, compare_frame_dict):
    return_key = i
    m_num = 2 * abs(i) + 3
    temp_dict = {}
    scene_num = scene_num_0 + i
    file_name = scene_dict[scene_num]['file_name']
    file_name_bt = scene_bt_dict[scene_num]['file_name']
    
    abs_pos_0 = np.array((scene_dict[scene_num_0]['verti_pos'], scene_dict[scene_num_0]['hori_pos']))
    abs_pos_1 = np.array((scene_dict[scene_num]['verti_pos'], scene_dict[scene_num]['hori_pos']))
    
    frame_1 = cv2.imread(f"{frame_dir}/{file_name}")
    frame_bt_1 = np.load(f"{frame_bt_dir}/{file_name_bt}").astype(int)
    temp_dict['euclidean_dis'] = np.sqrt(np.sum(np.power(abs_pos_1 - abs_pos_0, 2)/2))
    
    compare_dict = {}
    for s0 in tqdm(range(frame_0.shape[0]), ncols=70):
        for s1 in range(frame_0.shape[1]):
            segment_0 = frame_0[s0-n:s0+n, s1-n:s1+n, :]
            if segment_0.shape == set_shape:
                temp_compare = {}
                for m1 in range(-m_num, m_num+1):
                    for m2 in range(-m_num, m_num+1):
                        compare_1 = frame_1[max(s0-n+m1,0):max(s0+n+m1,0), max(s1-n+m2,0):max(s1+n+m2,0), :]
                        if compare_1.shape == set_shape:
                            diff_array = segment_0 - compare_1
                            diff = np.sum(np.power(diff_array,2))/len(diff_array)
                            temp_compare[(m1,m2)] = diff
                temp_compare = {k: v for k, v in sorted(temp_compare.items(), key=lambda item: item[1])}
                compare_dict[(s0,s1)] = np.array(list(temp_compare.keys())[0])

    compare_mean_dict = {}
    for key in tqdm(compare_dict, ncols=70):
        compare_mean_dict[key] = []
        s0, s1 = key
        for c0 in range(s0-n,s0+n):
            for c1 in range(s1-n, s1+n):
                if (c0, c1) in compare_dict:
                    compare_mean_dict[key].append(compare_dict[(c0, c1)])
    temp_dict['compare_mean_dict'] = compare_mean_dict
    compare_frame_dict[return_key] = temp_dict

scene_num_0 = random.choice(sorted(list(scene_dict.keys()))[5:-5])
file_name_0 = scene_dict[scene_num_0]['file_name']
file_name_bt_0 = scene_bt_dict[scene_num_0]['file_name']
frame_0 = cv2.imread(f"{frame_dir}/{file_name_0}")
frame_bt_0 = np.load(f"{frame_bt_dir}/{file_name_bt_0}").astype(int)

n = 1
set_shape = (n*2, n*2, 3)

compare_frame_dict = Manager().dict()

process_list = []
for i in range(-10, 11):
    if i == 0:
        continue
    p = Process(target=calculate_hedge_back, args=(i, compare_frame_dict))
    p.start()
    process_list.append(p)

for p in process_list:
    p.join()

compare_frame_dict = {k:v for k,v in compare_frame_dict.items()}

for i in compare_frame_dict:
    compare_mean_dict = compare_frame_dict[i]['compare_mean_dict']
    euclidean_dis = compare_frame_dict[i]['euclidean_dis']
    
    for key in tqdm(compare_mean_dict, ncols=70):
        compare_mean_dict[key] = np.mean(compare_mean_dict[key], axis=0)
        compare_mean_dict[key] = np.sqrt(np.sum(np.power(compare_mean_dict[key], 2)/2))

for i in tqdm(compare_frame_dict):
    compare_mean_dict = compare_frame_dict[i]['compare_mean_dict']
    euclidean_dis = compare_frame_dict[i]['euclidean_dis']

    back_compare_mean_dict = {k: v for k, v in sorted(compare_mean_dict.items(), key=lambda item: item[1]) if v <= euclidean_dis}
    hedge_compare_mean_dict = {k: v for k, v in sorted(compare_mean_dict.items(), key=lambda item: item[1]) if v > euclidean_dis}
    compare_frame_dict[i]['back'] = back_compare_mean_dict
    compare_frame_dict[i]['hedge'] = hedge_compare_mean_dict

import numpy as np
np.save(f'{file_num}_hedge_back_10.npy', compare_frame_dict)
np.save(f'{file_num}_hedge_back_frame_bt_0_10.npy', frame_bt_0)

100%|█████████████████████████████| 500/500 [1:06:28<00:00,  7.98s/it]
100%|██████████████████████| 398701/398701 [00:10<00:00, 38015.28it/s]
100%|█████████████████████████████| 500/500 [1:07:02<00:00,  8.05s/it]
100%|██████████████████████| 398701/398701 [00:10<00:00, 37906.29it/s]
100%|█████████████████████████████| 500/500 [1:44:03<00:00, 12.49s/it]
100%|██████████████████████| 398701/398701 [00:04<00:00, 89444.81it/s]
100%|█████████████████████████████| 500/500 [1:44:09<00:00, 12.50s/it]
100%|██████████████████████| 398701/398701 [00:04<00:00, 87559.63it/s]
100%|█████████████████████████████| 500/500 [1:57:50<00:00, 14.14s/it]
100%|█████████████████████| 398701/398701 [00:02<00:00, 169631.56it/s]
100%|█████████████████████████████| 500/500 [1:58:00<00:00, 14.16s/it]
100%|█████████████████████| 398701/398701 [00:01<00:00, 265717.21it/s]
100%|█████████████████████████████| 500/500 [2:11:26<00:00, 15.77s/it]
100%|█████████████████████| 398701/398701 [00:01<00:00, 252984.82it/s]
100%|█