In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.ndimage as snd

In [2]:
def rgb_to_intensity(rgb_image):
    return np.dot(rgb_image[...,:3], [0.299, 0.587, 0.114])

In [3]:
def get_train_images(directories, need_print: bool = False, get_full: bool = False):
    for directory in directories:
        if need_print:
            print(directory)

        found_images = [file_name for file_name in os.listdir(os.path.join(directory, 'img')) if file_name.endswith('.png') or file_name.endswith('.jpg')]
        for file_name in found_images:
            if need_print:
                print(file_name, end=' ')

            if get_full:
                yield (
                    os.path.join(directory, 'img', file_name),
                    os.path.join(directory, 'ann', file_name+'.json'),
                    os.path.join(directory, 'masks_machine', file_name),
                    directory.split('/')[-1],
                )
            else:
                yield os.path.join(directory, 'img', file_name)

Считаем что-то более простое (и быстрое)

In [4]:
def count_rods(json_path: str):
    with open(json_path) as mask_json:
        mask_content = mask_json.read()
        json_dict = json.loads(mask_content)
    
    return len(json_dict['objects'])

def get_total_rods_area_ratio(mask_image: np.ndarray):
    meaned_mask = np.mean(mask_image, axis=2)
    binarized_mask = np.where(meaned_mask != 0, 1, 0)
    return np.mean(binarized_mask)

Оптимизируем STA6 с помощью сверток прямо из библиотеки

In [5]:
def sta6_optimized(img: np.ndarray, conv_size: int = 0, stride: int = 1):
    image_intensity = rgb_to_intensity(img)
    
    mean_kernel = np.ones((conv_size, conv_size)) / conv_size**2
    image_mean_intensity = snd.convolve(image_intensity, mean_kernel)
    image_mean_intensity = image_mean_intensity[(conv_size-1)//2:image_mean_intensity.shape[0]-conv_size//2, (conv_size-1)//2:image_mean_intensity.shape[1]-conv_size//2]
    image_intensity = image_intensity[(conv_size-1)//2:image_intensity.shape[0]-conv_size//2, (conv_size-1)//2:image_intensity.shape[1]-conv_size//2]

    image_intensity = image_intensity[::stride, ::stride]
    image_mean_intensity = image_mean_intensity[::stride, ::stride]
    
    sta6 = np.mean(np.power(image_intensity - image_mean_intensity, 2))

    return sta6

In [6]:
dirs = ['../data/appropriate', '../data/bad']

names = []
folders = []
counts = []
ratios = []
sta6_vals = []

for image_name, ann_name, mask_name, folder_name in get_train_images(dirs, get_full=True):
    names.append(image_name.split('/')[-1])
    folders.append(folder_name)

    img = mpimg.imread(image_name)
    mask = mpimg.imread(mask_name)

    counts.append(count_rods(ann_name))
    ratios.append(get_total_rods_area_ratio(mask))
    sta6_vals.append(sta6_optimized(img, conv_size=5))

df = pd.DataFrame.from_dict({
    'Index': names,
    'Dataset': folders,
    'Rod Count': np.array(counts),
    'Area Ratio': ratios,
    'STA6': sta6_vals,
})

In [7]:
df[df['Dataset'] == 'appropriate']

Unnamed: 0,Index,Dataset,Rod Count,Area Ratio,STA6
0,IMG_6410.png,appropriate,11,0.368527,9.5e-05
1,IMG_6365.png,appropriate,7,0.172307,0.00015
2,IMG_6497.png,appropriate,2,0.789966,0.000149
3,IMG_1753.png,appropriate,8,0.158387,0.000489
4,IMG_1760.png,appropriate,37,0.194448,0.000668
5,IMG_6362.png,appropriate,7,0.461145,0.000264
6,IMG_1752.png,appropriate,8,0.167471,0.000193
7,IMG_6495.png,appropriate,3,0.755973,9.5e-05
8,IMG_1758.png,appropriate,32,0.243413,0.000573
9,IMG_6406.png,appropriate,6,0.564901,0.000177


In [8]:
df[df['Dataset'] == 'bad']

Unnamed: 0,Index,Dataset,Rod Count,Area Ratio,STA6
30,IMG_6369.png,bad,10,0.430101,0.00018
31,IMG_6417.png,bad,10,0.443148,0.00011
32,IMG_6375.png,bad,5,0.540432,0.000229
33,IMG_6376.png,bad,12,0.23816,0.000206
34,IMG_6415.png,bad,21,0.27851,0.000273
35,IMG_6364.png,bad,20,0.230936,0.000195
36,IMG_6408.png,bad,78,0.361267,0.000122
37,IMG_6416.png,bad,24,0.310809,0.000156
38,IMG_6374.png,bad,16,0.407425,0.000272
39,IMG_6422.png,bad,28,0.171472,0.000181


In [9]:
df['Dataset'] = df['Dataset'].replace(['appropriate'], 1).replace(['bad'], 0)

In [10]:
df.to_csv('../metrics/gosha_metrics.csv', index=False)