## Filter image patches to handle class imbalance

#### Remove portion of images patches with all 0s and no 1s (no signal)

In [None]:
from dotenv import load_dotenv

import numpy as np
import matplotlib.pyplot as plt
import math
import os
import glob
import random

load_dotenv()
%matplotlib inline

In [None]:
in_postfix = os.getenv('DIR_LABEL_SLICED')
out_postfix = os.getenv('DIR_LABEL_FILTERED')

im_dir = os.getenv('IMAGE_DIR')
mask_dir = os.getenv('MASK_DIR')

in_im_dir = im_dir + in_postfix
in_mask_dir = mask_dir + in_postfix

out_im_dir = im_dir + out_postfix
out_mask_dir = mask_dir + out_postfix

n = 128

In [None]:
# returns boolean that indicates whether to keep training example 
def keep_sample(mask, keep_proba):
    keep = True
    rand = random.uniform(0,1)
    if rand > keep_proba:
        keep = False
    return keep

In [None]:
im_filenames = glob.glob(in_im_dir+'/*.npy', recursive=True)
mask_filenames = glob.glob(in_mask_dir+'/*.npy', recursive=True)

num_samples = len(im_filenames)

# Create output directories 
if not os.path.exists(out_im_dir):
    os.makedirs(out_im_dir)
if not os.path.exists(out_mask_dir):
    os.makedirs(out_mask_dir)
    
samples_with_zeros = 0
samples_with_ones = 0

for i, im_filename in enumerate(im_filenames):
    im = np.load(im_filename)
    filename = im_filename.split('/')[-1] #get image name
    mask = np.load(os.path.join(in_mask_dir,filename))
    if i % 1000 == 0:
        print('processing sample',i,'/',num_samples)
        print('0:',samples_with_zeros,'1:',samples_with_ones)
    if np.count_nonzero(mask) == 0: # if sample contains all 0s
        keep = keep_sample(mask, 0.03) # determine whether to keep sample based on probability threshold
        if keep:
            samples_with_zeros += 1
            np.save(os.path.join(out_im_dir,filename), im)
            np.save(os.path.join(out_mask_dir,filename), mask)
    else:
        samples_with_ones += 1
        np.save(os.path.join(out_im_dir,filename), im)
        np.save(os.path.join(out_mask_dir,filename), mask)