## BIRD-GP: A Synthesized Fashion MNIST Example

### Install from Github

In [None]:
!pip install git+https://github.com/guoxuan-ma/2022_BIRD_GP

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import math
import bird_gp
import random
import torch

### Read images and generate synthesized images

In [6]:
#####################################################################
# read images
train_data = np.loadtxt("fashion_mnist_train_example_data.txt")
test_data = np.loadtxt("fashion_mnist_test_example_data.txt")

train_imgs = train_data[:, 1:]
test_imgs = test_data[:, 1:]
train_lbs = train_data[:, 0]
test_lbs = test_data[:, 0]
train_imgs = train_imgs / 255
test_imgs = test_imgs / 255

n_train = train_imgs.shape[0]
n_test = test_imgs.shape[0]
n = n_train + n_test

exp = 0
random.seed(exp)
torch.manual_seed(exp)
np.random.seed(exp)

#####################################################################
# generate images
train_quantiles = np.zeros((4, n_train))
for i in range(n_train):
    train_img_i = train_imgs[i, :]
    train_img_i = train_img_i[train_img_i > 0]
    train_quantiles[:, i] = np.quantile(train_img_i, [0, 0.25, 0.5, 0.75])

train_q0 = np.tile(train_quantiles[0, :].reshape((n_train, 1)), (1, 784))
train_q1 = np.tile(train_quantiles[1, :].reshape((n_train, 1)), (1, 784))
train_q2 = np.tile(train_quantiles[2, :].reshape((n_train, 1)), (1, 784))
train_q3 = np.tile(train_quantiles[3, :].reshape((n_train, 1)), (1, 784))

train_p0 = np.zeros((n_train, 784))
train_p1 = np.zeros((n_train, 784))
train_p2 = np.zeros((n_train, 784))
train_p3 = np.zeros((n_train, 784))

train_p3[train_imgs >= train_q3] = train_imgs[train_imgs >= train_q3]
train_p2[(train_imgs >= train_q2) & (train_imgs < train_q3)] = train_imgs[(train_imgs >= train_q2) & (train_imgs < train_q3)]
train_p1[(train_imgs >= train_q1) & (train_imgs < train_q2)] = train_imgs[(train_imgs >= train_q1) & (train_imgs < train_q2)]
train_p0[(train_imgs >= train_q0) & (train_imgs < train_q1)] = train_imgs[(train_imgs >= train_q0) & (train_imgs < train_q1)]

train_predictors = np.zeros((n_train, 28*28*4))
train_outcomes = train_imgs
for i in range(n_train):
    train_p0_i = train_p0[i, :].reshape((28, 28))
    train_p1_i = train_p1[i, :].reshape((28, 28))
    train_p2_i = train_p2[i, :].reshape((28, 28))
    train_p3_i = train_p3[i, :].reshape((28, 28))
    train_predictor_i = np.hstack((train_p0_i, train_p1_i, train_p2_i, train_p3_i))
    train_predictors[i, :] = train_predictor_i.reshape(-1)

test_quantiles = np.zeros((4, n_test))
for i in range(n_test):
    test_img_i = test_imgs[i, :]
    test_img_i = test_img_i[test_img_i > 0]
    test_quantiles[:, i] = np.quantile(test_img_i, [0, 0.25, 0.5, 0.75])


test_q0 = np.tile(test_quantiles[0, :].reshape((n_test, 1)), (1, 784))
test_q1 = np.tile(test_quantiles[1, :].reshape((n_test, 1)), (1, 784))
test_q2 = np.tile(test_quantiles[2, :].reshape((n_test, 1)), (1, 784))
test_q3 = np.tile(test_quantiles[3, :].reshape((n_test, 1)), (1, 784))

test_p0 = np.zeros((n_test, 784))
test_p1 = np.zeros((n_test, 784))
test_p2 = np.zeros((n_test, 784))
test_p3 = np.zeros((n_test, 784))

test_p3[test_imgs >= test_q3] = test_imgs[test_imgs >= test_q3]
test_p2[(test_imgs >= test_q2) & (test_imgs < test_q3)] = test_imgs[(test_imgs >= test_q2) & (test_imgs < test_q3)]
test_p1[(test_imgs >= test_q1) & (test_imgs < test_q2)] = test_imgs[(test_imgs >= test_q1) & (test_imgs < test_q2)]
test_p0[(test_imgs >= test_q0) & (test_imgs < test_q1)] = test_imgs[(test_imgs >= test_q0) & (test_imgs < test_q1)]

test_predictors = np.zeros((n_test, 28*28*4))
test_outcomes = test_imgs
for i in range(n_test):
    test_p0_i = test_p0[i, :].reshape((28, 28))
    test_p1_i = test_p1[i, :].reshape((28, 28))
    test_p2_i = test_p2[i, :].reshape((28, 28))
    test_p3_i = test_p3[i, :].reshape((28, 28))
    test_predictor_i = np.hstack((test_p0_i, test_p1_i, test_p2_i, test_p3_i))
    test_predictors[i, :] = test_predictor_i.reshape(-1)

### Generate grids

In [7]:
predictor_grids = bird_gp.generate_grids([28, 112])
outcome_grids = bird_gp.generate_grids([28, 28])

### Create BIRD_GP object and fit

In [8]:
birdgp = bird_gp.BIRD_GP(predictor_grids = predictor_grids,
                         outcome_grids = outcome_grids,
                         predictor_L = 50,
                         outcome_L = 50,
                         svgd_b_lambda = 1e2, 
                         bf_predictor_steps = 10000,
                         bf_outcome_steps = 10000,
                         device = "cpu"
                         )

In [None]:
birdgp.fit(train_predictors, train_outcomes)

### Evaluation on training data

In [None]:
train_pred = birdgp.predict_train()
np.mean((train_pred - train_outcomes)**2)

### Evaluation on testing data

In [None]:
test_pred = birdgp.predict_test(test_predictors)
np.mean((test_pred - test_outcomes)**2)