## Import dependencies

In [None]:
import numpy as np
from skimage import io
from skimage.restoration import estimate_sigma
from skimage.segmentation import slic, mark_boundaries
from sklearn.cluster import KMeans

from os.path import join
import pathlib
import matplotlib.pyplot as plt
from pprint import pprint
import csv

from typing import List
import numpy.typing as npt

In [None]:
write_files = True
write_superpixels = False
shrink_superpixels = False

dir = 'images/1'
fname_orig = 'original'
fname_spl = 'spliced'
ext = 'png'

n_superpixels = 100

## Import images

In [None]:
def load_images(dir: str, fname_orig: str, fname_spl: str, ext: str) -> List[npt.ArrayLike]:
    img = io.imread(join(str(dir), f'{fname_spl}.{ext}'))
    img = img[:,:,:3]
    orig = io.imread(join(dir, f'{fname_orig}.{ext}'))
    orig = orig[:,:,:3]

    if write_files:
        pathlib.Path(join(dir, 'artefacts')).mkdir(parents=True, exist_ok=True)

    return [orig, img]

[orig, img] = load_images(dir, fname_orig, fname_spl, ext)
img_shape = img.shape
path = join(dir, 'artefacts')

## Mark spliced

In [None]:
def mark_spliced(img: npt.ArrayLike, orig: npt.ArrayLike) -> npt.ArrayLike:
    def array_equals(a: List, b: List) -> bool:
        for i in range(len(a)):
            if not a[i] == b[i]:
                return False
        return True

    spliced_labels = []

    for i in range(img_shape[0]):
        spliced_labels_row = []
        for j in range(img_shape[1]):
            spliced_labels_row.append(
                0 if array_equals(img[i][j], orig[i][j]) else 1)
        spliced_labels.append(spliced_labels_row)

    spliced_labels = np.array(spliced_labels, dtype='uint8')

    plt.imshow(spliced_labels, cmap='gray')
    plt.show()

    if write_files:
        io.imsave(join(path, '1-splice-marked.png'), spliced_labels * 255)

    return spliced_labels

splice_marked = mark_spliced(img, orig)

## SLIC segmentation

In [None]:
def slic_segmentation(img: npt.ArrayLike):
    superpixels = slic(img, n_segments=n_superpixels, sigma=5, start_label=1)
    
    segmented = mark_boundaries(img, superpixels)

    plt.imshow(segmented)
    plt.show()

    if write_files:
        io.imsave(join(path, '2-superpixels.png'), segmented)

    return superpixels

superpixel_labels = slic_segmentation(img)

## Count superpixels

In [None]:
def count_superpixels(superpixel_labels):
    labels = set()

    for i in range(img_shape[0]):
        for j in range(img_shape[1]):
            if superpixel_labels[i][j] not in labels:
                labels.add(superpixel_labels[i][j])

    return len(labels)

superpixel_count = count_superpixels(superpixel_labels)
print(f'Superpixels count: {superpixel_count}')

## Separate superpixels

In [None]:
def get_superpixels(img: npt.ArrayLike, superpixel_labels, superpixel_count: int):
    superpixels = []
    for i in range(superpixel_count):
        superpixels.append(np.zeros(img_shape))
    
    ranges = []
    for i in range(superpixel_count):
        ranges.append({
            'min_row': -1,
            'max_row': -1,
            'min_col': -1,
            'max_col': -1
        })

    def max(a, b):
        return a if a > b else b
    def min(a, b):
        return a if a < b else b

    for i in range(img_shape[0]):
        for j in range(img_shape[1]):
            label_index = superpixel_labels[i][j] - 1

            superpixels[label_index][i][j] = img[i][j]
            
            if ranges[label_index].get('min_row') == -1:
                ranges[label_index]['min_row'] = i
            if ranges[label_index].get('min_col') == -1:
                ranges[label_index]['min_col'] = j
            ranges[label_index]['max_row'] = max(ranges[label_index].get('max_row'), i)
            ranges[label_index]['max_col'] = max(ranges[label_index].get('max_col'), j)
    
    def shrink_image(img, range):
        img_shape = img.shape
        return img[
            range.get('min_row'):range.get('max_row') + 1,
            range.get('min_col'):range.get('max_col') + 1
        ]

    if shrink_superpixels:
        for i in range(superpixel_count):
            superpixels[i] = shrink_image(superpixels[i], ranges[i])
    
    if write_superpixels:
        superpixels_path = join(path, 'superpixels')
        pathlib.Path(superpixels_path).mkdir(parents=True, exist_ok=True)
        for i in range(superpixel_count):
            io.imsave(join(superpixels_path, f'superpixel-{i}.png'), superpixels[i])

    return superpixels

superpixels = get_superpixels(img, superpixel_labels, superpixel_count)

## Estimate noise

In [None]:
def estimate_noise(img):
    return estimate_sigma(img, multichannel=True, average_sigmas=True)

In [None]:
def estimate_superpixel_noises(superpixels):
    superpixels_count = len(superpixels)

    noises = []
    for i in range(superpixels_count):
        noises.append(estimate_noise(superpixels[i]))
    
    with open(join(path, '3-noises.csv'), 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['Superpixel', 'Sigma'])
        writer.writerows(zip(range(superpixels_count), noises))
    
    plt.scatter(range(superpixels_count), noises)
    plt.savefig(join(path, '4-plot-noises.png'), bbox_inches='tight', dpi=500)
    plt.show()

    return noises

noises = estimate_superpixel_noises(superpixels)

## Cluster noises

In [None]:
def cluster_noises(noises, n_clusters: int = 2):
    data = np.array(noises)
    kmeans = KMeans(n_clusters=n_clusters, init='k-means++').fit(data.reshape(-1, 1))
    clusters = kmeans.predict(data.reshape(-1, 1))

    return clusters

noise_clusters = cluster_noises(noises)

## Detect splicing

In [None]:
def detect_splicing(noise_clusters, superpixel_labels):
    i_segmented = np.zeros((img_shape[0], img_shape[1]))

    for i in range(img_shape[0]):
        for j in range(img_shape[1]):
            i_segmented[i][j] = noise_clusters[superpixel_labels[i][j] - 1]
    
    n_0 = 0
    n_1 = 0
    for i in range(img_shape[0]):
        for j in range(img_shape[1]):
            if i_segmented[i][j] == 0:
                n_0 += 1
            else:
                n_1 += 1
    if n_0 < n_1:
        for i in range(img_shape[0]):
            for j in range(img_shape[1]):
                i_segmented[i][j] = 1 - i_segmented[i][j]

    i_segmented = np.array(i_segmented, dtype='int32')
    marked = mark_boundaries(img, i_segmented)

    plt.imshow(marked)
    plt.show()

    if write_files:
        io.imsave(join(path, '5-splice-detected.png'), marked)

    return i_segmented

splice_detected = detect_splicing(noise_clusters, superpixel_labels)

## Mask splicing

In [None]:
def mask_spliced(img, splice_detected):
    mask_img = []
    for i in range(img_shape[0]):
        mask_img_row = []
        for j in range(img_shape[1]):
            mask_pixel = img[i][j]
            if splice_detected[i][j] == 1:
                mask_pixel = mask_pixel * [0.77, 0.90, 0.22]
            mask_img_row.append(mask_pixel)
        mask_img.append(mask_img_row)
    mask_img = np.array(mask_img, dtype='uint8')

    plt.imshow(mask_img)
    plt.show()

    if write_files:
        io.imsave(join(path, '6-splice-mask.png'), mask_img)

mask_spliced(img, splice_detected)

## Show metrics

In [None]:
def get_metrics(splice_marked, splice_detected):
    tp = 0
    fp = 0
    tn = 0
    fn = 0

    mask_img = []
    for i in range(img_shape[0]):
        mask_img_row = []
        for j in range(img_shape[1]):
            mask_pixel = img[i][j]

            if splice_detected[i][j] == 1 and splice_marked[i][j] == 1:
                tp += 1
                mask_pixel = mask_pixel * [0.33, 1, 0.33]
            elif splice_detected[i][j] == 1 and splice_marked[i][j] == 0:
                fp += 1
                mask_pixel = mask_pixel * [1, 0.33, 0.33]
            elif splice_detected[i][j] == 0 and splice_marked[i][j] == 0:
                tn += 1
            elif splice_detected[i][j] == 0 and splice_marked[i][j] == 1:
                fn += 1
                mask_pixel = mask_pixel * [0.33, 0.33, 1]

            mask_img_row.append(mask_pixel)
        mask_img.append(mask_img_row)
    mask_img = np.array(mask_img, dtype='uint8')

    plt.imshow(mask_img)
    plt.show()

    if write_files:
        io.imsave(join(path, '7-splice-validation.png'), mask_img)

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f_1 = tp / (tp + ((fp + fn) / 2))

    print(f'rows: {img_shape[0]}')
    print(f'cols: {img_shape[1]}')
    print(f'total: {img_shape[0] * img_shape[1]}\n')

    print(f'true positives: {tp}')
    print(f'false positives: {fp}')
    print(f'true negatives: {tn}')
    print(f'false negatives: {fn}\n')

    print(f'precision: {precision}')
    print(f'recall: {recall}')
    print(f'f_1 score: {f_1}')

    with open(join(path, '8-metrics.txt'), 'w') as metrics_f:
        metrics_f.write(f'rows: {img_shape[0]}\n')
        metrics_f.write(f'cols: {img_shape[1]}\n')
        metrics_f.write(f'total: {img_shape[0] * img_shape[1]}\n\n')

        metrics_f.write(f'true positives: {tp}\n')
        metrics_f.write(f'false positives: {fp}\n')
        metrics_f.write(f'true negatives: {tn}\n')
        metrics_f.write(f'false negatives: {fn}\n\n')

        metrics_f.write(f'precision: {precision}\n')
        metrics_f.write(f'recall: {recall}\n')
        metrics_f.write(f'f_1 score: {f_1}\n')

get_metrics(splice_marked, splice_detected)