Recreating the original file visualization_utils.py from pytorch to tensorflow:

In [None]:
import tensorflow as tf
import numpy as np
import os
import cv2
import numpy as np
from skimage.transform import rotate

In [None]:
def preprocess_image(pil_im, resize_im=True):
    """
    Processes image for CNNs

    Args:
        pil_im (PIL.Image): Image to process
        resize_im (bool): Resize to 224 or not
    Returns:
        im_as_ten (tf.Tensor): Tensor that contains processed float32 values
    """
    # Mean and std list for channels (Imagenet)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    # Resize image
    if resize_im:
        pil_im.thumbnail((224, 224))

    im_as_arr = np.float32(pil_im)
    im_as_arr = np.transpose(im_as_arr, (2, 0, 1))  # Convert array to D,W,H

    # Normalize the channels
    for channel in range(im_as_arr.shape[0]):
        im_as_arr[channel] /= 255
        im_as_arr[channel] -= mean[channel]
        im_as_arr[channel] /= std[channel]

    # Convert to float32 tensor
    im_as_ten = tf.convert_to_tensor(im_as_arr, dtype=tf.float32)

    # Add one more channel to the beginning. Tensor shape = 1,3,224,224
    im_as_ten = tf.expand_dims(im_as_ten, axis=0)

    return im_as_ten


In [None]:
# VGG takes shape of 3 : watch out for error
def convert_to_grayscale(im_as_arr):
    """
    Converts 3D image to grayscale

    Args:
        im_as_arr (numpy array): RGB image with shape (D,W,H)
    Returns:
        grayscale_im (numpy array): Grayscale image with shape (1,W,D)
    """
    grayscale_im = np.sum(np.abs(im_as_arr), axis=0)
    im_max = np.percentile(grayscale_im, 99)
    im_min = np.min(grayscale_im)
    grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1))
    grayscale_im = np.expand_dims(grayscale_im, axis=0)

    return grayscale_im

In [None]:
def random_crop(image, crop_size):
    """
    Randomly crops the input image.

    Args:
        image (numpy array): Input image.
        crop_size (int): Size of the cropped region.

    Returns:
        numpy array: Cropped image.
    """
    height, width, _ = image.shape
    x = np.random.randint(0, width - crop_size)
    y = np.random.randint(0, height - crop_size)
    cropped_image = image[y:y + crop_size, x:x + crop_size]
    return cropped_image

def random_rotation(image, angle_range=(-15, 15)):
    """
    Randomly rotates the input image.

    Args:
        image (numpy array): Input image.
        angle_range (tuple): Range of rotation angles (default is (-15, 15)).

    Returns:
        numpy array: Rotated image.
    """
    angle = np.random.uniform(*angle_range)
    rotated_image = rotate(image, angle, mode='edge')
    return rotated_image

def random_flip(image):
    """
    Randomly flips the input image horizontally or vertically.

    Args:
        image (numpy array): Input image.

    Returns:
        numpy array: Flipped image.
    """
    if np.random.rand() < 0.5:
        flipped_image = np.fliplr(image)  # Horizontal flip
    else:
        flipped_image = np.flipud(image)  # Vertical flip
    return flipped_image


In [None]:
def augment_image(image_path, output_dir, num_augmentations=5):
    """
    Augments an image using cropping, rotation, and different flips.

    Args:
        image_path (str): Path to the input image.
        output_dir (str): Directory to save the augmented images.
        num_augmentations (int): Number of augmentations to create (default is 5).

    Returns:
        None
    """
    os.makedirs(output_dir, exist_ok=True)
    image = cv2.imread(image_path)

    for i in range(num_augmentations):
        cropped_image = random_crop(image, crop_size=min(image.shape[0], image.shape[1]) // 2)
        rotated_image = random_rotation(cropped_image)
        flipped_image = random_flip(rotated_image)

        output_image = os.path.join(output_dir, f"augmented_image_{i}.png")
        cv2.imwrite(output_image, flipped_image)
        print(f"Augmented image {i} saved at {output_image}")

In [None]:
# Example usage
input_image = 'path/to/your/input_image.png'
output_directory = 'path/to/augmented_images'

augment_image(input_image, output_directory)