<a href="https://colab.research.google.com/github/jsk245/MS-SSIM_L1_loss/blob/main/MSSSIML1_loss_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp

In [None]:
@jax.jit
def MSSSIML1_loss(image1, image2, sigmas=(0.5, 1., 2., 4., 8.), filter_size=11, C1=.01, C2=.03, alpha=0.84):
  # image1 and image2 are the same shape Nx1xHxWxC and are floats in the range [-1.0, 1.0]
  C1 = C1**2
  C2 = C2**2
  num_scale = len(sigmas)
  batch = image1.shape[0]
  channels = image1.shape[4]
  height = image1.shape[2]
  width = image1.shape[3]
  image1 = jnp.moveaxis(jnp.reshape(image1, (batch, height, width, channels)), 3, 1)
  image2 = jnp.moveaxis(jnp.reshape(image2, (batch, height, width, channels)), 3, 1)
  image1 = (image1+1)/2
  image2 = (image2+1)/2
  image1 = jnp.reshape(image1, (batch*channels, height, width, 1))
  image2 = jnp.reshape(image2, (batch*channels, height, width, 1))
  diff = jnp.abs(image2-image1)
  width = filter_size
  w = jnp.empty((width, width, 1, 0))

# initialize the gaussian filters based on the bottom size
  for i in range(num_scale):
    weights = jnp.exp(-1.*jnp.arange(-(width//2), width//2+1)**2/(2*sigmas[i]**2))
    weights = weights[:, None]@weights[None, :]
    weights = weights/jnp.sum(weights)							# normailization
    weights = jnp.reshape(weights, (width, width, 1, 1)) 		# reshape to 4D
    w = jnp.concatenate([w, weights], axis=3)

  w = jnp.transpose(w,[3,2,0,1])
  image1 = jnp.transpose(image1,[0,3,1,2])
  image2 = jnp.transpose(image2,[0,3,1,2])
  
  mux = jax.lax.conv(image1, w, (1,1), "SAME")
  muy = jax.lax.conv(image2, w, (1,1), "SAME")
  sigmax2 = jax.lax.conv(image1**2, w, (1,1), "SAME") - mux **2
  sigmay2 = jax.lax.conv(image2**2, w, (1,1), "SAME") - muy **2
  sigmaxy = jax.lax.conv(image1*image2, w, (1,1), "SAME") - mux * muy
  l = (2 * mux * muy + C1)/(mux ** 2 + muy **2 + C1)
  cs = (2 * sigmaxy + C2)/(sigmax2 + sigmay2 + C2)

  Pcs = jnp.prod(cs, axis=1)
  l1_loss = jnp.mean(jax.lax.conv(jnp.transpose(diff,[0,3,1,2]), w[-1,:,:,:][None,:,:,:], (1,1), "SAME"))

  return alpha * (1 - jnp.mean(l[:, -1, :, :] * Pcs)) + (1 - alpha) * l1_loss

In [None]:
@jax.jit
def MSSSIML1_loss_vectorized(image1, image2, sigmas=(0.5, 1., 2., 4., 8.), filter_size=11, C1=.01, C2=.03, alpha=0.84):
  # image1 and image2 are the same shape Nx1xHxWxC and are floats in the range [-1.0, 1.0]
  C1 = C1**2
  C2 = C2**2
  num_scale = len(sigmas)
  batch = image1.shape[0]
  channels = image1.shape[4]
  height = image1.shape[2]
  width = image1.shape[3]
  image1 = jnp.moveaxis(jnp.reshape(image1, (batch, height, width, channels)), 3, 1)
  image2 = jnp.moveaxis(jnp.reshape(image2, (batch, height, width, channels)), 3, 1)
  image1 = (image1+1)/2
  image2 = (image2+1)/2
  image1 = jnp.reshape(image1, (batch*channels, height, width, 1))
  image2 = jnp.reshape(image2, (batch*channels, height, width, 1))
  diff = jnp.abs(image2-image1)
  width = filter_size

  # initialize the gaussian filters based on the bottom size
  w = jnp.exp((-1.*jnp.arange(-(width//2), width//2+1)**2)[None,:] / (2*jnp.array(sigmas)**2)[:,None])
  w = w[:,:,None]@w[:,None,:]
  w = w / jnp.sum(w, axis=(1,2))[:,None,None]
  w = jnp.reshape(w, (num_scale, 1, width, width))
  w = jnp.transpose(w, axes=(3,2,1,0))

  w = jnp.transpose(w,[3,2,0,1])
  image1 = jnp.transpose(image1,[0,3,1,2])
  image2 = jnp.transpose(image2,[0,3,1,2])
  
  mux = jax.lax.conv(image1, w, (1,1), "SAME")
  muy = jax.lax.conv(image2, w, (1,1), "SAME")
  sigmax2 = jax.lax.conv(image1**2, w, (1,1), "SAME") - mux **2
  sigmay2 = jax.lax.conv(image2**2, w, (1,1), "SAME") - muy **2
  sigmaxy = jax.lax.conv(image1*image2, w, (1,1), "SAME") - mux * muy
  l = (2 * mux * muy + C1)/(mux ** 2 + muy **2 + C1)
  cs = (2 * sigmaxy + C2)/(sigmax2 + sigmay2 + C2)

  Pcs = jnp.prod(cs, axis=1)
  l1_loss = jnp.mean(jax.lax.conv(jnp.transpose(diff,[0,3,1,2]), w[-1,:,:,:][None,:,:,:], (1,1), "SAME"))

  return alpha * (1 - jnp.mean(l[:, -1, :, :] * Pcs)) + (1 - alpha) * l1_loss

In [None]:
img1 = jnp.ones((4,1,224,224,3)) * 0.5
img2 = jnp.ones((4,1,224,224,3)) * -0.5

In [None]:
% timeit MSSSIML1_loss(img1,img2)

In [None]:
% timeit MSSSIML1_loss_vectorized(img1,img2)

In [None]:
MSSSIML1_loss_vectorized(img1,img2)