In [None]:
from utils import load_images
import glob
import cv2
import numpy as np
from PIL import Image
from utils import add_patch, mult_patch, rand_ele_mix, rand_pix_mix
from torchvision.transforms import v2
import random
import warnings
warnings.simplefilter('ignore')

### Load the Fractals Here

In [30]:
fractal_tensor = load_images('fractals')

### Load the training dataset

- My file structure is name/class/image same as most of training data for image classification.
- If the training folders are differently arranged the code below must be modified



In [31]:
data = []
for filename in glob.glob('Potatoes/*/*.png'):
    if filename.endswith('.png'):
        im = cv2.imread(filename)
        if np.any(im):
            im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
            im = cv2.resize(im,(256,256))
            label = filename.split('/')[-2]
            if label:
                data.append([im,label])

- Define any image level transforms here from v2 in torch image transforms

In [32]:
img_level_transforms =[v2.RandomHorizontalFlip(p=1),v2.RandomVerticalFlip(p=1),
                       v2.RandomRotation(degrees=100),v2.ColorJitter(),v2.Grayscale(num_output_channels=3),
                      v2.GaussianBlur(kernel_size=53),v2.ElasticTransform(alpha=300.,sigma=7.),
                      ]
p_level_transforms = [add_patch,mult_patch,rand_pix_mix,rand_ele_mix]

In [33]:
def ipmix(img,frac=fractal_tensor,aug_method=[0,1],p_size=[4, 8, 16, 32, 64, 128, 256],
    p_level_transforms=p_level_transforms,img_level_transforms=img_level_transforms,
    k = 3, t= 3):
    img = img.copy()
    og_img = img.copy()
    frac = random.choice(fractal_tensor)
    mixin_w = np.random.dirichlet(np.ones(k) * 1, size=1)
    m = np.round(np.random.beta(1,1),1)
    x_mix = np.zeros_like(img)
    for i in range(k):
        img_copy = img.copy()
        meth = random.choice(aug_method)
        t_lst = [i for i in range(1,t+1)]
        if meth == 1:
            for j in range(random.choice(t_lst)):
                patch_sz = random.choice(p_size)
                p_func = random.choice(p_level_transforms)
                if random.random() > 0.5:
                    x_mixed = p_func(img_copy,frac,patch_sz)
                else:
                    i_func = random.choice(img_level_transforms)
                    pipe = v2.Compose([v2.ToTensor(),i_func])
                    aug_img = pipe(img_copy)
                    aug_img = aug_img * 255
                    x_mixed = aug_img.numpy()
                    x_mixed = x_mixed.transpose(1,2,0)
        elif meth == 0:
            for j in range(random.choice(t_lst)):
                i_func = random.choice(img_level_transforms)
                pipe = v2.Compose([v2.ToTensor(),i_func])
                aug_img = pipe(img_copy)
                aug_img = aug_img * 255 
                x_mixed = aug_img.numpy()
                x_mixed = x_mixed.transpose(1,2,0)

        x_mix = x_mix.astype(np.float64)
        x_mix += mixin_w[0][i] * x_mixed
    og_img = og_img.astype(np.float64)
    x_ipmix = m*(x_mix) + (1-m)*og_img
    x_ipmix = x_ipmix.astype(np.uint8)
    return x_ipmix

In [None]:
from joblib import Parallel, delayed

aug_data = []
def process_image(data_item):
    image, label = data_item
    aug_im = ipmix(image)
    return [aug_im, label]

if __name__ == "__main__":
    n_jobs = -1
    aug_data = Parallel(n_jobs=n_jobs)(delayed(process_image)(data_item) for data_item in data)