In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.backends.cudnn as cudnn
import sys
sys.path.append("..")
import model.simplenet as simplenet
import matplotlib.pyplot as plt
import matplotlib as mpl
import imageio
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
#import seaborn as sns
import numpy as np
import torch.nn.functional as F
import random
import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 2.5
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['axes.labelsize'] = 12
mpl.rcParams['axes.titlesize'] = 14
mpl.rcParams['font.weight'] = 'bold'
mpl.rcParams['font.size'] = 12
mpl.rcParams['axes.labelweight'] = 'bold'
%matplotlib inline
%config InlineBackend.figure_format = 'svg'


def change_weight(weight):
    '''
    Change the weight matrix in quadratic neural network.
    '''
    model_dict = model.state_dict()
    model_dict['classifier.0.weight_a'] = weight
    model.load_state_dict(model_dict)
    
    return


def change_eigen(eigen):

    model_dict = model.state_dict()
    model_dict['classifier.0.eigen'] = eigen
    model.load_state_dict(model_dict)
    
    return

def eigen_decomposition(a, initial, remain_number = 40):
    '''
    Finding the low-dimensional sturcture for a weight matrix weight_a in quadratic 
    neural network. (Accuracy change with respect to the remaining eigens.)
    '''
    
    weight_a_change = torch.zeros(10,784,784)
    test_accuracy = torch.zeros(remain_number)
    #out = torch.zeros(remain_number,10,10)
    
    for k in range(remain_number):
        for i in range(10):
            weight = (a-initial)[i,:,:]
            u,s,v = torch.svd(weight)
            s_1 = torch.zeros(784,784)
            for j in range(k):
                s_1[j,j] = s[j]
            weight_a_change[i,:,:]  = torch.mm(torch.mm(u,s_1),v.t())
    
        change_weight(weight_a_change+initial)

        model.eval()
        valid_correct = 0
        valid_total = 0
        with torch.no_grad():
            total_loss = 0
            for i, (input, target) in enumerate(valid_loader):
                input, target = input, target.long()
            # compute output
                output = model(input)
                #for digit in range(10):
                 #   out[k,:,digit] = torch.sum(output[target==digit,:],dim=0)/torch.sum(target==digit)
                _, predicted = torch.max(output.data, 1)
                valid_total = target.size(0)
                valid_correct = (predicted == target).sum().item()

            
            
        prec = valid_correct / valid_total
        print('Accuary on test images:{:.2f}%'.format(prec*100))
        print('index of eigen:{}'.format(k))
    
        test_accuracy[k] = prec

    return test_accuracy


def eigen_decomposition_1(a, remain_number = 40):
    '''
    Finding the low-dimensional sturcture for a weight matrix weight_a in quadratic 
    neural network. (Accuracy change with respect to the remaining eigens.)
    '''
    
    weight_a_change = torch.zeros(10,784,784)
    test_accuracy = torch.zeros(remain_number)
    #out = torch.zeros(remain_number,10,10)
    
    for k in range(remain_number):
        for i in range(10):
            weight = a[i,:,:]
            u,s,v = torch.svd(weight)
            s_1 = torch.zeros(784,784)
            for j in range(k):
                s_1[j,j] = s[j]
            weight_a_change[i,:,:]  = torch.mm(torch.mm(u,s_1),v.t())
    
        change_weight(weight_a_change)

        model.eval()
        valid_correct = 0
        valid_total = 0
        with torch.no_grad():
            for i, (input, target) in enumerate(valid_loader):
                input, target = input, target.long()
            # compute output
                output = model(input)
                #for digit in range(10):
                 #   out[k,:,digit] = torch.sum(output[target==digit,:],dim=0)/torch.sum(target==digit)
                _, predicted = torch.max(output.data, 1)
                valid_total = target.size(0)
                valid_correct = (predicted == target).sum().item()

            
            
        prec = valid_correct / valid_total
        print('Accuary on test images:{:.2f}%'.format(prec*100))
        print('index of eigen:{}'.format(k))
    
        test_accuracy[k] = prec
        #change_weight(a)

    return test_accuracy


def generate_spike_trigger_average():
    '''
    Generate spike trigger average of MNIST dataset
    (approximated by the mean of each digits in the dataset.)
    '''
    
    train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=60000, shuffle=True)
    for i, (input, target) in enumerate(train_loader):
        a = 0
    input = input.reshape(60000,784)

    spike_trigger_average = torch.zeros(10,784)

    for j in range(10):
        k = 0
        for i in range(60000):
            if target[i] == j:
                spike_trigger_average[j,:] += input[i,:]
                k = k+1
        spike_trigger_average[j,:] = spike_trigger_average[j,:].cpu()/torch.norm(spike_trigger_average[j,:])
        
    return spike_trigger_average


def generate_correlation_matrix():
    '''
    Generate spike trigger average of MNIST dataset
    (approximated by the mean of each digits in the dataset.)
    '''
    
    train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=60000, shuffle=True)
    for i, (input, target) in enumerate(train_loader):
        a = 0
    input = input.reshape(60000,784)

    correlated = torch.zeros(10,784,784)

    for j in range(10):
        for i in range(60000):
            if target[i] == j:
                correlated[j,:,:] += torch.mm(input[i,:].reshape(784,1),input[i,:].reshape(1,784))
        correlated[j,:,:] = correlated[j,:,:]/torch.sum(target==j)
        
    return correlated


def matshow(a):
    m = torch.max(abs(a))
    plt.figure()    
    plt.matshow(a.reshape(28,28),cmap=plt.cm.gray,vmin = -m,vmax = m)
    #plt.colorbar()
    return


def test_matrix(a):
    valid_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=10000)
    change_weight(a)
    model.eval()
    valid_correct = 0
    valid_total = 0
    with torch.no_grad():
        total_loss = 0
        for i, (input, target) in enumerate(valid_loader):
            input, target = input, target.long()
                        
            # compute output
            output = model(input)
            loss = criterion(output, target)

            _, predicted = torch.max(output.data, 1)
            valid_total += target.size(0)
            valid_correct += (predicted == target).sum().item()

            total_loss += loss
            
            
    prec = valid_correct / valid_total
    ave_loss = total_loss/len(valid_loader)
    print('Accuary on test images:{:.2f}%'.format(prec*100))
    return prec


def test_eigen(a):
    valid_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=100)
    change_eigen(a)
    model.eval()
    valid_correct = 0
    valid_total = 0
    with torch.no_grad():
        for i, (input, target) in enumerate(valid_loader):
            input, target = input.cuda(), target.long().cuda()
                        
            # compute output
            output = model(input)
            loss = criterion(output, target)

            _, predicted = torch.max(output.data, 1)
            valid_total += target.size(0)
            valid_correct += (predicted == target).sum().item()
            
    prec = valid_correct / valid_total
    print('Accuary on test images:{:.2f}%'.format(prec*100))
    return prec


def savefig(name):
    plt.savefig(name,dpi=600, bbox_inches='tight')
    return


def generate_adversial_exm_grad(varepsilon):
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=10000)
    for i, (input, target) in enumerate(valid_loader):
        input, target = input, target.long()
        adversial_exm = input
        
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=1)
    for i, (input, target) in enumerate(valid_loader):
        input, target = input, target.long()

        input.requires_grad = True
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
    
        adversial_exm[i:i+1,:,:,:] = input + varepsilon*input.grad/torch.norm(input.grad)
        
        
    adversial_exm = adversial_exm.detach()
    return adversial_exm.data


def generate_adversial_exm_sign(varepsilon):
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=10000)
    for i, (input, target) in enumerate(valid_loader):
        input, target = input, target.long()
        adversial_exm = input
        
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=1)
    for i, (input, target) in enumerate(valid_loader):
        input, target = input, target.long()

        input.requires_grad = True
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
    
        adversial_exm[i:i+1,:,:,:] = input + varepsilon*torch.sign(input.grad)
        
        
    adversial_exm = adversial_exm.detach()
    return adversial_exm.data

def kl_divergence():
    number_class = torch.Tensor([0,980,1135,1032,1010,982,892,958,1028,974,1009])
    number_class_cor = torch.cumsum(number_class,0)
    number_class_mis = torch.cat((torch.zeros(1),torch.cumsum(10000-number_class[1:],0)),0)
    all_cos_mistake, all_cos_correct = cos_distribution(new_eigen)
    for digit in range(10):
        bins=np.arange(0,0.6,0.6/400) # |cos|
        bins=np.arange(0,1,1/2000) # cos^2
        frequency_each,_,_ = plt.hist(all_cos_mistake[int(number_class_mis[digit]):int(number_class_mis[digit+1])].tolist(), bins = bins,color='blue',label='other label')
        frequency_each_c,_,_ = plt.hist(all_cos_correct[int(number_class_cor[digit]):int(number_class_cor[digit+1])].tolist(), color='red', bins = bins,label='true label')
        mistaken_dist = frequency_each/(10000-number_class[digit+1])
        correct_dist = frequency_each_c/number_class[digit+1]
        mistaken_dist[mistaken_dist == 0] = 1e-16*torch.ones((mistaken_dist == 0).sum()).double()
        correct_dist[correct_dist == 0] = 1e-16*torch.ones((correct_dist == 0).sum()).double()
        KL_divergence = F.kl_div(mistaken_dist.log(),correct_dist , None, None, 'sum')
        KL_divergence_1 = F.kl_div(correct_dist.log(),mistaken_dist , None, None, 'sum')
        KL_div.append(KL_divergence)
        KL_div_1.append(KL_divergence_1) 
    return

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.benchmark = True #for accelerating the running
    return

def ks():
    input_digit = input[target==digit,:,:,:] 
    output = model(input_digit)

    _, predicted = torch.max(output.data, 1)
    valid_total = output.shape[0]
    valid_correct = (predicted == digit).sum().item()
    all_cos_cor = torch.zeros(output.shape[0])
    all_cos_mis = torch.zeros(9,output.shape[0])    

    for index in range(output.shape[0]):
        output[index,:] = output[index,:]/torch.norm(input_digit[index,:],2)
        all_cos_cor[index] = output[index, digit]
        all_cos_mis[:,index] = output[index, torch.arange(10).cuda() != digit]


    bins=np.arange(0,1.001,1/2000) 
    frequency_each,_ = np.histogram(all_cos_mis.reshape(9*output.shape[0]).tolist(), bins = bins)
    frequency_each_c,_ = np.histogram(all_cos_cor.tolist(), bins = bins)
    cdf_mistaken = torch.cat((torch.Tensor([0]),torch.cumsum(torch.from_numpy(frequency_each)/(9*output.shape[0]), dim=0)),0)
    cdf_correct = torch.cat((torch.Tensor([0]),torch.cumsum(torch.from_numpy(frequency_each_c)/output.shape[0], dim=0)),0)
    KS_distance = torch.max(abs(cdf_correct-cdf_mistaken))
    KS_dis.append(KS_distance) 
    return

#ax1.spines['top'].set_visible(False)


# Training Process

In [28]:
setup_seed(1)
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
valid_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())

#FMNIST dataset
#train_dataset = torchvision.datasets.FashionMNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
#valid_dataset = torchvision.datasets.FashionMNIST(root='data', train=False, transform=transforms.ToTensor())


# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=10000)
all_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=60000, shuffle=False)

#num_eigen = 2
#initial_number = 4
#EI_distribution = torch.bernoulli(torch.ones(784)*0.75)
#torch.save(EI_distribution,'EI_distribution')
#print('number of excitatory neuron:{}'.format(EI_distribution.sum()))
#kappa_matrix = -torch.ones(784, 784)
#kappa_matrix[EI_distribution==0,:] = 1
#kappa_matrix[:,EI_distribution==0] = 1
model = simplenet.SimpleNet_1(num_eigens=1)
eigen = model.state_dict()['classifier.0.eigen']
theta_0_norm = torch.norm(eigen)
with torch.no_grad():
    for j, (input, target) in enumerate(all_train_loader):
        input, target = input, target.long()
        output = model(input)
    output_theta0_norm = torch.norm(output,dim=1)
#model = nn.DataParallel(model, device_ids=[0]).cuda()
#model = twolayernet.TwoLayerNet_0(num_eigens = num_eigen, hidden_layers_neuron = hidden_layer_neuron)
#model = nn.DataParallel(model, device_ids=[0]).cuda()
#model = nn.DataParallel(model, device_ids=[0])
print(model)
#initial_weight_a = model.state_dict()['classifier.0.weight_a']
#torch.save(initial_weight_a,'initial_weight_a')
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
test_accuracy = []
KS_dis = []
train_accuracy = [0]  
all_train_loss = []
image_list = []
monotone = []
norm = []
output_std_true = []
output_std_other = []
output_norm = []
mean = []
gamma_t = []
u1u2_innerpro = []
u1u3_innerpro = []
alpha1_all = []
alpha2_all = []
alpha3_all = []
all_partial_snj_norm = []
all_dtheta_norm = []

#new_weight = torch.zeros(10,784,784)
number = 0
best_prec = 0
min_loss = 1
a = 1/torch.Tensor(torch.load('a_randomseed_1_norm_square'))
#b = torch.load('b')
#eigen = model.state_dict()['classifier.0.eigen']
#new_eigen[:,:,:initial_number] = torch.load('{}_rank_after_train'.format(initial_number))
#new_eigen[:,:,initial_number:] = eigen[:,:,initial_number:]*10
#change_eigen(new_eigen)

#change initial value
#for i in range(10):
 #   weight_a[i+1,:,:] = correlated[i+1,:,:]
 #   correlated_change[i,:,:] = correlated[9,:,:]
#change_weight(torch.load('first_five_eig'))

for epoch in range(0, 1):
    #elif epoch < 10:
     #   lr = 0.001
    #else:
     #   lr = 0.0001

    

    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.01

    model.train()
    # train for one epoch
    for i, (input, target) in enumerate(train_loader):
        train_total = 0
        train_correct = 0
        train_loss = 0
        # measure data loading time
        input, target = input, target.long()
                    
        
        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward() 
        optimizer.step()
        
        #weight_a = model.state_dict()['classifier.0.weight_a']
        #new_weight = weight_a.data
        #for j in range(10):
         #   new_weight[j,(new_weight[j,:,:]*kappa_matrix)<0] = 0

        #change_weight(new_weight)  

        # for name, parms in model.named_parameters():
        #     print('-->name:', name, '-->grad_requirs:',parms.requires_grad, '-->grad_value:',parms.grad)

        _, predicted = torch.max(output.data, 1)
        train_total += target.size(0)
        train_correct += (predicted == target).sum().item()
        prec = train_correct / train_total

        #eigen = model.state_dict()['classifier.0.eigen']
        #for j in range(10):
         #   new_eigen[j,:,:] = eigen[j,:,:]/torch.norm(eigen[j,:,:],'fro')
        #change_eigen(new_eigen)

        #weight_a = model.state_dict()['classifier.0.weight_a']
        #for j in range(10):
         #   new_weight[j,:,:] = weight_a[j,:,:]/torch.norm(weight_a[j,:,:],'fro')
        #change_weight(new_weight)

        if i < 300:
            if (i) % 2 == 0:
                train_accuracy.append(prec)
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.5f}, Train_Acc:{:.2f}%'.format(epoch+1, 20, i, len(train_loader), loss, prec*100))

                
        # evaluate on test set
        # switch to evaluate mode
        
                model.eval()
                valid_correct = 0
                valid_total = 0
                with torch.no_grad():
                    total_loss = 0
                    for j, (input, target) in enumerate(valid_loader):
                        input, target = input, target.long()
                        output = model(input)


                        
                        _, predicted = torch.max(output.data, 1)
                        valid_total = output.shape[0]
                        valid_correct = (predicted == target).sum().item()
                        loss = criterion(output, target)

                        #for digit in range(10):
                        #   for index in range(output.shape[0]):
                        #      if target[index] == digit:
                        #         out[number,:,digit] += output[index,:]/((torch.norm(eigen[digit,:,:],2)*torch.norm(output_first_layer_after[index,:],2))**2)
                            #out[number,:,digit] = torch.mean(output[target==digit,:],dim=0)
                        #number += 1

                        prec = valid_correct / valid_total
                        print('Accuary on test images:{:.2f}%, loss:{:.5f}'.format(prec*100, loss))
                        test_accuracy.append(prec)
                        #all_train_loss.append(loss)
                        best_prec = max(prec, best_prec)


                model.train()    
                for j, (input, target) in enumerate(all_train_loader):
                    input, target = input, target.long()
                    output = model(input)

                    output_norm.append(torch.norm(output,dim=1).data)
                    loss = criterion(output, target)
                    loss.backward()

                    eigen = model.state_dict()['classifier.0.eigen']
                    norm.append(torch.norm(eigen))

                    for name, param in model.named_parameters():
                        dtheta = -param.grad.data.clone()


                    all_dtheta_norm.append(torch.norm(dtheta))
                    dtheta = dtheta/torch.norm(dtheta)

                    alpha_1 = torch.sum(dtheta*eigen)/torch.norm(eigen)

                    alpha1_all.append(alpha_1)
                    u1 = dtheta-alpha_1*eigen/torch.norm(eigen)

                    partial_snj = torch.zeros(10,784,1)
                    #partial_snj_1 = torch.zeros(10,784,1)
                    for n in range(60000):
                        for j in torch.arange(10):
                      #      if j != target[n]:
                                #model.zero_grad()
                                #output = model(input)
                                #snj = output[n,target[n]]-output[n,j]
                                #snj.backward()
                                #for name, param in model.named_parameters():
                               #     partial_snj += param.grad.data.clone()*a[n]

                                partial_snj[j,:,0] += -a[n]*2*torch.sum(eigen[j,:,0]*input[n,0,:,:].reshape(784))*input[n,0,:,:].reshape(784)
                      #          partial_snj_1[j,:,0] += -b[n]*2*torch.sum(eigen[j,:,0]*input[n,0,:,:].reshape(784))*input[n,0,:,:].reshape(784)
                        partial_snj[target[n],:,0] += a[n]*18*torch.sum(eigen[target[n],:,0]*input[n,0,:,:].reshape(784))*input[n,0,:,:].reshape(784)
                     #   partial_snj_1[target[n],:,0] += b[n]*18*torch.sum(eigen[target[n],:,0]*input[n,0,:,:].reshape(784))*input[n,0,:,:].reshape(784)
                    partial_snj_norm = torch.norm(partial_snj)
                    all_partial_snj_norm.append(partial_snj_norm)
                    partial_snj = partial_snj/partial_snj_norm
                    #partial_snj_1 = partial_snj_1/partial_snj_norm
                    
                    alpha_2 = torch.sum(partial_snj*eigen)/torch.norm(eigen)
                    alpha2_all.append(alpha_2)

                    #alpha_3 = torch.sum(partial_snj_1*eigen)/torch.norm(eigen)
                    #alpha3_all.append(alpha_3)

                    u2 = partial_snj-alpha_2*eigen/torch.norm(eigen)
                    u1u2_innerpro.append(torch.sum(u1*u2))

                    #u3 = partial_snj_1-alpha_3*eigen/torch.norm(eigen)
                    #u1u3_innerpro.append(torch.sum(u1*u3))

                    gamma_t.append((torch.norm(eigen)**2-theta_0_norm**2)/(torch.norm(eigen)**2))

                    #with torch.no_grad():
                     #   all_cos_cor = torch.zeros(output.shape[0])
                      #  all_cos_mis = torch.zeros(9,output.shape[0])  
                       # mean_difference = 0
                        #for index in range(output.shape[0]):
                         #   output[index,:] = output[index,:]/torch.norm(output[index,:],2)
                          #  all_cos_cor[index] = output[index, target[index]]
                           # all_cos_mis[:,index] = output[index, torch.arange(10) != target[index]]
                            #for digit in range(10):
                             #   if digit != target[index]:
                              #      mean_difference += output[index, target[index]]-output[index,digit]

                       # mean.append(mean_difference/540000)

                            #output_std_other.append(torch.std(all_cos_mis, unbiased=True))
                            #output_std_true.append(torch.std(all_cos_cor, unbiased=True))
                            #print(torch.std(all_cos_cor, unbiased=True), torch.std(all_cos_mis, unbiased=True))




                        #bins=np.arange(min(torch.min(all_cos_cor), torch.min(all_cos_mis)),max(torch.max(all_cos_cor), torch.max(all_cos_mis)), 1/2000) 
                        #frequency_each,_ = np.histogram(all_cos_mis.reshape(9*output.shape[0]).tolist(), bins = bins)
                        #frequency_each_c,_ = np.histogram(all_cos_cor.tolist(),bins = bins)
                        #cdf_mistaken = torch.cat((torch.Tensor([0]),torch.cumsum(torch.from_numpy(frequency_each)/(9*output.shape[0]), dim=0)),0)
                        #cdf_correct = torch.cat((torch.Tensor([0]),torch.cumsum(torch.from_numpy(frequency_each_c)/output.shape[0], dim=0)),0)
                        #KS_distance = torch.max(abs(cdf_correct-cdf_mistaken))
                        #KS_dis.append(KS_distance)

print('Best accuracy is: {:.2f}%, Minimum loss is: {:.4f}'.format(best_prec*100, min_loss))
#weight_a = model.state_dict()['classifier.0.weight_a']
#torch.save(weight_a,'after_train_weight')
#torch.save(weight_a,'weight_a')
#eigen = model.state_dict()['classifier.0.eigen']
#torch.save(eigen,'eigen_9')

#imageio.mimsave('distribution_of_cos.gif', image_list, duration=0.2) 
#spike_trigger_average = generate_spike_trigger_average()

In [None]:
ft = (torch.Tensor(gamma_t))*torch.Tensor(u1u2_innerpro)+torch.Tensor(alpha1_all)*torch.Tensor(alpha2_all)*(torch.Tensor(gamma_t)-0.5)

In [None]:
#rhs = 1/(2*(torch.Tensor(u1u2_innerpro)/(torch.Tensor(alpha1_all)*torch.Tensor(alpha2_all))+1))

In [None]:
fig,ax = plt.subplots(1,1, layout="constrained", figsize=(5,4))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.plot(2*(torch.arange(80)+20), torch.zeros(80), linestyle='--', linewidth=2, color = 'black')
plt.plot(2*(torch.arange(80)+20), test_diff[19:99], linestyle='--', linewidth=2, color = 'gray', label='Accuracy\ndifference')
plt.scatter(2*(torch.arange(80)+20),ft[20:100],c = test_accuracy[20:100],cmap=plt.cm.rainbow,s=5)
cb = plt.colorbar()
cb.set_label('test accuracy',fontsize=12)
plt.xlabel(r'Training steps')
plt.ylabel(r'$\frac{d}{dt}(\mu_1-\mu_2)$')
plt.legend(loc='best')
savefig('test_difference.pdf')

In [None]:
#output_std_true = torch.log(torch.Tensor(output_std_true))
#norm = torch.log(torch.Tensor(norm)-1)
a = []
#output_norm = torch.stack(output_norm)
for i in range(60000):
    end = 75
    output_std_true = (torch.Tensor(norm)**2)[:end]
    output_std_other = output_norm[:end,i].data

    A = torch.cat(((torch.Tensor(output_std_true)).unsqueeze(0),torch.ones(len(output_std_true)).unsqueeze(0)), dim=0)
    B = torch.Tensor(output_std_other).unsqueeze(0)
    x = torch.linalg.lstsq(A.T, B.T).solution
    r_square = 1 - torch.sum((B.squeeze(0)-torch.mm(A.T, x).squeeze(1))**2)/torch.sum((B.squeeze(0)-torch.mean(B.squeeze(0)))**2)
    print(r_square)
    a.append(x[0][0])

In [None]:
fig,ax = plt.subplots(1,1, layout="constrained", figsize=(5,4))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
#torch.mm(A.T, x).squeeze(0)
#c=test_accuracy
plt.plot(torch.Tensor(output_std_true), torch.mm(A.T, x).squeeze(0), linestyle='--', linewidth=2, label='$r^2 = {:.2f}$'.format(r_square), color = 'black')
plt.scatter(torch.Tensor(output_std_true), output_std_other,c = test_accuracy[:end], cmap=plt.cm.rainbow,s=5)
#plt.xlabel(r'$||x_n||_2^2$')
#plt.ylabel(r'$a_n $')
plt.title(r'$||\Phi(x_5,\theta)||_2  = {:.2f}||\theta||_2^2  {:.2f}$'.format(x[0][0],x[1][0]))
cb = plt.colorbar()
cb.set_label('test accuracy',fontsize=12)
plt.legend(loc='best')
#savefig('mnist_5_outputnorm_vs_norm.pdf')

In [None]:
test_diff = torch.diff(test_accuracy)
plt.plot(test_diff)

# plot KS distance

In [None]:
test_accuracy = torch.Tensor(test_accuracy)
KS_dis = torch.Tensor(KS_dis)
mean = torch.Tensor(mean)
to = 150
begin = 0
#for digit in range(10):

plt.figure(figsize=(4,4)) 
ax1 = plt.subplot(111)

plt.xticks(weight = 'bold')
ax1.plot(2*torch.arange(len(test_accuracy))[begin:to] ,test_accuracy[begin:to] ,color = 'blue',linewidth = 3)
plt.yticks(weight = 'bold')
ax2 = ax1.twinx()
ax2.plot(2*torch.arange(len(test_accuracy))[begin:to], mean[begin:to] ,color = 'red',linewidth = 2,linestyle = '--')
ax1.set_xlabel("Training steps",fontsize = 12, fontweight = 'bold')
ax2.set_ylabel('$\mu_1-\mu_2$',fontsize = 12, color = 'red', fontweight = 'bold')
ax1.set_ylabel('Test accuracy',fontsize = 12, color = 'blue', fontweight = 'bold')
plt.yticks(weight = 'bold')
plt.title('LR-NN: MNIST', fontsize = 12, fontweight = 'bold')
ax1.grid()
#ax2.set_ylim([-0.025,0.5])

#plt.title('Digit {}'.format(digit), fontsize = 20)
#savefig('20-150-md_mnist.pdf')


# Distribution

plt.figure()
plt.plot(bins, cdf_correct, color = 'lime', linewidth=4, linestyle='--',label='True label cdf')
plt.plot(bins, cdf_mistaken, color = 'aqua', linewidth=4, linestyle='--',label='Other label cdf')
plt.xlabel('Output under L2 normalization',fontsize=20)
plt.ylabel('Cumulative probability',fontsize = 20)
plt.title('KS distance = {:.2f}'.format(KS_distance))
plt.legend(fontsize=15,loc= 'lower right')
savefig('temp.png')
plt.show()
image_list.append(imageio.imread('temp.png'))

# Confusion matrix

valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=10000)
change_eigen(spike_trigger_average.reshape(10,784,1))
for wo, (input, target) in enumerate(valid_loader):
    input, target = input.cuda(), target.cuda()
output = model(input)
_, predicted = torch.max(output.data, 1)
number = torch.Tensor([980,1135,1032,1010,982,892,958,1028,974,1009])
confusion_matrix = torch.zeros(10,10)
for i in range(10000):
    confusion_matrix[target[i], predicted[i]] += 1

for i in range(10):
    confusion_matrix[i,:] = confusion_matrix[i,:]/number[i]
confusion_matrix = confusion_matrix.numpy()
confusion_matrix = np.around(confusion_matrix,3)

plt.figure()
sns.heatmap(confusion_matrix, annot=True, cmap=plt.cm.GnBu,annot_kws={'size':7,'weight':'bold'})
plt.xlabel('Predict label', fontsize = 20)
plt.ylabel('True label', fontsize = 20)
savefig('spike_confusion_matrix.svg')
plt.show()

# Adversial Examples and Gaussion noise

In [None]:
for j, (input, target) in enumerate(valid_loader):
    input, target = input, target.long()

test_accuracy = torch.zeros(10)
adversial_input_all = torch.zeros(10,10000,784)
for j in range(10):
    varepsilon = 0.01*j
    #adversial_input = generate_adversial_exm_grad(varepsilon)
    adversial_input_1 = generate_adversial_exm_sign(varepsilon)
    adversial_input_all[j,:,:] = adversial_input_1.reshape(10000,784)

    model.eval()
#valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=10000)
#for i, input in enumerate(valid_loader):
 #   input = input

#valid_correct = 0
#valid_total = 0
#with torch.no_grad():
 #   output = model(adversial_input)

   # _, predicted = torch.max(output.data, 1)
    #valid_total = target.size(0)
    #valid_correct = (predicted == target).sum().item()



#prec = valid_correct / valid_total
#print('Accuary on test images:{:.2f}%, epsilon:{}'.format(prec*100, j*0.1))

    valid_correct = 0
    valid_total = 0
    with torch.no_grad():
        output = model(adversial_input_1)

        _, predicted = torch.max(output.data, 1)
        valid_total = target.size(0)
        valid_correct = (predicted == target).sum().item()


    prec_1 = valid_correct / valid_total
    print('Accuary on test images:{:.2f}%, epsilon:{}'.format(prec_1*100, varepsilon))

#input = input.reshape(10000,512)
#difference = torch.cat([difference,torch.zeros(10000,1)],dim=1)
    test_accuracy[j] = prec_1 #gradient ad
#test_accuracy[1,j] = prec_1 #sign ad

In [None]:
for j, (input, target) in enumerate(valid_loader):
    input, target = input, target.long()

test_accuracy = torch.zeros(10)
for j in range(10):
#adversial_input = generate_adversial_exm_grad(0.5)
    adversial_input_1 = torch.load('ad_all_mnist_sign')[j,:,:]

    model.eval()
#valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=10000)
#for i, input in enumerate(valid_loader):
 #   input = input

#valid_correct = 0
#valid_total = 0
#with torch.no_grad():
 #   output = model(adversial_input)

   # _, predicted = torch.max(output.data, 1)
    #valid_total = target.size(0)
    #valid_correct = (predicted == target).sum().item()



#prec = valid_correct / valid_total 
#print('Accuary on test images:{:.2f}%, epsilon:{}'.format(prec*100, j*0.1))

    valid_correct = 0
    valid_total = 0
    with torch.no_grad():
        output = model(adversial_input_1)

        _, predicted = torch.max(output.data, 1)
        valid_total = target.size(0)
        valid_correct = (predicted == target).sum().item()


    prec_1 = valid_correct / valid_total
    print('Accuary on test images:{:.2f}%, epsilon:{}'.format(prec_1*100, j*0.01))
    test_accuracy[j] = prec_1

torch.save(test_accuracy, 'ad_sign_quadratic_mnist')

In [None]:
for i, (input, target) in enumerate(valid_loader):
    input, target = input.reshape(10000,784), target.long()

weight = model.state_dict()['classifier.0.weight']
bias = model.state_dict()['classifier.0.bias']
weight_all = torch.cat([weight,bias.reshape(10,1)],dim=1).cpu()
feature_map_all = torch.cat([input,torch.ones(10000,1)],dim=1).cpu()

Sigma = torch.mm((feature_map_all-torch.mean(feature_map_all,dim=0)).T, feature_map_all-torch.mean(feature_map_all,dim=0))/10000
Sigma = Sigma+0.001*torch.mean(abs(Sigma))*torch.diag(torch.ones(785))
u,s,v = torch.svd(Sigma)

def f(x,w):
    mu = 2*torch.mm(torch.mm(w.T,Sigma),x)/(torch.norm(w)**2)
    lambda1 = torch.mm(torch.mm(x.T,Sigma),x)
    y = -2*torch.mm(Sigma,x)+mu*w+2*lambda1*x
    return y

def Jf(x,w):
    lambda1 = torch.mm(torch.mm(x.T,Sigma),x)
    J = -2*Sigma+2*torch.mm(torch.mm(w,w.T),Sigma)/(torch.norm(w)**2)+4*torch.mm(torch.mm(x,x.T),Sigma)+2*lambda1*torch.diag(torch.ones(785))
    return J


digit = 7

w = weight_all[digit:digit+1,:].T
x = w/(torch.norm(w)**2)
for steps in range(25):
    delta_x = -torch.mm(torch.inverse(Jf(x,w)),f(x,w))
    x = x + delta_x
print(torch.norm(x),torch.mm(x.T,w),torch.norm(f(x,w)))
print(torch.mm(torch.mm(x.T,Sigma),x), s[0:30], torch.mm(torch.mm(w.T,Sigma),w)/(torch.norm(w)**2))

In [None]:
project_vector = torch.cat([w,x],dim=1)


low_dimension_data = torch.matmul(feature_map_all,project_vector)
plt.figure(figsize=(4,4))
ax = plt.subplot()
for classes in range(10):
    ax.scatter(low_dimension_data[target==classes,0], low_dimension_data[target==classes,1],label='{}'.format(classes), s=1)
ax.set_xlabel('$w_{}$'.format(digit),fontsize=12)
ax.set_ylabel('pc $\perp \ w_{}$ '.format(digit),fontsize=12)
#ax.legend(bbox_to_anchor=(1,1),fontsize=12, markerscale=8)
#ax.set_title('Class {}'.format(digit),fontsize=14)
savefig('dim_redu_mnist_{}.png'.format(digit))

# Make the input close to the network

spike_trigger_average = generate_spike_trigger_average()
change_eigen(spike_trigger_average.reshape(10,784,1))
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=10000)
for i, (input, target) in enumerate(valid_loader):
    input, target = input.cuda(), target.cuda()
    ad_exm = input
    ad_exm.requires_grad = True
    
for epoch in range(20):
    for i in range(10000):
        a = ad_exm[i:i+1,:,:,:]
        a.retain_grad()
        
        output = model(a)
        loss = criterion(output, target[i:i+1])
        loss.backward()
        
        with torch.no_grad():
            ad_exm[i:i+1,:,:,:] = a - 0.2*a.grad/torch.norm(a.grad)
        
        
    model.eval()
        
    valid_correct = 0
    valid_total = 0
    with torch.no_grad():
        output = model(ad_exm)

        _, predicted = torch.max(output.data, 1)
        valid_total += target.size(0)
        valid_correct += (predicted == target).sum().item()



    prec = valid_correct / valid_total
    print('Accuary on test images:{:.2f}%'.format(prec*100))
    print(epoch)

# Analysis of one-rank network

In [None]:
## 动力学
#image_list = []
#for i in range(600):
 #   a = torch.zeros(784)
  #  for j in range(784):
   #     a[j] = max(abs(weight_we_need[i,j,:]))
    #plt.matshow(a.reshape(28,28),cmap = plt.cm.Blues)
    #plt.savefig('temp.png')
    #image_list.append(imageio.imread('temp.png'))
#imageio.mimsave('pic6.gif', image_list, duration=0.1)

In [None]:
# 比较线性与非线性的贡献
#train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=60000, shuffle=True)
#for i, (input, target) in enumerate(train_loader):
 #   a = 1
    
#output = model(input)
#linear_term = []
#bilinear_term = []
#for i in range(60000):
 #   category = output[i,:].argmax().cpu()
   # if category == 1:
  #  weight_bilinear = weight_a[category,:,:].cpu()
   # weight_linear = weight_b[category,:].reshape(1,784).cpu()
    #input_now = input[i,0,:,:].reshape(784,1).cpu()
    #linear_term.append(torch.mm(weight_linear, input_now)*5) 
    #bilinear_term.append(torch.mm(torch.mm(input_now.t(), weight_bilinear),input_now))
    #if i % 1000 == 0:
     #   print(i)

#fig = plt.figure() 
#ax1 = fig.add_subplot(111)

#ax1.bar(torch.arange(len(bilinear_term)), bilinear_term, color = 'r',width = 1,alpha = 0.7,label = 'bilinear term')
#ax1.bar(torch.arange(len(linear_term)), linear_term,color = 'b',width = 1 , label = 'linear term')
#ax1.set_xlabel('Picture index',fontsize = 20)
#ax1.set_ylabel('Magnitude after NN',fontsize = 20)
#ax1.legend(fontsize = '15', loc = 'best')
#ax1.set_ylim(-1, 40)