In [None]:
import copy
import glob
from imagecorruptions import corrupt, get_corruption_names
from itertools import repeat
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from pathlib import Path
from PIL import Image
import random
from scipy import stats
from sklearn.neighbors import KernelDensity as KDE
from pathos.multiprocessing import ProcessingPool as Pool
import shutil
from tqdm import tqdm
import time
import warnings

In [None]:
def mutator_group(group, I, I_dist_samples_array):
    I_prime = I.copy()

    i, j = group[0]  # Use the first point in the group as representative

    # Error distribution
    d_c1 = I_dist_samples_array[:, i, j, 0]
    d_c2 = I_dist_samples_array[:, i, j, 1]
    d_c3 = I_dist_samples_array[:, i, j, 2]

    # Fit distribution
    kde_model_c1 = KDE(bandwidth=1).fit(d_c1.reshape(-1,1))
    kde_model_c2 = KDE(bandwidth=1).fit(d_c2.reshape(-1,1))
    kde_model_c3 = KDE(bandwidth=1).fit(d_c3.reshape(-1,1))

    for i, j in group:
        # Generate sample from kde model
        distortion_sample_c1 = kde_model_c1.sample(1)
        distortion_sample_c2 = kde_model_c2.sample(1)
        distortion_sample_c3 = kde_model_c3.sample(1)

        # Save sample in location
        I_prime[i,j,0] = distortion_sample_c1
        I_prime[i,j,1] = distortion_sample_c2
        I_prime[i,j,2] = distortion_sample_c3
        
    return I_prime

In [None]:
def group_indices_by_variance(num_mut_pts_i, num_mut_pts_j, I, I_dist_samples_array, threshold=5):
    variance_map = {}
    for i, j in zip(num_mut_pts_i, num_mut_pts_j):
        variance_c1 = np.var(I_dist_samples_array[:, i, j, 0])
        variance_c2 = np.var(I_dist_samples_array[:, i, j, 1])
        variance_c3 = np.var(I_dist_samples_array[:, i, j, 2])
        avg_variance = (variance_c1 + variance_c2 + variance_c3) / 3
        bucket = int(avg_variance / threshold)
        if bucket not in variance_map:
            variance_map[bucket] = []
        variance_map[bucket].append((i, j))        
    return list(variance_map.values())

In [None]:
# Main function
def main(train_sample_lst,start,step,dist_samples,tau):   
    for corruption in ['gaussian_noise']:       
        for severity in [2]:
            sav_path = 'path_to_save_distorted_samples' + '/' + corruption + '/'
            Path(sav_path).mkdir(parents=True, exist_ok=True)

            for img_name in tqdm(train_sample_lst[start:start+step]):
                I = np.asarray(Image.open(img_name))
                I_dist_samples_lst = []
                
                for dist_sample in range(dist_samples):
                    corrupted = corrupt(I, corruption_name=corruption, severity=severity+1)
                    I_dist_samples_lst.append(corrupted)
                del corrupted    
                I_dist_samples_array = np.array(I_dist_samples_lst)
                I_c0 = I[:,:,0]
                num_mut = int(tau*(I_c0.size))               
                indices = np.array(np.where(I_c0>=0))
                df = pd.DataFrame(indices).T
                df_sample = df.sample(num_mut)
                num_mut_pts_i = df_sample[0].values
                num_mut_pts_j = df_sample[1].values                
                groups = group_indices_by_variance(num_mut_pts_i, num_mut_pts_j, I, I_dist_samples_array)               
                with Pool() as p:
                    mutated_parts = p.map(lambda group: mutator_group(group, I, I_dist_samples_array), groups)

                # Combine mutated parts
                I_prime = np.mean(mutated_parts, axis=0)
                #I_prime = np.clip(np.mean(mutated_parts, axis=0), 0, 255).astype(np.uint8)
                I_prime[I_prime < 0] = 0
                I_prime[I_prime > 255] = 255
                I_prime = I_prime*255
                img = img_name.split('put/')[1].split('.jpg')[0] + '.png'
                print(img)
                plt.imsave(os.path.join(sav_path, img), I_prime)

In [None]:
# Preprocessing step and main invocation
train_sample_lst = []
with open('input_images_list.txt') as f: 
    lines = f.readlines()
    for line in lines:
        train_sample_lst.append(line.strip()) 
        
start = 0 # index of image on input_images_list.txt to start on.
step = 1110 #index of image on input_images_list.txt to end on for a corruption type.
dist_samples = 100 # Number of distorted samples to use for distortion estimation
tau = 0.75 # Number of pixels to distort

# Run main part of code
main(train_sample_lst,start,step,dist_samples,tau)