In [None]:
import math
import random
import torch
import torch.nn as nn
import torchvision.transforms as T
import matplotlib.pyplot as plt

from PIL import Image

In [None]:
class RandomErasing(nn.Module):
    def __init__(self, p=0.5, sl=0.02, sh=0.4, r1=0.3):
        super().__init__()
        self.p = p
        self.sl = sl
        self.sh = sh
        self.r1 = r1
    
    def forward(self, image):
        if random.random() < self.p:
            c, h, w = image.shape
            s = h * w
            while True:
                se = random.uniform(self.sl, self.sh) * s
                re = random.uniform(self.r1, 1.0 / self.r1)
                # This is a bit more advanced code
                # r1 = math.log(self.r1)
                # r2 = math.log(1.0 / self.r1)
                # re = math.exp(random.uniform(r1, r2))
                # 
                # why this is better code?
                # Because random.uniform(self.r1, 1.0 / self.r1) will sample re > 1 with high probability/
                # Note that re is the ratio between he and we. So, it eventually give biased random erasing
                # We can solve this problem by using log :) 
                # Note that |math.log(self.r1)| = |math.log(1.0 / self.r1)|     
                he = int(round(math.sqrt(se * re)))
                we = int(round(math.sqrt(se / re)))
                xe = random.randint(0, w)
                ye = random.randint(0, h)    
                if xe + we < w and ye + he < h:
                    image[:, ye:ye+he, xe:xe+we] = torch.rand((c, he, we))                    
                    return image
        return image


In [None]:
image = Image.open('cat.png')
display(image)

In [None]:
tensor = T.ToTensor()(image)
tensor = RandomErasing()(tensor)
output = T.ToPILImage()(tensor)
display(output)
