<a href="https://colab.research.google.com/github/dsevero/Linear-Autoregressive-Similarity-Index/blob/main/Linear-Autoregressive-Similarity-Index.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The Unreasonable Effectiveness of Linear Prediction as a Perceptual Metric

- https://github.com/dsevero/Linear-Autoregressive-Similarity-Index
- https://arxiv.org/abs/2310.05986


# Setup (run this first)

## Clone repo

In [None]:
!git clone https://github.com/dsevero/Linear-Autoregressive-Similarity-Index.git
%cd Linear-Autoregressive-Similarity-Index

Cloning into 'Linear-Autoregressive-Similarity-Index'...
remote: Enumerating objects: 88, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (75/75), done.[K
remote: Total 88 (delta 33), reused 41 (delta 12), pack-reused 0[K
Receiving objects: 100% (88/88), 1010.47 KiB | 18.37 MiB/s, done.
Resolving deltas: 100% (33/33), done.
/workspaces/Linear-Autoregressive-Similarity-Index/Linear-Autoregressive-Similarity-Index


## Install dependencies

In [None]:
!./install_dependencies.sh

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Obtaining file:///workspaces/Linear-Autoregressive-Similarity-Index/Linear-Autoregressive-Similarity-Index
  Preparing metadata (setup.py) ... [?25ldone
[?25hInstalling collected packages: lasi
  Running setup.py develop for lasi
Successfully installed lasi-0.0.0


# How to use LASI in your own code

In [3]:
from PIL import Image
from lasi import LASI
import jax
import jax.numpy as jnp

# Load images.
img_megg = jnp.array(Image.open('assets/megg.png').convert('RGB'))
img_dark_megg = jnp.array(Image.open('assets/dark-megg.png').convert('RGB'))
assert img_dark_megg.shape == img_megg.shape

# Compute the distance between img_megg and img_dark_megg.
lasi = LASI(img_megg.shape, neighborhood_size=10)
distance = jax.jit(lasi.compute_distance)(img_megg, img_dark_megg)
print(f'd(img_megg, img_dark_megg) = {distance}')

# Efficiently compute the distance between multiple images relative to a reference (img_megg).
# This function jits internally.
img_megg_offset = jnp.clip(img_megg + 20, 0 ,255)
distances = lasi.compute_distance_multiple(
    ref=img_megg, p0=img_dark_megg, p1=img_megg_offset)
print(f"d(ref, p0) = {distances['p0']}")
print(f"d(ref, p1) = {distances['p1']}")

d(img_megg, img_dark_megg) = 1.369293212890625
d(ref, p0) = 1.369293212890625
d(ref, p1) = 1.3496346473693848
