In [1]:
import sys
sys.path.append("../birdgp")

import subprocess
import torch
import numpy as np
import mnist
import matplotlib.pyplot as plt
import math
import bird_gp
import random
%matplotlib inline
# mnist.init()

def generate_sign(n = 2000, size = 28, length_min = 13, length_max = 20, width_min = 3, width_max = 5):
    center = math.floor(size / 2)
    sign = np.random.choice((-1, 1), size = n)
    length = np.random.choice(np.arange(length_min, length_max + 1), size = n)
    width = np.random.choice(np.arange(width_min, width_max + 1), size = n)
    start = np.random.choice(np.arange(4, size - length_max - 1), size = n) - center
    x = np.tile(np.arange(size), size) - center
    y = np.repeat(np.arange(size), size) - center
    
    sign_img = np.zeros((n, size**2))
    for i in range(n):
        sign_i = sign[i]
        length_i = length[i]
        width_i = width[i]
        start_i = start[i]
        minus_i = ((x >= start_i) &
                   (x <= start_i + length_i) &
                   (y >= - math.floor(width_i / 2)) &
                   (y <= - math.floor(width_i / 2) + width_i))
        if sign_i == -1:
            sign_img[i, :] = minus_i * 1
        else:
            mid_i = (length_i + 2 * start_i) / 2
            minus_i_2 = ((x >= math.floor(mid_i - width_i / 2)) &
                         (x <= math.floor(mid_i + width_i / 2)) &
                         (y >= -math.floor(length_i / 2)) &
                         (y <= -math.floor(length_i / 2) + length_i))
            sign_img[i, :] = (minus_i | minus_i_2) * 1
    
    sign_img[sign_img > 0] = 0.9 * sign_img[sign_img > 0] + np.random.normal(scale = 0.05, size = np.sum(sign_img > 0))
    return sign_img, sign


import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.linear = nn.Linear(32 * 7 * 7, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        x = self.linear(x)
        x = self.sigmoid(x)
        return x

    
class Dataset(torch.utils.data.Dataset):
    def __init__(self, predictors, labels):
        self.labels = labels
        self.predictors = predictors

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        X = self.predictors[index, :]
        y = self.labels[index, :]

        return X, y

In [2]:
train_img, train_label, test_img, test_label = mnist.load()
train_img = train_img / 255
test_img = test_img / 255

In [3]:
train_label_1_idx = np.where(train_label == 1)[0]
train_label_2_idx = np.where(train_label == 2)[0]
train_label_3_idx = np.where(train_label == 3)[0]
test_label_1_idx = np.where(test_label == 1)[0]
test_label_2_idx = np.where(test_label == 2)[0]
test_label_3_idx = np.where(test_label == 3)[0]

n = 2000
n_train = 1000
n_test = 1000

L = 50

exp = 0
random.seed(exp)
torch.manual_seed(exp)
np.random.seed(exp)
#####################################################################
# generate images
sign_img, sign = generate_sign(n = n)
train_sign = sign[0:n_train]
train_sign_img = sign_img[0:n_train, :]
test_sign = sign[n_train:n]
test_sign_img = sign_img[n_train:n, :]
train_1s = train_img[np.random.choice(train_label_1_idx, size = n_train, replace = False), :]
train_2s = train_img[np.random.choice(train_label_2_idx, size = n_train, replace = False), :]
test_1s = test_img[np.random.choice(test_label_1_idx, size = n_train, replace = False), :]
test_2s = test_img[np.random.choice(test_label_2_idx, size = n_train, replace = False), :]

train_predictors = np.zeros((n_train, 28*28*3))
train_outcomes = np.zeros((n_train, 28*28))
test_predictors = np.zeros((n_test, 28*28*3))
test_outcomes = np.zeros((n_test, 28*28))
train_outcomes_label = np.zeros(n_train)
test_outcomes_label = np.zeros(n_test)
for i in range(n_train):
    train_1_i = train_1s[i, :].reshape((28, 28))
    train_2_i = train_2s[i, :].reshape((28, 28))
    train_sign_i = train_sign[i]
    train_sign_img_i = train_sign_img[i, :].reshape((28, 28))
    train_predictor_i = np.hstack((train_2_i, train_sign_img_i, train_1_i))
    if train_sign_i == -1:
        label_img_i = train_img[np.random.choice(train_label_1_idx, size = 1, replace = False), :]
        train_outcomes_label[i] = 1
    else:
        label_img_i = train_img[np.random.choice(train_label_3_idx, size = 1, replace = False), :]
        train_outcomes_label[i] = 3
    train_predictors[i, :] = train_predictor_i.reshape(-1)
    train_outcomes[i, :] = label_img_i

for i in range(n_test):
    test_1_i = test_1s[i, :].reshape((28, 28))
    test_2_i = test_2s[i, :].reshape((28, 28))
    test_sign_i = test_sign[i]
    test_sign_img_i = test_sign_img[i, :].reshape((28, 28))
    test_predictor_i = np.hstack((test_2_i, test_sign_img_i, test_1_i))
    if test_sign_i == -1:
        label_img_i = test_img[np.random.choice(test_label_1_idx, size = 1, replace = False), :]
        test_outcomes_label[i] = 1
    else:
        label_img_i = test_img[np.random.choice(test_label_3_idx, size = 1, replace = False), :]
        test_outcomes_label[i] = 3
    test_predictors[i, :] = test_predictor_i.reshape(-1)
    test_outcomes[i, :] = label_img_i

np.savetxt("train_predictors.txt", train_predictors)
np.savetxt("test_predictors.txt", test_predictors)
np.savetxt("train_outcomes.txt", train_outcomes)
np.savetxt("test_outcomes.txt", test_outcomes)

#####################################################################
# train cnn mnist classifier
cnn = CNN()
cnn_optimizer = torch.optim.Adam(cnn.parameters(), lr = 0.001)
cnn_loss = nn.functional.binary_cross_entropy

cnnX = torch.tensor(train_outcomes, dtype = torch.float32).reshape((n_train, 1, 28, 28))
cnny = torch.tensor(train_outcomes_label, dtype = torch.float32).reshape((n_train, 1))
cnny[cnny == 1] = 0
cnny[cnny == 3] = 1

cnn_dataset = Dataset(cnnX, cnny)
cnn_dataloader = torch.utils.data.DataLoader(cnn_dataset, batch_size = 64, shuffle = True)

num_epochs = 100
cnn.train()
for epoch in range(num_epochs):
    for (idx, (X_batch, y_batch)) in enumerate(cnn_dataloader):
        output = cnn(X_batch)             
        loss = cnn_loss(output, y_batch)

        # clear gradients for this training step   
        cnn_optimizer.zero_grad()           

        # backpropagation, compute gradients 
        loss.backward()    
        # apply gradients             
        cnn_optimizer.step()  

In [None]:
predictor_grids = bird_gp.generate_grids([28, 84])
outcome_grids = bird_gp.generate_grids([28, 28])

# configurations = [[0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
#                   [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
#                   [0.5, 2.0, 0.5, 2.0, 0.5, 2.0],
#                   [2.0, 0.5, 2.0, 0.5, 2.0, 0.5], 
#                   [2.0, 2.0, 0.5, 0.5, 0.5, 0.5],
#                   [0.5, 0.5, 2.0, 2.0, 0.5, 0.5],
#                   [0.5, 0.5, 0.5, 0.5, 2.0, 2.0],
#                   [2.0, 2.0, 2.0, 2.0, 0.5, 0.5],
#                   [2.0, 2.0, 0.5, 0.5, 2.0, 2.0],
#                   [0.5, 0.5, 2.0, 2.0, 2.0, 2.0]]
configurations = [[0.8, 0.8, 0.8, 0.8, 0.8, 0.8],
                  [1.2, 1.2, 1.2, 1.2, 1.2, 1.2],
                  [0.8, 1.2, 0.8, 1.2, 0.8, 1.2],
                  [1.2, 0.8, 1.2, 0.8, 1.2, 0.8], 
                  [1.2, 1.2, 0.8, 0.8, 0.8, 0.8],
                  [0.8, 0.8, 1.2, 1.2, 0.8, 0.8],
                  [0.8, 0.8, 0.8, 0.8, 1.2, 1.2],
                  [1.2, 1.2, 1.2, 1.2, 0.8, 0.8],
                  [1.2, 1.2, 0.8, 0.8, 1.2, 1.2],
                  [0.8, 0.8, 1.2, 1.2, 1.2, 1.2]]
sensitivity_results = np.zeros((len(configurations), 8))
for c in range(len(configurations)):
    config = configurations[c]
    a_sigma, b_sigma, a_lambda, b_lambda, a_gamma, b_gamma = config
    sensitivity_results[c, 0:6] = config
    birdgp = bird_gp.BIRD_GP(predictor_grids = predictor_grids,
                             outcome_grids = outcome_grids,
                             predictor_L = 50,
                             outcome_L = 50,
                             hs_lm_a_sigma = a_sigma,
                             hs_lm_b_sigma = b_sigma,
                             svgd_a_gamma = a_gamma,
                             svgd_b_gamma = b_gamma,
                             svgd_a_lambda = a_lambda,
                             svgd_b_lambda = b_lambda, 
                             bf_predictor_steps = 10000,
                             bf_outcome_steps = 10000,
                             device = "cpu"
                             )
    birdgp.fit(train_predictors, train_outcomes)
    train_pred = birdgp.predict_train()
    test_pred = birdgp.predict_test(test_predictors)
    
    train_pred = torch.tensor(train_pred, dtype = torch.float32).reshape((n_train, 1, 28, 28))
    test_pred = torch.tensor(test_pred, dtype = torch.float32).reshape((n_test, 1, 28, 28))

    sensitivity_results[c, 6] = sum( (cnn(train_pred).detach().numpy().reshape(-1) > 0.5) == (train_outcomes_label == 3) ) / n_train
    sensitivity_results[c, 7] = sum( (cnn(test_pred).detach().numpy().reshape(-1) > 0.5) == (test_outcomes_label == 3) ) / n_test
    np.savetxt("sensitivity_results.txt", sensitivity_results)