In [386]:
import sys 
sys.path.append('..')
import torch
from torch.nn.functional import one_hot
import numpy as np
import matplotlib.pyplot as plt
from random_features.polynomial_sketch import PolynomialSketch

In [387]:
import util.data

In [684]:
train_data, train_labels = torch.load('../../datasets/export/fashion_mnist/pytorch/train_fashion_mnist.pth')
test_data, test_labels = torch.load('../../datasets/export/fashion_mnist/pytorch/test_fashion_mnist.pth')

In [478]:
train_data, train_labels = torch.load('../../datasets/export/mnist/pytorch/train_mnist.pth')
test_data, test_labels = torch.load('../../datasets/export/mnist/pytorch/test_mnist.pth')

In [470]:
train_data, train_labels = torch.load('../../datasets/export/adult/pytorch/train_adult.pth')
test_data, test_labels = torch.load('../../datasets/export/adult/pytorch/test_adult.pth')

In [110]:
train_data, train_labels = torch.load('../../datasets/export/cod-rna/pytorch/train_cod-rna.pth')
test_data, test_labels = torch.load('../../datasets/export/cod-rna/pytorch/test_cod-rna.pth')

In [509]:
train_data, train_labels = torch.load('/home/jonas/python-projects/datasets/export/cifar10/pytorch/resnet34_final_conv_train.pth')
test_data, test_labels = torch.load('/home/jonas/python-projects/datasets/export/cifar10/pytorch/resnet34_final_conv_test.pth')

In [510]:
# degree = 3
# a = 2.
# bias = 1.-2./a**2
# lengthscale = a / np.sqrt(2.)
degree = 3
bias = 1
lengthscale = np.sqrt(train_data.shape[1])

In [511]:
train_data = train_data.reshape(len(train_data), -1)
test_data = test_data.reshape(len(test_data), -1)

In [512]:
train_labels = train_labels.float()
test_labels = test_labels.float()

In [513]:
# min-max
min_val = torch.min(train_data, 0)[0]
val_range = torch.max(train_data, 0)[0] - min_val
val_range[val_range == 0] = 1
train_data = (train_data - min_val) / val_range
test_data = (test_data - min_val) / val_range

In [514]:
# mean = train_data.mean(dim=0, keepdim=True)
# std = train_data.std(dim=0, keepdim=True)
# std[std==0] = 1.
# train_data = (train_data - mean) / std
# test_data = (test_data - mean) / std
# train_data = train_data / std
# test_data = test_data / std
# unit norm
train_data = train_data / train_data.norm(dim=1, keepdim=True)
test_data = test_data / test_data.norm(dim=1, keepdim=True)

In [515]:
indices = torch.randint(len(train_data), (5000,))
train_data = train_data[indices]
train_labels = train_labels[indices]
indices = torch.randint(len(test_data), (1000,))
test_data = test_data[indices]
test_labels = test_labels[indices]

In [482]:
# rbf kernel
# median_distance = torch.cdist(train_data, train_data).median()
# lengthscale = median_distance
# train_data = train_data / lengthscale
# test_data = test_data / lengthscale
# bias = 1.

# squared_norm = (train_data**2).sum(dim=1)
# prefactor_train = torch.exp(-squared_norm / 2.)

# train_data = train_data / np.sqrt(degree)
# test_data = test_data / np.sqrt(degree)

In [516]:
lengthscale

22.627416997969522

In [517]:
# poly kernel
train_data = train_data / lengthscale
test_data = test_data / lengthscale

In [518]:
placeholder = torch.zeros(len(train_data), 512)
placeholder[:, :train_data.shape[1]] = train_data
placeholder[:, -1] = np.sqrt(bias)
train_data = placeholder
placeholder = torch.zeros(len(test_data), 512)
placeholder[:, :test_data.shape[1]] = test_data
placeholder[:, -1] = np.sqrt(bias)
test_data = placeholder

In [519]:
train_labels[train_labels==-1.] = 0
test_labels[test_labels==-1.] = 0
#train_labels = one_hot(train_labels.type(torch.LongTensor)).reshape(-1, 2).type(torch.FloatTensor)
#test_labels = one_hot(test_labels.type(torch.LongTensor)).reshape(-1, 2).type(torch.FloatTensor)

In [469]:
train_labels

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [520]:
def cholesky_solve(y, L):
    # L: lower triangular cholesky
    return torch.triangular_solve(
        torch.triangular_solve(y, L, upper=False)[0],
        L.conj().t(), transpose=False, upper=True
    )[0]

In [521]:
def solve_linear_regression(train_features, train_labels, test_features, lam=0.1):
    sigma_inv = train_features.t() @ train_features + torch.eye(len(train_features.t())) * lam
    xTy = train_features.t() @ train_labels
    L_sigma_inv = torch.cholesky(sigma_inv)
    alpha = cholesky_solve(xTy, L_sigma_inv)
    return test_features @ alpha

In [522]:
lam = 0.01#0.1

In [523]:
real_errors = []
comp_errors = []

real_accs = []
comp_accs = []

for degree in [degree]:
    print('Degree', degree)
    real_errors_cur = []
    comp_errors_cur = []
    
    real_accs_cur = []
    comp_accs_cur = []
    
    for seed in range(20):
        print('Seed', seed)

        # real sketch
        feature_encoder = PolynomialSketch(
            512, # data input dimension (power of 2 for srht projection_type)
            2*512, # output dimension of the random sketch
            degree=degree, # degree of the polynomial kernel
            bias=0, # bias parameter of the polynomial kernel
            lengthscale=1., # inverse scale of the data (like lengthscale for Gaussian kernel)
            projection_type='srht',
            hierarchical=False,
            complex_weights=False,
            full_cov=True,
            convolute_ts=False,
            complex_real=False
        )

        feature_encoder.resample()
        projections_train = feature_encoder.forward(train_data)
        projections_test = feature_encoder.forward(test_data)
#         approx_kernel_real = projections_train @ projections_train.t()
        predictions_real = solve_linear_regression(projections_train, train_labels, projections_test, lam=lam)
#         projections_train = feature_encoder.forward(train_data)
#         projections_test = feature_encoder.forward(test_data)
#         projections_train = torch.hstack([projections_train.real, projections_train.imag])
#         projections_test = torch.hstack([projections_test.real, projections_test.imag])
        
#         predictions_real = solve_linear_regression(projections_train, train_labels, projections_test, lam=0.1)
        approx_kernel_real = projections_train @ projections_train.t()

        # complex sketch
        feature_encoder = PolynomialSketch(
            512, # data input dimension (power of 2 for srht projection_type)
            2*512, # output dimension of the random sketch
            degree=degree, # degree of the polynomial kernel
            bias=0, # bias parameter of the polynomial kernel
            lengthscale=1., # inverse scale of the data (like lengthscale for Gaussian kernel)
            projection_type='srht',
            hierarchical=False,
            complex_real=True,
            complex_weights=False,
            full_cov=True,
            convolute_ts=False
        )

        feature_encoder.resample()
        projections_train = feature_encoder.forward(train_data)
        projections_test = feature_encoder.forward(test_data)
#         projections_train = train_data
#         projections_test = test_data
        
        predictions_comp = solve_linear_regression(projections_train, train_labels, projections_test, lam=lam)
        approx_kernel_comp = projections_train @ projections_train.t()
#         projections_train = feature_encoder.forward(train_data)
#         projections_test = feature_encoder.forward(test_data)
#         approx_kernel_comp = projections_train @ projections_train.t()
#         predictions_comp = solve_linear_regression(projections_train, train_labels, projections_test, lam=0.1)

        # reference kernel
        ref_kernel = (train_data @ train_data.t())**degree
#         pref = prefactor_train.unsqueeze(1) * prefactor_train.unsqueeze(0)
#         ref_kernel = pref * (train_data @ train_data.t())**degree
#         approx_kernel_comp *= pref
#         approx_kernel_real *= pref

        # error
        real_error = (approx_kernel_real - ref_kernel).pow(2).sum().sqrt()
        real_error /= ref_kernel.pow(2).sum().sqrt()

        comp_error = (approx_kernel_comp - ref_kernel).pow(2).sum().sqrt()
        comp_error /= ref_kernel.pow(2).sum().sqrt()

        real_errors_cur.append(real_error.item())
        comp_errors_cur.append(comp_error.item())
        
        real_acc = (predictions_real.argmax(dim=1) == test_labels.argmax(dim=1)).sum() / len(test_labels)
        comp_acc = (predictions_comp.argmax(dim=1) == test_labels.argmax(dim=1)).sum() / len(test_labels)
        
        real_accs_cur.append(real_acc.item())
        comp_accs_cur.append(comp_acc.item())
        
    real_errors.append(np.array(real_errors_cur))
    comp_errors.append(np.array(comp_errors_cur))
    
    real_accs.append(np.array(real_accs_cur))
    comp_accs.append(np.array(comp_accs_cur))

Degree 3
Seed 0
Seed 1
Seed 2
Seed 3
Seed 4
Seed 5
Seed 6
Seed 7
Seed 8
Seed 9
Seed 10
Seed 11
Seed 12
Seed 13
Seed 14
Seed 15
Seed 16
Seed 17
Seed 18
Seed 19


In [404]:
approx_kernel_real

tensor([[1.0350, 1.0335, 1.0503,  ..., 1.0408, 1.0482, 1.0446],
        [1.0335, 1.0905, 1.0636,  ..., 1.0582, 1.0623, 1.0697],
        [1.0503, 1.0636, 1.1123,  ..., 1.0718, 1.0790, 1.0725],
        ...,
        [1.0408, 1.0582, 1.0718,  ..., 1.0937, 1.0661, 1.0667],
        [1.0482, 1.0623, 1.0790,  ..., 1.0661, 1.1054, 1.0709],
        [1.0446, 1.0697, 1.0725,  ..., 1.0667, 1.0709, 1.1228]])

In [405]:
approx_kernel_comp

tensor([[1.0316, 1.0296, 1.0320,  ..., 1.0320, 1.0285, 1.0270],
        [1.0296, 1.0930, 1.0491,  ..., 1.0519, 1.0450, 1.0574],
        [1.0320, 1.0491, 1.0791,  ..., 1.0518, 1.0460, 1.0450],
        ...,
        [1.0320, 1.0519, 1.0518,  ..., 1.0839, 1.0421, 1.0472],
        [1.0285, 1.0450, 1.0460,  ..., 1.0421, 1.0701, 1.0407],
        [1.0270, 1.0574, 1.0450,  ..., 1.0472, 1.0407, 1.0962]])

In [406]:
ref_kernel

tensor([[1.0289, 1.0248, 1.0308,  ..., 1.0321, 1.0252, 1.0259],
        [1.0248, 1.0820, 1.0461,  ..., 1.0499, 1.0400, 1.0525],
        [1.0308, 1.0461, 1.0828,  ..., 1.0536, 1.0450, 1.0448],
        ...,
        [1.0321, 1.0499, 1.0536,  ..., 1.0863, 1.0421, 1.0475],
        [1.0252, 1.0400, 1.0450,  ..., 1.0421, 1.0677, 1.0395],
        [1.0259, 1.0525, 1.0448,  ..., 1.0475, 1.0395, 1.0934]])

In [524]:
# deg 3 bias 2, lengthscale np.sqrt(d), comp vs real
print(np.mean(np.array(real_accs), axis=1))
print(np.mean(np.array(comp_accs), axis=1))

[0.87325]
[0.87284999]


In [525]:
print(np.mean(np.array(real_errors), axis=1))
print(np.mean(np.array(comp_errors), axis=1))

[0.00015419]
[0.00011944]


In [507]:
# deg 3 bias 0.5, lengthscale np.sqrt(d), comp vs real
print(np.mean(np.array(real_accs), axis=1))
print(np.mean(np.array(comp_accs), axis=1))

[0.857]
[0.85765]


In [508]:
print(np.mean(np.array(real_errors), axis=1))
print(np.mean(np.array(comp_errors), axis=1))

[0.00063477]
[0.0004947]


In [474]:
# deg 3 bias 1, lengthscale np.sqrt(d), min-max
print(np.mean(np.array(real_accs), axis=1))
print(np.mean(np.array(comp_accs), axis=1))

[0.886]
[0.88585]


In [475]:
print(np.mean(np.array(real_errors), axis=1))
print(np.mean(np.array(comp_errors), axis=1))

[0.01129104]
[0.00356528]


In [491]:
# deg 3 bias 1, lengthscale np.sqrt(d), min-max, comp vs real
print(np.mean(np.array(real_accs), axis=1))
print(np.mean(np.array(comp_accs), axis=1))

[0.8777]
[0.877]


In [492]:
print(np.mean(np.array(real_errors), axis=1))
print(np.mean(np.array(comp_errors), axis=1))

[0.00450897]
[0.00314206]


In [456]:
# deg 3 bias 1, lengthscale np.sqrt(d), no min-max
print(np.mean(np.array(real_accs), axis=1))
print(np.mean(np.array(comp_accs), axis=1))

[0.85329999]
[0.8498]


In [457]:
print(np.mean(np.array(real_errors), axis=1))
print(np.mean(np.array(comp_errors), axis=1))

[0.11507457]
[0.09992507]


In [439]:
# deg 3 a 2, unit norm
print(np.mean(np.array(real_accs), axis=1))
print(np.mean(np.array(comp_accs), axis=1))

[0.84925]
[0.85265]


In [440]:
print(np.mean(np.array(real_errors), axis=1))
print(np.mean(np.array(comp_errors), axis=1))

[0.10427316]
[0.083339]


In [422]:
# deg 3 a 4, unit norm
print(np.mean(np.array(real_accs), axis=1))
print(np.mean(np.array(comp_accs), axis=1))

[0.86105]
[0.86279999]


In [423]:
print(np.mean(np.array(real_errors), axis=1))
print(np.mean(np.array(comp_errors), axis=1))

[0.03144861]
[0.01627847]
