In [1]:
import torch
import numpy as np
import time
import cv2
import os

In [2]:
rng = np.random.default_rng(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cpu = 'cpu'

In [3]:
folder = "./data/"

img_files = [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

bases = []
actives = []

for _file in img_files:
    img = cv2.imread(_file, cv2.IMREAD_GRAYSCALE).astype(np.uint8)
    rnd = (rng.random(img.shape, dtype = np.float32).astype(np.uint8) * 255).astype(np.uint8)
    bases.append(img)
    actives.append(rnd)

In [4]:
torch.manual_seed(0)

def mat_runner(bases, actives, f):
    total_time = 0
    for i in range(len(bases)):
        b = torch.from_numpy(bases[i]).to(dtype=torch.uint8).to(cpu)
        a = torch.from_numpy(actives[i]).to(dtype=torch.uint8).to(cpu)
        
        start_time = time.perf_counter()
        b = b.to(device)
        a = a.to(device)
        res = f(b, a)
        res = res.to(cpu)
        end_time = time.perf_counter()

        del a
        del b
        total_time += (end_time - start_time) * 1000
    return total_time

def mat_runner_float(bases, actives, f):
    total_time = 0
    for i in range(len(bases)):
        base = bases[i].astype(np.float32)
        active = actives[i].astype(np.float32)
        b = torch.from_numpy(base).to(dtype=torch.float32).to(cpu)
        a = torch.from_numpy(active).to(dtype=torch.float32).to(cpu)
        opacity = torch.from_numpy(rng.random(1, dtype = np.float32)).to(dtype=torch.float32).to(cpu)

        start_time = time.perf_counter()
        b = b.to(device)
        a = a.to(device)
        opacity = opacity.to(device)
        res = f(b, a, opacity)
        res = res.to(cpu)
        end_time = time.perf_counter()
        
        del a
        del b
        total_time += (end_time - start_time) * 1000
    return total_time

def vec_runner_int(bases, actives, f):
    total_time = 0
    for i in range(len(bases)):
        base = bases[i].flatten()
        active = actives[i].flatten()
        b = torch.from_numpy(base).to(dtype=torch.uint8).to(cpu)
        a = torch.from_numpy(active).to(dtype=torch.uint8).to(cpu)
        opacity = torch.from_numpy(rng.random(1, dtype = np.float32).astype(np.uint8)).to(dtype=torch.uint8).to(cpu)

        start_time = time.perf_counter()
        b = b.to(device)
        a = a.to(device)
        opacity = opacity.to(device)
        res = f(b, a, opacity)
        res = res.to(cpu)
        end_time = time.perf_counter()

        del a
        del b
        total_time += (end_time - start_time) * 1000
    return total_time

def vec_runner_float(bases, actives, f):
    total_time = 0
    for i in range(len(bases)):
        base = bases[i].flatten().astype(np.float32)
        active = actives[i].flatten().astype(np.float32)
        b = torch.from_numpy(base).to(dtype=torch.float32).to(cpu)
        a = torch.from_numpy(active).to(dtype=torch.float32).to(cpu)
        opacity = torch.from_numpy(rng.random(1, dtype = np.float32)).to(dtype=torch.float32).to(cpu)

        start_time = time.perf_counter()
        b = b.to(device)
        a = a.to(device)
        opacity = opacity.to(device)
        res = f(b, a, opacity)
        res = res.to(cpu)
        end_time = time.perf_counter()
        
        del a
        del b
        total_time += (end_time - start_time) * 1000
    return total_time



def timer(input1, input2, f, runner):
    runs = 10
    times = []
    for _ in range(runs):
        times.append(runner(input1, input2, f))
    times = np.array(times)
    print(f"{f.__name__}_with_load")
    print(f"{np.average(times)}ms +/- {np.std(times)}ms")

In [5]:
### PYTORCH

In [None]:
def dissolve_blend_8_torch(base, active, opacity):
    return torch.where(torch.greater_equal(opacity - ((torch.randint(1, 2147483647, base.shape, dtype=torch.int32, device='cuda') % 100) + 1) / 100, 0), active, base)  

In [6]:
def darken_blend_8_torch(base, active):
  return torch.where(torch.greater(base, active), active, base)

In [7]:
def color_burn_8_torch(base, active):
  return torch.where(torch.eq(active, 0), 255, 255 - (255 - base) // active)

In [8]:
def lighten_blend_8_torch(base, active):
  return torch.where(torch.less(base, active), active, base)

In [9]:
def color_dodge_8_torch(base, active):
  return torch.where(torch.eq(active, 255), 255, base // (255 - active))

In [10]:
def overlay_blend_8_torch(base, active):
  return torch.where(torch.greater_equal(base, 128), 2 * base + base - 2 * base * base // 255 - 128, 2 * base * base // 128)

In [11]:
def multiply_blend_8_torch(base, active):
  return base * active // 255

In [12]:
def linear_burn_8_torch(base, active):
  return base + active - 255

In [13]:
def screen_blend_8_torch(base, active):
  return base + active - base * active // 255

In [14]:
def linear_dodge_8_torch(base, active):
  return base + active

In [16]:
def normal_blend_f_torch(base, active, opacity):
  return opacity * active + (1-opacity)*base

In [17]:
def normal_blend_8_torch(base, active, opacity):
  return opacity * active + (255 - opacity) * base

In [18]:
timer(bases, actives, darken_blend_8_torch, mat_runner)

darken_blend_8_torch_with_load
886.917199101299ms +/- 25.179270121631575ms


In [19]:
timer(bases, actives, color_burn_8_torch, mat_runner)

color_burn_8_torch_with_load
996.5826189611107ms +/- 9.0763955395946ms


In [20]:
timer(bases, actives, lighten_blend_8_torch, mat_runner)

lighten_blend_8_torch_with_load
862.4261014629155ms +/- 1.4966379041495774ms


In [21]:
timer(bases, actives, color_dodge_8_torch, mat_runner)

color_dodge_8_torch_with_load
965.8814419992268ms +/- 2.0656653390533535ms


In [22]:
timer(bases, actives, overlay_blend_8_torch, mat_runner)

overlay_blend_8_torch_with_load
1149.721364909783ms +/- 2.5659686212909647ms


In [23]:
timer(bases, actives, multiply_blend_8_torch, mat_runner)

multiply_blend_8_torch_with_load
871.0549494251609ms +/- 3.5123073793141035ms


In [24]:
timer(bases, actives, linear_burn_8_torch, mat_runner)

linear_burn_8_torch_with_load
868.9940991811454ms +/- 2.658514464542843ms


In [25]:
timer(bases, actives, screen_blend_8_torch, mat_runner)

screen_blend_8_torch_with_load
921.2759136687964ms +/- 3.3708129938868927ms


In [26]:
timer(bases, actives, linear_dodge_8_torch, mat_runner)

linear_dodge_8_torch_with_load
836.9474723469466ms +/- 5.603529521880382ms


In [27]:
timer(bases, actives, normal_blend_f_torch, vec_runner_float)

normal_blend_f_torch_with_load
2458.7184565141797ms +/- 39.03333054253607ms


In [28]:
timer(bases, actives, normal_blend_8_torch, vec_runner_int)

normal_blend_8_torch_with_load
839.7484925109893ms +/- 4.779726011336703ms


In [None]:
timer(bases, actives, dissolve_blend_8_torch, mat_runner_float)