Import packages

In [1]:
import numpy as np
import copy
import time
import datetime
import os
from tqdm import tqdm

import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from matplotlib.ticker import MaxNLocator
import matplotlib.patches as mpatches
# plt.rcParams['text.usetex'] = True

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from matplotlib import rcParams
rcParams.update({'figure.facecolor': "None"})
rcParams.update({'savefig.transparent': True})
rcParams.update({'figure.autolayout': True})
rcParams.update({'figure.dpi': 200})
rcParams.update({"mathtext.fontset" : "cm"})

Useful Functions

In [9]:
# image show function
def imshow(x, title=None, figsize=(2,2)):
    if torch.is_tensor(x) :
        _x = x.detach().cpu().squeeze()
        plt.figure(figsize=figsize)
        if _x.dim() == 2:
            plt.imshow(_x, cmap='gray')
        else:
            plt.imshow(_x)
        plt.title(title)
        plt.show()
    else :
        plt.imshow(x.squeeze())
        plt.title(title)
        plt.show()
        
def create_directory(directory): 
    try: 
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError: print("Error: Failed to create the directory.")
        
def savefig(directory):
    plt.savefig(directory)
    plt.show()
    plt.close()

    
def protractor(w,v):
    # function that calculates angle between two torch tensors
    return torch.arccos((w*v).sum()/(w.norm()*v.norm()))

def make_list_of_width(depth, width, dimension):
    list_of_width = []
    list_of_width.append(dimension)
    for _depth in range(depth-1):
        list_of_width.append(width)
    list_of_width.append(dimension)
    return list_of_width

In [None]:
def load_dataset(dataset, n):
    if dataset == "MNIST" :
        trn_dataset = torchvision.datasets.MNIST("./DATASETS", train=True, download=True,
                                                         transform=torchvision.transforms.Compose(
                                                             [torchvision.transforms.ToTensor(),
                                                              torchvision.transforms.Normalize((0.1307,), (0.3081,)) # MNIST
                                                             ]
                                                         ))
        test_dataset = torchvision.datasets.MNIST("./DATASETS", train=False, download=True, 
                                                         transform=torchvision.transforms.Compose(
                                                             [torchvision.transforms.ToTensor(),
                                                              torchvision.transforms.Normalize((0.1307,), (0.3081,)) # MNIST
                                                             ]
                                                         ))
    elif dataset == "FashionMNIST":
        trn_dataset = torchvision.datasets.FashionMNIST("./DATASETS", train=True, download=True,
                                                         transform=torchvision.transforms.Compose(
                                                             [torchvision.transforms.ToTensor(),
                                                              torchvision.transforms.Normalize((0.2860,), (0.3530,)) # Fashion MNIST
                                                             ]
                                                         ))
        test_dataset = torchvision.datasets.FashionMNIST("./DATASETS", train=False, download=True, 
                                                         transform=torchvision.transforms.Compose(
                                                             [torchvision.transforms.ToTensor(),
                                                              torchvision.transforms.Normalize((0.2860,), (0.3530,)) # Fashion MNIST
                                                             ]
                                                         ))
    elif dataset == "CIFAR10":
        trn_dataset = torchvision.datasets.CIFAR10("./DATASETS", train=True, download=True,
                                                         transform=torchvision.transforms.Compose(
                                                             [torchvision.transforms.ToTensor(),
                                                              torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) # CIFAR10
                                                             ]
                                                         ))
        test_dataset = torchvision.datasets.CIFAR10("./DATASETS", train=False, download=True, 
                                                         transform=torchvision.transforms.Compose(
                                                             [torchvision.transforms.ToTensor(),
                                                              torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) # CIFAR10
                                                             ]
                                                         ))
    else:
        raise ValueError("Dataset name is unvalid.")
    
    
    trn_data_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=n, shuffle=False, num_workers=0, pin_memory=False, drop_last=True) # do not shuffle.
    test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10000, shuffle=False, num_workers=0, pin_memory=False, drop_last=True)
    
    
    for data in trn_data_loader:
        trn_X, trn_Y = data[0], data[1]
        break
    trn_X = trn_X.view(-1, d)
    trn_Y = trn_Y.float()

    for data in test_data_loader:
        test_X, test_Y = data[0], data[1]
    test_X = test_X.view(-1, d)
    test_Y = test_Y.float()

    # trn_X = trn_X.view(n,-1).to(device) # FC net
    trn_X = trn_X.to(device)
    trn_Y = trn_Y.to(device)
    # test_X = test_X.view(10000,-1).to(device) # FC net
    test_X = test_X.to(device)
    test_Y = test_Y.to(device)
    return trn_X, trn_Y, test_X, test_Y

Nework Architectures

In [None]:
# # Network architectures
# class two_layer_ReLU_net(nn.Module):
#     def __init__(self, d1, output_class=2, bias=True, last_ReLU=True):
#         super(two_layer_ReLU_net, self).__init__()
#         self.fc1 = nn.Linear(d, d1, bias=bias)
#         self.fc2 = nn.Linear(d1, output_class, bias=bias)
#         self.d1 = d1 # width
#         self.last_ReLU = last_ReLU
#         self.output_class = output_class
        
        
#     def forward(self, x):
#         self.g1 = self.fc1(x)
#         self.h1 = F.relu(self.g1)
#         self.g2 = self.fc2(self.h1)
#         self.h2 = F.relu(self.g2)
#         if self.last_ReLU:
#             return self.h2
#         else:
#             return self.g2
    
    
    
#     def activation_pattern(self, x):
#         self.forward(x) # set h1 and h2
#         return self.h1>0 # binary tensor
    
    
#     def rank(self, x, e=0.001): # compute the rank at x
#         diff = self.forward(x+e) - self.forward(x)
#         rank = (diff != 0).sum(-1)
#         return rank
    
#     def partition(self, X, title, save_location, show_data=True, rank_measure='output rank', show_vector=False):
#         w = 2
#         N = 400
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
#         if rank_measure == 'decision boundary':
#             rank = (net(grid.view(-1,2))<0).float().view(N,N)                     # decision boundary, output>0
#             plt.contourf(x,y,rank.cpu(), np.arange(-1,2,1), cmap='gray')
# #             plt.contourf(x,y,rank.cpu(), np.arange(-1,2,0.1), cmap='RdBu_r')
#         elif rank_measure == 'output rank':
#             rank = net.output_rank(grid.view(-1,2)).view(N,N)                     # Output rank
#             plt.contourf(x,y,rank.cpu(), np.arange(0,self.output_class+1,1), cmap='gray')
#         elif rank_measure == 'function rank':
#             rank = net.function_rank(grid.view(-1,2)).view(N,N)                 # Function rank
#             if self.last_ReLU:
#                 plt.contourf(x,y,rank.cpu(), np.arange(0,self.output_class+1,1), cmap='gray')
#             else:
#                 plt.contourf(x,y,rank.cpu(), np.arange(0,self.m+1,1), cmap='gray')
#         elif rank_measure == 'matrix rank':
#             rank = net.matrix_rank(grid.view(-1,2)).view(N,N)                 # Function rank
#             plt.contourf(x,y,rank.cpu(), np.arange(0,4 ,1), cmap='gray')
#         else:
#             raise RankMeasureError
#         plt.colorbar()
        
#         if self.last_ReLU :
#             activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+output_class)
#         else:
#             activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1)
        
#         for i in range(self.d1):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#         if self.last_ReLU:
#             for i in range(self.d1,self.d1+self.output_class):
#                 plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='yellow', linewidths=0.5)

#         black_patch = mpatches.Patch(color='black', label='Rank 0')
#         gray_patch = mpatches.Patch(color='gray', label='Rank 1')
#         blue_patch = mpatches.Patch(color='blue', label='1st layer')
#         plt.legend(handles=[blue_patch])
        
#         if show_vector:
#             _x, _y = torch.meshgrid(torch.linspace(-w,w,20), torch.linspace(-w,w,20))
#             _grid = torch.stack((_x,_y),dim=2).to(device).float() # SHAPE 10,10,2
#             out = net(_grid.view(-1,2))
#             _u = out[:,0]-_grid.view(-1,2)[:,0]
#             _v = out[:,1]-_grid.view(-1,2)[:,1]
#             plt.quiver(_x.cpu(),_y.cpu(),_u.detach().cpu(),_v.detach().cpu())
        
#         plt.plot(trn_X[trn_Y.squeeze()>0.5][:,0].cpu(), trn_X[trn_Y.squeeze()>0.5][:,1].cpu(), '*g')
#         if show_data :
#             # plot dataset
#             plt.plot(trn_X[trn_Y.squeeze()<0.5][:,0].cpu(), trn_X[trn_Y.squeeze()<0.5][:,1].cpu(), '*k')
#         plt.xlabel('x')
#         plt.ylabel('y')
#         plt.title(f"Net:2->{self.d1}->{self.output_class}, "+title)
#         plt.xlim(-w,w)
#         plt.ylim(-w,w)
#         plt.savefig(save_location)
# #         plt.show()
#         plt.clf()

    
    
# ###############################################################################
# class three_layer_ReLU_net(nn.Module):
#     def __init__(self, d1, d2, output_class=2, bias=True, last_ReLU=True):
#         super(three_layer_ReLU_net, self).__init__()
#         self.fc1 = nn.Linear(d, d1, bias=bias)
#         self.fc2 = nn.Linear(d1, d2, bias=bias)
#         self.fc3 = nn.Linear(d2, output_class, bias=bias)
#         self.d1 = d1
#         self.d2 = d2
#         self.last_ReLU = last_ReLU
#         self.output_class = output_class
        
#     def forward(self, x):
#         self.g1 = self.fc1(x)
#         self.h1 = F.relu(self.g1)
#         self.g2 = self.fc2(self.h1)
#         self.h2 = F.relu(self.g2)
#         self.g3 = self.fc3(self.h2)
#         self.h3 = F.relu(self.g3)
#         if self.last_ReLU:
#             return self.h3
#         else:
#             return self.g3
    
#     def activation_pattern(self, x):
#         self.forward(x) # set h1 and h2
#         if self.last_ReLU :
#             return torch.cat((self.h1>0, self.h2>0, self.h3>0), dim=1) # binary tensor, shape width+width+output_class
#         else:
#             return torch.cat((self.h1>0, self.h2>0), dim=1) # binary tensor, shape width+width+output_class
    
    
#     def output_rank(self, x, e=0.001): # compute the output rank at x
#         diff = self.forward(x+e) - self.forward(x)
#         output_rank = (diff != 0).sum(-1)
#         return output_rank
    
#     def function_rank(self, x): # compute the function rank at x
#         self.forward(x)
#         if self.last_ReLU:
#             function_rank = torch.min(torch.stack(((self.h1!=0).sum(1), (self.h2!=0).sum(1), (self.h3!=0).sum(1)), dim=0), dim=0)[0]
#         else:
#             function_rank = torch.min(torch.stack(((self.h1!=0).sum(1), (self.h2!=0).sum(1)), dim=0), dim=0)[0]
#         return function_rank#.clamp(0,2)
    
#     def matrix_rank(self, x): # compute the matrix rank at x // the matrix A s.t. N(x) = Ax + b
#         batchsize = len(x)
#         A = torch.cat(
#             (self.forward(x + torch.tensor([1.0, 0.0]).to(device)) - self.forward(x),
#              self.forward(x + torch.tensor([0.0, 1.0]).to(device)) - self.forward(x)), dim=0
#         ) # shape 2batchsize x output_class
#         z = torch.zeros(int(len(A)/2))
#         for i in range(len(z)): ########################################################## for loop is too ineffective ################
#             ############################################ How can I fix it effectively ?? ###############################################
#             z[i] = torch.linalg.matrix_rank(A[2*i:2*i+2]).item()
#         return z
        
    
#     def partition(self, X, title, save_location, loss, w=2, only_data=False, showfig=False):
# #         w = 20
#         N = 400
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        
#         if only_data:
#             plt.plot(X[:,0].cpu(), X[:,1].cpu(), '*g')
#             plt.xlabel('x')
#             plt.ylabel('y')
#             plt.title("Dataset: two triangles")
#             plt.xlim(-w,w)
#             plt.ylim(-w,w)
#             plt.savefig(save_location)
#             plt.clf()
            
#         else:
#             rank = net(grid.view(-1,2)).view(N,N).detach()                     # decision boundary, output>0
#             if loss == 'MSE':
#                 plt.contourf(x,y,rank.cpu(), np.array([-0.4,0,0.4, 0.8,1.2]), cmap='Greys') #### MSE
#             elif loss == 'BCE':
#                 plt.contourf(x,y,torch.sigmoid(rank.cpu()), np.array([-0,0.5,1]), cmap='Greys') #### BCE
#             plt.colorbar()

#             if self.last_ReLU :
#                 activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+self.d2+output_class)
#                 for i in range(self.d1):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#                 for i in range(self.d1,self.d1+self.d2):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)
#                 blue_patch = mpatches.Patch(color='blue', label='1st layer')
#                 red_patch = mpatches.Patch(color='red', label='2nd layer')
#                 plt.legend(handles=[blue_patch, red_patch])
#             else:
#                 activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+self.d2)
#                 for i in range(self.d1):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#                 for i in range(self.d1,self.d1+self.d2):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)
#                 blue_patch = mpatches.Patch(color='blue', label='1st layer')
#                 red_patch = mpatches.Patch(color='red', label='2nd layer')
#                 plt.legend(handles=[blue_patch, red_patch])

#             plt.plot(trn_X[trn_Y.squeeze()>0.5][:,0].cpu(), trn_X[trn_Y.squeeze()>0.5][:,1].cpu(), '*g')
#             plt.xlabel('x')
#             plt.ylabel('y')
#             if loss == 'MSE':
#                 plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                       +r"$\longrightarrow$"+f"{self.output_class}, "+title)
#             elif loss=="BCE":
#                 plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                       +r"$\genfrac{}{}{0}{}{\mathtt{SIG}}{\longrightarrow}$"+f"{self.output_class}, "+title)
#             plt.xlim(-w,w)
#             plt.ylim(-w,w)
#             plt.savefig(save_location)
#             if showfig:
#                 plt.show()
#             plt.clf()

    
    
# ###############################################################################
# class four_layer_ReLU_net(nn.Module):
#     def __init__(self, d1, d2, d3, output_class=1, bias=True, last_ReLU=False):
#         super(four_layer_ReLU_net, self).__init__()
#         self.fc1 = nn.Linear(d, d1, bias=bias)
#         self.fc2 = nn.Linear(d1, d2, bias=bias)
#         self.fc3 = nn.Linear(d2, d3, bias=bias)
#         self.fc4 = nn.Linear(d3, output_class, bias=bias)
#         self.d1 = d1 # width
#         self.d2 = d2 # width
#         self.d3 = d3 # width
#         self.last_ReLU = last_ReLU
#         self.output_class = output_class
        
        
#     def forward(self, x):
#         self.g1 = self.fc1(x)
#         self.h1 = F.relu(self.g1)
#         self.g2 = self.fc2(self.h1)
#         self.h2 = F.relu(self.g2)
#         self.g3 = self.fc3(self.h2)
#         self.h3 = F.relu(self.g3)
#         self.g4 = self.fc4(self.h3)
#         self.h4 = F.relu(self.g4)
#         if self.last_ReLU:
#             return self.h4
#         else:
#             return self.g4
    
#     def activation_pattern(self, x):
#         self.forward(x) # set h1 and h2
#         if self.last_ReLU :
#             return torch.cat((self.h1>0, self.h2>0, self.h3>0, self.h4>0), dim=1) # binary tensor, shape width+width+output_class
#         else:
#             return torch.cat((self.h1>0, self.h2>0, self.h3>0), dim=1) # binary tensor, shape width+width+output_class
    
#     def partition(self, X, title, save_location, loss, only_data=False, showfig=False):
#         w = 2
#         N = 400
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        
#         if only_data:
#             plt.plot(X[:,0].cpu(), X[:,1].cpu(), '*g')
#             plt.xlabel('x')
#             plt.ylabel('y')
#             plt.title("Dataset: two triangles")
#             plt.xlim(-w,w)
#             plt.ylim(-w,w)
#             plt.savefig(save_location)
#             plt.clf()
            
#         else:
#             rank = net(grid.view(-1,2)).view(N,N).detach()                     # 400 x 400 x 2 tensor
#             if loss == 'MSE':
#                 plt.contourf(x,y,rank.cpu(), np.array([-0.4,0,0.4, 0.8,1.2]), cmap='Greys', extend='both') #### MSE
#             elif loss == 'BCE':
#                 plt.contourf(x,y,torch.sigmoid(rank.cpu()), np.array([-0,0.5,1]), cmap='Greys') #### BCE
#             plt.colorbar()
            
#             if self.last_ReLU :
#                 activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+self.d2+self.d3+output_class)
#                 for i in range(self.d1):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#                 for i in range(self.d1,self.d1+self.d2):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)
#                 for i in range(self.d1+self.d2,self.d1+self.d2+self.d3):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='yellow', linewidths=0.5)
#                 for i in range(self.d1+self.d2+self.d3, self.d1+self.d2+self.d3+output_class):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='purple', linewidths=0.5)

#                 blue_patch = mpatches.Patch(color='blue', label='1st layer')
#                 red_patch = mpatches.Patch(color='red', label='2nd layer')
#                 yellow_patch = mpatches.Patch(color='yellow', label='3rd layer')
#                 purple_patch = mpatches.Patch(color='purple', label='4th layer')
#                 plt.legend(handles=[blue_patch, red_patch, yellow_patch, purple_patch])
#             else:
#                 activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+self.d2+self.d3+output_class)
#                 for i in range(self.d1):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#                 for i in range(self.d1,self.d1+self.d2):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)
#                 for i in range(self.d1+self.d2,self.d1+self.d2+self.d3):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='yellow', linewidths=0.5)
#                 for i in range(self.d1+self.d2+self.d3, self.d1+self.d2+self.d3+output_class):
#                     plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='purple', linewidths=0.5)
                    
#                 blue_patch = mpatches.Patch(color='blue', label='1st layer')
#                 red_patch = mpatches.Patch(color='red', label='2nd layer')
#                 yellow_patch = mpatches.Patch(color='yellow', label='3rd layer')
#                 plt.legend(handles=[blue_patch, red_patch, yellow_patch])

#             plt.plot(trn_X[trn_Y.squeeze()>0.5][:,0].cpu(), trn_X[trn_Y.squeeze()>0.5][:,1].cpu(), '*g')
#             plt.xlabel('x')
#             plt.ylabel('y')
#             if loss == 'MSE':
#                 plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d3}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.output_class}, "+title)
#             elif loss=="BCE":
#                 plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d3}"
#                       +r"$\genfrac{}{}{0}{}{\mathtt{SIG}}{\longrightarrow}$"+f"{self.output_class}, "+title)
#             plt.xlim(-w,w)
#             plt.ylim(-w,w)
#             plt.savefig(save_location)
#             if showfig:
#                 plt.show()
#             plt.clf()

In [None]:
# ###############################################################################    
# class five_layer_ReLU_net(nn.Module):
#     def __init__(self, d1,d2,d3,d4,output_class=1, bias=True, last_ReLU=False):
#         super(five_layer_ReLU_net, self).__init__()
#         self.fc1 = nn.Linear(d, d1, bias=bias)
#         self.fc2 = nn.Linear(d1, d2, bias=bias)
#         self.fc3 = nn.Linear(d2, d3, bias=bias)
#         self.fc4 = nn.Linear(d3, d4, bias=bias)
#         self.fc5 = nn.Linear(d4, output_class, bias=bias)
#         self.d1 = d1
#         self.d2 = d2
#         self.d3 = d3
#         self.d4 = d4
#         self.last_ReLU = last_ReLU
#         self.output_class = output_class
        
#     def forward(self, x):
#         self.g1 = self.fc1(x)
#         self.h1 = F.relu(self.g1)
#         self.g2 = self.fc2(self.h1)
#         self.h2 = F.relu(self.g2)
#         self.g3 = self.fc3(self.h2)
#         self.h3 = F.relu(self.g3)
#         self.g4 = self.fc4(self.h3)
#         self.h4 = F.relu(self.g4)
#         self.g5 = self.fc5(self.h4)
#         self.h5 = F.relu(self.g5)
#         if self.last_ReLU:
#             return self.h5
#         else:
#             return self.g5
    
#     def activation_pattern(self, x):
#         self.forward(x) # set h1 and h2
#         if self.last_ReLU :
#             return torch.cat((self.h1>0, self.h2>0, self.h3>0, self.h4>0, self.h5>0), dim=1) # binary tensor, shape width+width+output_class
#         else:
#             return torch.cat((self.h1>0, self.h2>0, self.h3>0, self.h4>0), dim=1) # binary tensor, shape width+width+output_class
    
#     def partition(self, X, title, save_location, loss, show_data=True, show_vector=False):
#         w = 2
#         N = 400
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
#         rank = net(grid.view(-1,2)).view(N,N).detach()                     # decision boundary, output>0
#         if loss == 'MSE':
#             plt.contourf(x,y,rank.cpu(), np.array([-0.4,0,0.4, 0.8,1.2]), cmap='Greys', extend='both') #### MSE
#         elif loss == 'BCE':
#             plt.contourf(x,y,torch.sigmoid(rank.cpu()), np.array([-0,0.5,1]), cmap='Greys') #### BCE
#         plt.colorbar()
        
#         if self.last_ReLU :
#             activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+self.d2+self.d3+self.d4+output_class)
#         else:
#             activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+self.d2+self.d3+self.d4)
        
#         for i in range(self.d1):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#         for i in range(self.d1,self.d1+self.d2):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)
#         for i in range(self.d1+self.d2,self.d1+self.d2+self.d3):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='yellow', linewidths=0.5)
#         for i in range(self.d1+self.d2+self.d3, self.d1+self.d2+self.d3+self.d4):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='purple', linewidths=0.5)
#         if self.last_ReLU:
#             for i in range(self.d1+self.d2+self.d3+self.d4,self.d1+self.d2+self.d3+self.d4+self.output_class):
#                 plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='orange', linewidths=0.5)

#         black_patch = mpatches.Patch(color='black', label='Rank 0')
#         gray_patch = mpatches.Patch(color='gray', label='Rank 1')
#         blue_patch = mpatches.Patch(color='blue', label='1st layer')
#         red_patch = mpatches.Patch(color='red', label='2nd layer')
#         yellow_patch = mpatches.Patch(color='yellow', label='3rd layer')
#         purple_patch = mpatches.Patch(color='purple', label='4th layer')
#         orange_patch = mpatches.Patch(color='orange', label='5th layer')
#         plt.legend(handles=[blue_patch, red_patch, yellow_patch, purple_patch])
        
#         if show_data :
#             # plot dataset
# #             plt.plot(trn_X[trn_Y.squeeze()>0.5][:,0].cpu(), trn_X[trn_Y.squeeze()>0][:,1].cpu(), '*g')
# #             plt.plot(trn_X[trn_Y.squeeze()<0.5][:,0].cpu(), trn_X[trn_Y.squeeze()<0][:,1].cpu(), '.k')
#             plt.plot(trn_X[trn_Y.squeeze()>0.5][:,0].cpu(), trn_X[trn_Y.squeeze()>0.5][:,1].cpu(), '*g')
            
#         plt.xlabel('x')
#         plt.ylabel('y')
#         plt.title(f"Net:2->{self.d1}->{self.d2}->{self.output_class}, "+title)
#         plt.xlim(-w,w)
#         plt.ylim(-w,w)
#         plt.savefig(save_location)
# #         plt.show()
#         plt.clf()

In [None]:
# class three_layer_MAX_net(nn.Module):
#     def __init__(self, d1, d2, output_class=1, bias=True, last_ReLU=False):
#         super(three_layer_MAX_net, self).__init__()
#         self.fc1 = nn.Linear(d, d1, bias=bias)
#         self.fc2 = nn.Linear(d1, d2, bias=bias)
#         self.d1 = d1
#         self.d2 = d2
#         self.last_ReLU = last_ReLU
#         self.output_class = output_class
        
#     def forward(self, x):
#         self.g1 = self.fc1(x)
#         self.h1 = F.relu(self.g1)
#         self.g2 = self.fc2(self.h1)
#         self.h2 = F.relu(self.g2)
#         self.g3 = self.h2.max(dim=1)[0]
#         return self.g3.view(-1,1)
    
#     def activation_pattern(self, x):
#         self.forward(x) # set h1 and h2
#         return torch.cat((self.h1>0, self.h2>0), dim=1) # binary tensor, shape width+width+output_class
    
#     def partition(self, X, title, save_location, loss, only_data=False, showfig=False):
#         w = 20
#         N = 400
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        
#         if only_data:
#             plt.plot(X[:,0].cpu(), X[:,1].cpu(), '*g')
#             plt.xlabel('x')
#             plt.ylabel('y')
#             plt.title("Dataset: two triangles")
#             plt.xlim(-w,w)
#             plt.ylim(-w,w)
#             plt.savefig(save_location)
#             plt.clf()
            
#         else:
#             rank = net(grid.view(-1,2)).view(N,N).detach()                     # decision boundary, output>0
#             if loss == 'MSE':
#                 plt.contourf(x,y,rank.cpu(), np.array([-0.4,0,0.4, 0.8, 1.2]), cmap='Greys', extend='both') #### MSE
#             elif loss == 'BCE':
#                 plt.contourf(x,y,torch.sigmoid(rank.cpu()), np.array([-0,0.5,1]), cmap='Greys') #### BCE
#             plt.colorbar()

#             activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1 + self.d2)

#             for i in range(self.d1):
#                 plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#             for i in range(self.d1, self.d2):
#                 plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)

#             blue_patch = mpatches.Patch(color='blue', label='1st layer')
#             red_patch = mpatches.Patch(color='red', label='2nd layer')
#             plt.legend(handles=[blue_patch, red_patch])

#             plt.plot(trn_X[trn_Y.squeeze()>0.5][:,0].cpu(), trn_X[trn_Y.squeeze()>0.5][:,1].cpu(), '*g')
#             plt.xlabel('x')
#             plt.ylabel('y')
#             if loss == 'MSE':
#                 plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                       +r"$\genfrac{}{}{0}{}{\mathtt{MAX}}{\longrightarrow}$"+f"{self.output_class}, "+title)
#             elif loss=="BCE":
#                 plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                       +r"$\genfrac{}{}{0}{}{\mathtt{BCE}}{\longrightarrow}$"+f"{self.output_class}, "+title)
#             plt.xlim(-w,w)
#             plt.ylim(-w,w)
#             plt.savefig(save_location)
#             if showfig:
#                 plt.show()
#             plt.clf()



# class three_layer_MAX_two_blocks(nn.Module):
#     def __init__(self, d1, d2, output_class=1, bias=True):
#         super(three_layer_MAX_two_blocks, self).__init__()
#         self.fc11 = nn.Linear(d, d+1, bias=bias)
#         self.fc12 = nn.Linear(d+1, 1, bias=bias)
#         self.fc21 = nn.Linear(d, d+1, bias=bias)
#         self.fc22 = nn.Linear(d+1, 1, bias=bias)
#         self.d1 = d1
#         self.d2 = d2
#         self.output_class = output_class
        
#     def forward(self, x):
#         self.g11 = self.fc11(x)
#         self.h11 = F.relu(self.g11)
#         self.g12 = self.fc12(self.h11)
#         self.h12 = F.relu(self.g12)
        
#         self.g21 = self.fc21(x)
#         self.h21 = F.relu(self.g21)
#         self.g22 = self.fc22(self.h21)
#         self.h22 = F.relu(self.g22)
        
#         self.g3 = torch.max(self.h12, self.h22)
#         return self.g3
    
#     def activation_pattern(self, x):
#         self.forward(x) # set h1 and h2
#         return torch.cat((self.h11>0,self.h21>0, self.h12>0,self.h22>0), dim=1) # binary tensor, shape width+width+output_class
    
#     def partition(self, X, title, save_location, loss, only_data=False):
#         w = 20
#         N = 400
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        
#         if only_data:
#             plt.plot(X[:,0].cpu(), X[:,1].cpu(), '*g')
#             plt.xlabel('x')
#             plt.ylabel('y')
#             plt.title("Dataset: two triangles")
#             plt.xlim(-w,w)
#             plt.ylim(-w,w)
#             plt.savefig(save_location)
#             plt.clf()
            
#         else:
#             rank = net(grid.view(-1,2)).view(N,N).detach()                     # decision boundary, output>0
#             if loss == 'MSE':
#                 plt.contourf(x,y,rank.cpu(), np.array([-0.4,0,0.4, 0.8,1.2]), cmap='Greys') #### MSE
#             elif loss == 'BCE':
#                 plt.contourf(x,y,F.sigmoid(rank.cpu()), np.array([-0,0.5,1]), cmap='Greys') #### BCE
#             plt.colorbar()

#             activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,2*(d+1)+2)

#             for i in range(2*(d+1)):
#                 plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#             for i in range(2*(d+1), 2*(d+1)+2):
#                 plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)

#             blue_patch = mpatches.Patch(color='blue', label='1st layer')
#             red_patch = mpatches.Patch(color='red', label='2nd layer')
#             plt.legend(handles=[blue_patch, red_patch])

#             plt.plot(trn_X[trn_Y.squeeze()>0.5][:,0].cpu(), trn_X[trn_Y.squeeze()>0.5][:,1].cpu(), '*g')
#             plt.xlabel('x')
#             plt.ylabel('y')
#             if loss == 'MSE':
#                 plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                       +r"$\genfrac{}{}{0}{}{\mathtt{MAX}}{\longrightarrow}$"+f"{self.output_class}, "+title)
#             elif loss=="BCE":
#                 plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                       +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                       +r"$\genfrac{}{}{0}{}{\mathtt{SIG}}{\longrightarrow}$"+f"{self.output_class}, "+title)
#             plt.xlim(-w,w)
#             plt.ylim(-w,w)
#             plt.savefig(save_location)
#             plt.clf()

In [None]:
# class three_layer_MAX_six_blocks(nn.Module):
#     def __init__(self, d1, d2, output_class=1, bias=True):
#         super(three_layer_MAX_six_blocks, self).__init__()
#         self.fc11 = nn.Linear(d, d+1, bias=bias)
#         self.fc12 = nn.Linear(d+1, 1, bias=bias)
#         self.fc21 = nn.Linear(d, d+1, bias=bias)
#         self.fc22 = nn.Linear(d+1, 1, bias=bias)
#         self.fc31 = nn.Linear(d, d+1, bias=bias)
#         self.fc32 = nn.Linear(d+1, 1, bias=bias)
#         self.fc41 = nn.Linear(d, d+1, bias=bias)
#         self.fc42 = nn.Linear(d+1, 1, bias=bias)
#         self.fc51 = nn.Linear(d, d+1, bias=bias)
#         self.fc52 = nn.Linear(d+1, 1, bias=bias)
#         self.fc61 = nn.Linear(d, d+1, bias=bias)
#         self.fc62 = nn.Linear(d+1, 1, bias=bias)
#         self.d1 = d1
#         self.d2 = d2
#         self.output_class = output_class
        
#     def forward(self, x):
#         self.g11 = self.fc11(x)
#         self.h11 = F.relu(self.g11)
#         self.g12 = self.fc12(self.h11)
#         self.h12 = F.relu(self.g12)
        
#         self.g21 = self.fc21(x)
#         self.h21 = F.relu(self.g21)
#         self.g22 = self.fc22(self.h21)
#         self.h22 = F.relu(self.g22)
        
#         self.g31 = self.fc31(x)
#         self.h31 = F.relu(self.g31)
#         self.g32 = self.fc22(self.h31)
#         self.h32 = F.relu(self.g32)
        
#         self.g41 = self.fc41(x)
#         self.h41 = F.relu(self.g41)
#         self.g42 = self.fc22(self.h41)
#         self.h42 = F.relu(self.g42)
        
#         self.g51 = self.fc51(x)
#         self.h51 = F.relu(self.g51)
#         self.g52 = self.fc22(self.h51)
#         self.h52 = F.relu(self.g52)
        
#         self.g61 = self.fc61(x)
#         self.h61 = F.relu(self.g61)
#         self.g62 = self.fc22(self.h61)
#         self.h62 = F.relu(self.g62)
        
#         self.g3 = torch.max( torch.cat( (self.h12, self.h22, self.h32, self.h42,self.h52, self.h62), dim=1 )
#                             , dim=1)[0].view(-1,1)
#         return self.g3
    
#     def activation_pattern(self, x):
#         self.forward(x) # set h1 and h2
#         # binary tensor, shape width+width+output_class
#         return torch.cat((self.h11>0, self.h21>0, self.h31>0, self.h41>0, self.h51>0, self.h61>0,
#                           self.h12>0, self.h22>0, self.h32>0, self.h42>0, self.h52>0, self.h62>0), dim=1) 
    
#     def partition(self, X, title, save_location, rank_measure='output rank'):
#         w = 20
#         N = 400
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
#         if rank_measure == 'decision boundary':
#             rank = net(grid.view(-1,2)).view(N,N).detach()                     # decision boundary, output>0
#             plt.contourf(x,y,rank.cpu()/10, np.arange(-1,3,1), cmap='Greys')
#         plt.colorbar()
        
#         activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,6*(d+2))
        
#         for i in range(6*(d+1)):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#         for i in range(6*(d+1), 6*(d+1)+6):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)

#         blue_patch = mpatches.Patch(color='blue', label='1st layer')
#         red_patch = mpatches.Patch(color='red', label='2nd layer')
#         plt.legend(handles=[blue_patch, red_patch])
        
#         plt.plot(trn_X[trn_Y.squeeze()>0.5][:,0].cpu(), trn_X[trn_Y.squeeze()>0.5][:,1].cpu(), '*g')
#         plt.xlabel('x')
#         plt.ylabel('y')
#         plt.title("Network:  "+r"$2\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d1}"
#                   +r"$\genfrac{}{}{0}{}{\sigma}{\longrightarrow}$"+f"{self.d2}"
#                   +r"$\genfrac{}{}{0}{}{\mathtt{MAX}}{\longrightarrow}$"+f"{self.output_class}, "+title)
#         plt.xlim(-w,w)
#         plt.ylim(-w,w)
#         plt.savefig(save_location)
#         plt.clf()

        
        
        
# class four_layer_conv_net(nn.Module):
#     def __init__(self, d1, d2, d3, output_class=1, bias=True, last_ReLU=False):
#         super(four_layer_conv_net, self).__init__()
#         self.conv1 = nn.Conv1d(1, d1, kernel_size=2, bias=bias)
#         self.fc2 = nn.Linear(d1, d2, bias=bias)
#         self.fc3 = nn.Linear(d2, d3, bias=bias)
#         self.fc4 = nn.Linear(d3, output_class, bias=bias)
#         self.d1 = d1 # width
#         self.d2 = d2 # width
#         self.d3 = d3 # width
#         self.last_ReLU = last_ReLU
#         self.output_class = output_class
        
        
#     def forward(self, x):
#         self.g1 = self.conv1(x.view(-1,1,2)) # one channel
#         self.h1 = F.relu(self.g1).view(-1,self.d1)
#         self.g2 = self.fc2(self.h1)
#         self.h2 = F.relu(self.g2)
#         self.g3 = self.fc3(self.h2)
#         self.h3 = F.relu(self.g3)
#         self.g4 = self.fc4(self.h3)
#         self.h4 = F.relu(self.g4)
#         if self.last_ReLU:
#             return self.h4
#         else:
#             return self.g4
    
#     def activation_pattern(self, x):
#         self.forward(x) # set h1 and h2
#         if self.last_ReLU :
#             return torch.cat((self.h1>0, self.h2>0, self.h3>0, self.h4>0), dim=1) # binary tensor, shape width+width+output_class
#         else:
#             return torch.cat((self.h1>0, self.h2>0, self.h3>0), dim=1) # binary tensor, shape width+width+output_class
    
    
#     def output_rank(self, x, e=0.001): # compute the output rank at x
#         diff = self.forward(x+e) - self.forward(x)
#         output_rank = (diff != 0).sum(-1)
#         return output_rank
    
#     def function_rank(self, x): # compute the function rank at x
#         self.forward(x)
#         if self.last_ReLU:
#             function_rank = torch.min(torch.stack(((self.h1!=0).sum(1), (self.h2!=0).sum(1), (self.h3!=0).sum(1), (self.h4!=0).sum(1)), dim=0), dim=0)[0]
#         else:
#             function_rank = torch.min(torch.stack(((self.h1!=0).sum(1), (self.h2!=0).sum(1), (self.h3!=0).sum(1)), dim=0), dim=0)[0]
#         return function_rank
    
#     def partition(self, X, title, save_location, show_data=True, rank_measure='output rank', show_vector=False):
#         w = 20
#         N = 400
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
#         if rank_measure == 'decision boundary':
#             rank = net(grid.view(-1,2)).view(N,N).detach()                     # decision boundary, output>0
#             plt.contourf(x,y,rank.cpu(), np.arange(-5,15,5), cmap='Greys')
# #             plt.contourf(x,y,rank.cpu(), np.arange(-1,2,0.1), cmap='RdBu_r')
#         elif rank_measure == 'output rank':
#             rank = net.output_rank(grid.view(-1,2)).view(N,N)                     # Output rank
#             plt.contourf(x,y,rank.cpu(), np.arange(0,self.output_class+1,1), cmap='gray')
#         elif rank_measure == 'function rank':
#             rank = net.function_rank(grid.view(-1,2)).view(N,N)                 # Function rank
#             plt.contourf(x,y,rank.cpu(), np.arange(0,self.m+1,1), cmap='gray')
#         else:
#             raise RankMeasureError
#         plt.colorbar()
        
#         if self.last_ReLU :
#             activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+self.d2+self.d3+output_class)
#         else:
#             activation_pattern = net.activation_pattern(grid.view(-1,2)).view(N,N,self.d1+self.d2+self.d3)
        
#         for i in range(self.d1):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='blue', linewidths=0.5)
#         for i in range(self.d1,self.d1+self.d2):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='red', linewidths=0.5)
#         for i in range(self.d1+self.d2,self.d1+self.d2+self.d3):
#             plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='yellow', linewidths=0.5)
#         if self.last_ReLU:
#             for i in range(self.d1+self.d2+self.d3,self.d1+self.d2+self.d3+self.output_class):
#                 plt.contour(x,y,activation_pattern.float().cpu()[:,:,i], levels=1, colors='purple', linewidths=0.5)

#         black_patch = mpatches.Patch(color='black', label='Rank 0')
#         gray_patch = mpatches.Patch(color='gray', label='Rank 1')
#         blue_patch = mpatches.Patch(color='blue', label='1st layer')
#         red_patch = mpatches.Patch(color='red', label='2nd layer')
#         yellow_patch = mpatches.Patch(color='yellow', label='3rd layer')
#         purple_patch = mpatches.Patch(color='purple', label='4th layer')
#         plt.legend(handles=[blue_patch, red_patch, yellow_patch])
        
#         if show_data :
#             # plot dataset
#             plt.plot(X[:,0].cpu(), X[:,1].cpu(), '*g')
#         plt.xlabel('x')
#         plt.ylabel('y')
#         plt.title(f"ConvNet:2->{self.d1}->{self.d2}->{self.d3}->{self.output_class}, "+title)
#         plt.xlim(-w,w)
#         plt.ylim(-w,w)
#         plt.savefig(save_location)
# #         plt.show()
#         plt.clf()

Polytope-basis cover

In [2]:
# Network setting # d -> d1 -> 1
class polytope(nn.Module):
    def __init__(self, width, output_class=1, positive_init=True):
        super(polytope, self).__init__()
        self.fc0 = nn.Linear(d, width)
        self.fc1 = nn.Linear(width, output_class, bias=False)
        self.width = width # width
        self.bias = 1 - 2*positive_init
        
        if positive_init:
            # initialization, to all v_k are negative.
            self.fc1.weight = nn.Parameter((self.W(0).norm(dim=1)**2 + self.b(0)**2).view(1,-1)+1)
        else:
            # initialization, to all v_k are negative.
            self.fc1.weight = nn.Parameter(-(self.W(0).norm(dim=1)**2 + self.b(0)**2).view(1,-1)-1)
        
    def forward(self, x):
        self.g1 = self.fc0(x)
        self.h1 = F.relu(self.g1)
        self.g2 = self.fc1(self.h1) + self.bias
        return self.g2
    
    def W(self, i):
        if i ==0:
            output = self.fc0.weight
        elif i ==1:
            output = self.fc1.weight
        return output
    def b(self, i):
        if i ==0:
            output = self.fc0.bias
        elif i ==1:
            output = self.fc1.bias
        return output
    
    def activation_pattern(self, x):
        self.forward(x)
        return (self.h1>0).float()
    
    def change_layer_weights(self, i, W, b):
        if W.shape == self.W(i).shape and b.shape == self.b(i).shape:
            if i ==0:
                self.fc0.weight = nn.Parameter(W)
                self.fc0.bias = nn.Parameter(b)
            elif i ==1:
                self.fc1.weight = nn.Parameter(W)
                self.fc1.bias = nn.Parameter(b)
        else:
            raise ValueError("wrong shape of input tensors")
    
    def partition(self, w=16):
        N = 200
        x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
        grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        rank = (self.forward(grid.view(-1,2))<0).float().view(N,N)                     # decision boundary, output<0
        plt.contourf(x,y,rank.cpu(), np.arange(-1,2,1), cmap='gray')
        #     plt.contourf(x,y,rank.cpu(), np.arange(-1,2,0.1), cmap='RdBu_r')
        plt.colorbar()
        for i in range(self.width):
            plt.contour(x,y, self.activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='blue', linewidths=0.5)

        black_patch = mpatches.Patch(color='black', label='Rank 0')
        gray_patch = mpatches.Patch(color='gray', label='Rank 1')
        blue_patch = mpatches.Patch(color='blue', label='1st layer')
        plt.legend(handles=[blue_patch])
        # dataset
        plt.scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.')
        plt.scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.')
        plt.title("Activation and decision boundary")
        plt.show()

In [3]:
# multiple polytopes : polytope basis-cover
class cover(nn.Module):
    def __init__(self, width, output_class=1, positive_init=True):
        super(cover, self).__init__()
        self.width=width
        self.polytope_A1 = polytope(width=width, positive_init=positive_init)
        self.polytope_B1 = polytope(width=width, positive_init=positive_init)
        
    def forward(self, x):
        self.a1 = self.polytope_A1(x)
        self.b1 = self.polytope_B1(x)

#         output = (2*(F.relu(self.a1) > F.relu(self.b1)).float()-1) * torch.max(self.a1, self.b1)
        if positive_init:
            output = torch.min(self.a1, self.b1)
        else:
            output = torch.max(self.a1, self.b1)
        return output
    
    def activation_pattern(self, x):
#         self.forward(x)
        output = torch.cat((self.polytope_A1.activation_pattern(x), self.polytope_B1.activation_pattern(x)))
        return output
    
    def all_deactivated(self, x):
        output = (self.polytope_A1.activation_pattern(trn_X).sum(dim=1)==0).float()
        output += (self.polytope_B1.activation_pattern(trn_X).sum(dim=1)==0).float()
        return (output>0).sum().item()
    
    def partition(self, w=2):
        N = 200
        x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
        grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        rank = (net(grid.view(-1,2))<0).float().view(N,N)                     # decision boundary, output<0
        plt.contourf(x,y,rank.cpu(), np.arange(-1,2,1), cmap='gray')
        #     plt.contourf(x,y,rank.cpu(), np.arange(-1,2,0.1), cmap='RdBu_r')
        plt.colorbar()

        for i in range(net.width):
            plt.contour(x,y, net.polytope_A1.activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='blue', linewidths=0.5)
            plt.contour(x,y, net.polytope_B1.activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='red', linewidths=0.5)

        black_patch = mpatches.Patch(color='black', label='Rank 0')
        gray_patch = mpatches.Patch(color='gray', label='Rank 1')
        blue_patch = mpatches.Patch(color='blue', label='1st layer')
        plt.legend(handles=[blue_patch])
        # dataset
        plt.scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.')
        plt.scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.')
        plt.title("Activation and decision boundary")
        plt.show()

In [None]:
# # multiple polytopes : polytope basis-cover
# class cover5(nn.Module):
#     def __init__(self, width, num_polytopes, output_class=1, positive_init=True):
#         super(cover5, self).__init__()
#         self.width=width
#         self.num_polytopes=num_polytopes
#         self.polytope_A1 = polytope(width=width, positive_init=positive_init)
#         self.polytope_A2 = polytope(width=width, positive_init=positive_init)
#         self.polytope_A3 = polytope(width=width, positive_init=positive_init)
#         self.polytope_A4 = polytope(width=width, positive_init=positive_init)
#         self.polytope_A5 = polytope(width=width, positive_init=positive_init)
        
#     def forward(self, x):
#         self.a1 = self.polytope_A1(x)
#         self.a2 = self.polytope_A1(x)
#         self.a3 = self.polytope_A1(x)
#         self.a4 = self.polytope_A1(x)
#         self.a5 = self.polytope_A1(x)
        
#         output = torch.cat((self.a1, self.a2, self.a3, self.a4, self.a5), dim=1)

# #         output = (2*(F.relu(self.a1) > F.relu(self.b1)).float()-1) * torch.max(self.a1, self.b1)
#         if positive_init:
#             output = torch.min(output, dim=1)[0]
#         else:
#             output = torch.max(output, dim=1)[0]
#         return output
    
#     def activation_pattern(self, x):
# #         self.forward(x)
#         output = torch.cat((self.polytope_A1.activation_pattern(x), 
#                             self.polytope_A2.activation_pattern(x),
#                             self.polytope_A3.activation_pattern(x),
#                             self.polytope_A4.activation_pattern(x),
#                             self.polytope_A5.activation_pattern(x)
#                            ), dim=1)
#         return output
    
#     def all_deactivated(self, x):
#         output = (self.polytope_A1.activation_pattern(trn_X).sum(dim=1)==0).float()
#         output += (self.polytope_A2.activation_pattern(trn_X).sum(dim=1)==0).float()
#         output += (self.polytope_A3.activation_pattern(trn_X).sum(dim=1)==0).float()
#         output += (self.polytope_A4.activation_pattern(trn_X).sum(dim=1)==0).float()
#         output += (self.polytope_A5.activation_pattern(trn_X).sum(dim=1)==0).float()
#         return (output>0).sum().item()
    
#     def partition(self, w=16):
#         N = 200
#         x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
#         grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
#         rank = (net(grid.view(-1,2))<0).float().view(N,N)                     # decision boundary, output<0
#         plt.contourf(x,y,rank.cpu(), np.arange(-1,2,1), cmap='gray')
#         #     plt.contourf(x,y,rank.cpu(), np.arange(-1,2,0.1), cmap='RdBu_r')
#         plt.colorbar()

#         for i in range(net.width):
#             plt.contour(x,y, net.polytope_A1.activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='blue', linewidths=0.5)
#             plt.contour(x,y, net.polytope_B1.activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='red', linewidths=0.5)

#         black_patch = mpatches.Patch(color='black', label='Rank 0')
#         gray_patch = mpatches.Patch(color='gray', label='Rank 1')
#         blue_patch = mpatches.Patch(color='blue', label='1st layer')
#         plt.legend(handles=[blue_patch])
#         # dataset
#         plt.scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.')
#         plt.scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.')
#         plt.title("Activation and decision boundary")
#         plt.show()