Copyright 2022 Authors. SPDX-License-Identifier: Apache-2.0

In [None]:
#@title LICENSE
 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import jax
import functools
import jax.numpy as jnp
import flax.linen as nn
import math

from absl import logging
from flax import linen as nn
import gin
import jax
import jax.numpy as jnp
import collections
import tensorflow as tf 
import os.path as osp
import pickle
import flax

import typing_extensions
from typing import Any, Tuple, Optional
from acme import types
import dataclasses
import optax

import numpy as onp


In [None]:
import tensorflow_datasets as tfds
import numpy as np

from clu import checkpoint

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def get_mnist_data(num_samples: int = 60000):
  """Returns MNIST data as a matrix."""
  ds = tfds.load('mnist:3.*.*', split='train')
  ds = ds.batch(num_samples)
  data = next(ds.as_numpy_iterator())
  X = np.reshape(data['image'], (num_samples, -1)) / 255.
  # Returns a matrix of size `784 x num_samples`
  mnist_matrix = X.T
  mnist_matrix -= np.mean(mnist_matrix, axis=1, keepdims=True)
  return mnist_matrix

def get_mean_training_image():
  ds = tfds.load('mnist:3.*.*', split='train')
  ds = ds.batch(num_samples)
  data = next(ds.as_numpy_iterator())
  X = np.reshape(data['image'], (num_samples, -1)) / 255.
  # Returns a matrix of size `784 x num_samples`
  mnist_matrix = X.T
  return np.mean(mnist_matrix, axis=1, keepdims=True)

@jax.jit
def f(X):
  left_svd, sigma, _ = jnp.linalg.svd(X, full_matrices=False)
  return left_svd, sigma

In [None]:
num_samples = 10000
ds = tfds.load('mnist:3.*.*', split='test')
ds = ds.batch(num_samples)
data = next(ds.as_numpy_iterator())
labels = data['label']
X = np.reshape(data['image'], (num_samples, -1)) / 255.
# # Returns a matrix of size `784 x num_samples`
test_matrix = X.T
mean_training_image = get_mean_training_image()
test_matrix -= mean_training_image

In [None]:
def plot_image(image, ax=None):
  img = mean_training_image[:, 0] + image
  img = img.reshape(28, -1)
  if ax is None:
    ax = plt.imshow(img)
  else:
    ax.imshow(img)
  ax.axes.xaxis.set_ticks([])
  ax.axes.yaxis.set_ticks([])
  return ax

In [None]:
ax = plot_image(test_matrix[:, 0])

In [None]:
#@title Load the true subspace

## Add loading code.

true_subspace_d = true_subspace[:, :16]

In [None]:
true_subspace_d.shape

In [None]:
#@title Generate test images

label_to_img = collections.defaultdict(list)
for i, label in enumerate(labels):
  label_to_img[label].append(test_matrix[:, i])

for label, v in label_to_img.items():
  label_to_img[label] = np.array(v)

num_images_to_keep = 100
images_to_test = []
for label in sorted(label_to_img):
  v = label_to_img[label]
  images_to_test.append(v[:num_images_to_keep])

images_to_test = np.concatenate(images_to_test, axis=0)
images_to_test = images_to_test.T

In [None]:
@jax.jit
def solve_reconstruction(subspace, images):
  x, residuals, _, _ = jnp.linalg.lstsq(subspace, images, rcond=None)
  reconstructions = subspace @ x
  return residuals, reconstructions

def generate_labels_to_vals(residuals):
  vals = np.argsort(residuals).to_py()
  labels_to_vals = collections.defaultdict(list)
  indices = [val // num_images_to_keep for val in vals]
  for idx, val in zip(indices, vals):
    labels_to_vals[idx].append(val)
  return labels_to_vals

In [None]:
residuals, reconstructions = solve_reconstruction(true_subspace_d, images_to_test)
true_labels_to_vals = generate_labels_to_vals(residuals)
print('Residuals:', jnp.mean(residuals))

In [None]:
fig, axes = plt.subplots(3, 10, figsize=(10, 3.0))

titles = ['Ground \nTruth', 'True \n   Subspace', 'Lissa \n(320 pixels)']

for i in range(3):
  if i == 0:
    matrix_to_use = images_to_test

  elif i == 1:
    matrix_to_use = reconstructions
  else:
    matrix_to_use = lissa_reconstructions
  for label in range(10):
    img = matrix_to_use[:, true_labels_to_vals[label][1]]
    ax = axes[i][label]
    plot_image(img, ax=ax)
    if label == 0:
      ax.set_ylabel(titles[i], size=13) #  rotation=0,  labelpad=46)
      # ax.set_xticklabels([])
      # ax.set_yticklabels([])

fig.subplots_adjust(hspace = 0.0, wspace=0.0, bottom=0, top=0.01)
fig.tight_layout()