In [1]:
# Classification, Intrinsic dimension, Degeneracy
%run SetUp.ipynb
%run synthetic_datasets.ipynb

## Configuration 
Title = "Rebuttal"
device = "cuda:1"
pt_option=True
width = 8
dataset = "Swiss" # Swiss / Two circles / Two moons / XOR
d = 2
n = 1000 # used number of data
num_classes = 2
selected_class = 0
data_balance = 1
trn_X, trn_Y = load_synthetic_dataset(dataset)

# create directory
output_path = f"./Results/{Title}"
create_directory(output_path)

Network Architecture

In [2]:
# Network setting # d -> d1 -> 1
class polytope(nn.Module):
    def __init__(self, width, output_class=1, positive_init=False, small_norm_init=False):
        super(polytope, self).__init__()
        self.fc0 = nn.Linear(d, width)
        self.fc1 = nn.Linear(width, output_class, bias=False)
        self.width = width # width
        self.epoch = 0
        self.data_balance = data_balance
        self.bias = (1 - 2*positive_init)*3
        self.positive_init = positive_init
        self.output_class = output_class
        self.small_norm_init = small_norm_init
        self.device = device
        
        # initialization
        if small_norm_init:
            self.change_layer_weights(layer=0, W=1.5/self.width*torch.randn_like(self.W(0)), b=0.6/self.width*(self.b(0))) # multiply 0.01 
            
        if positive_init:
            # if positivie init,  all v_k are positive.
            self.fc1.weight = nn.Parameter(torch.sqrt((self.W(0).norm(dim=1)**2 + self.b(0)**2).view(1,-1)+1))
        else:
            # else, to all v_k are negative.
            self.fc1.weight = nn.Parameter(-torch.sqrt((self.W(0).norm(dim=1)**2 + self.b(0)**2).view(1,-1)+1))
        
        # declare the optimizer
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        
    def forward(self, x):
        self.g1 = self.fc0(x)
        self.h1 = F.relu(self.g1)
        self.g2 = self.fc1(self.h1) + self.bias
        return self.g2
    
    def W(self, layer):
        if layer ==0:
            output = self.fc0.weight
        elif layer ==1:
            output = self.fc1.weight
        return output.clone()
    def b(self, layer):
        if layer ==0:
            output = self.fc0.bias
        elif layer ==1:
            output = self.fc1.bias
        return output if output == None else output.clone()
    
    def activation_pattern(self, x):
        self.forward(x)
        return (self.h1>0).float()
    
    def change_layer_weights(self, layer, W, b):
        if self.b(layer) == None : # no bias term
            if W.shape == self.W(layer).shape and self.b(layer) == b:
                if layer ==0:
                    self.fc0.weight = nn.Parameter(W.to(self.device))
                elif layer ==1:
                    self.fc1.weight = nn.Parameter(W.to(self.device))
            else:
                raise ValueError("wrong shape of input tensors")
        else:
            if W.shape == self.W(layer).shape and b.shape == self.b(layer).shape:
                if layer ==0:
                    self.fc0.weight = nn.Parameter(W.to(self.device))
                    self.fc0.bias = nn.Parameter(b.to(self.device))
                elif layer ==1:
                    self.fc1.weight = nn.Parameter(W.to(self.device))
    #                 self.fc1.bias = nn.Parameter(b)
            else:
                raise ValueError("wrong shape of input tensors")
    
    def partition(self, save_location, trn_X_index, w=1.6, title="title"):
        N = 200
        x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
        grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        rank = (self.forward(grid.view(-1,2))<0.001).float().view(N,N)                     # decision boundary, output<0
        plt.contourf(x,y,rank.cpu(), np.arange(-1,2,1), cmap='gray')
        plt.contourf(x,y,(self.forward(grid.view(-1,2))==self.bias).float().cpu().view(N,N), np.arange(-1,2,1), cmap='Greys')
        plt.colorbar()
        for i in range(self.width):
            plt.contour(x,y, self.activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='blue', linewidths=0.5)

        # dataset
        plt.scatter(trn_X[:int(n/2)][trn_X_index[:int(n/2)],0].cpu(),  trn_X[:int(n/2)][trn_X_index[:int(n/2)],1].cpu(), marker='.', color="C0")
        plt.scatter(trn_X[int(n/2):][trn_X_index[int(n/2):],0].cpu(),  trn_X[int(n/2):][trn_X_index[int(n/2):],1].cpu(), marker='.', color="C1")
        plt.title(title)
        plt.savefig(save_location)
        plt.close()
    
    
    # Training
    def train(self, repetition, trn_X_index):
        self.fail = False
        self.trn_loss_tr = np.empty(0)
        self.test_loss_tr = np.empty(0)
        self.trn_acc_tr = np.empty(0)
        self.test_acc_tr = np.empty(0)


        test_acc = 0
        test_loss = torch.zeros(1)
        print(f"=================================   Start Training. Repetition: {repetition+1}   ==================================") if pt_option else None
        time.sleep(1)

        log_period = 1000
        for epoch in range(Epochs) :
            self.epoch = epoch
#             loss = criterion(self.forward(trn_X[trn_X_index]), trn_Y[trn_X_index])
            # since class may be imbalance
            loss =  criterion(net(trn_X[trn_X_index][(trn_Y[trn_X_index]==float(positive_init)).view(-1)]),
                              trn_Y[trn_X_index][(trn_Y[trn_X_index]==float(positive_init)).view(-1)]) / len((trn_Y[trn_X_index]==float(positive_init)).view(-1))
            loss += criterion(net(trn_X[trn_X_index][(trn_Y[trn_X_index]==float(not positive_init)).view(-1)]), 
                              trn_Y[trn_X_index][(trn_Y[trn_X_index]==float(not positive_init)).view(-1)]) / len((trn_Y[trn_X_index]==float(not positive_init)).view(-1)) * self.data_balance
            loss *= len((trn_Y[trn_X_index]).view(-1))
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()


            if epoch ==0 or epoch % 1000 == 999 :
                self.trn_loss_tr = np.append(self.trn_loss_tr, loss.item())
                self.test_loss_tr = np.append(self.test_loss_tr, test_loss.item())

                trn_acc = ((net(trn_X)>-0.001).float() == trn_Y).sum().item() / len(trn_Y) *100
                self.trn_acc_tr = np.append(self.trn_acc_tr, trn_acc)
                self.test_acc_tr = np.append(self.test_acc_tr, test_acc)
                class0_inPoly = (self.forward(trn_X[trn_X_index][(trn_Y[trn_X_index]!=float(self.positive_init)).view(-1)]) == self.bias).sum().item()
                class1_inPoly = (self.forward(trn_X[trn_X_index][(trn_Y[trn_X_index]==float(self.positive_init)).view(-1)]) == self.bias).sum().item()
                class0_outPoly = (self.forward(trn_X[trn_X_index][(trn_Y[trn_X_index]!=float(self.positive_init)).view(-1)]) != self.bias).sum().item()
                class1_outPoly = (self.forward(trn_X[trn_X_index][(trn_Y[trn_X_index]==float(self.positive_init)).view(-1)]) != self.bias).sum().item()
#                 self.data_balance = min(100, max(50, class1_outPoly/(.1 + class0_outPoly)))
                self.data_balance = 100
                writer.add_scalars("ReLU/Loss", {
                    "Trn_Loss": loss.item(),
                    "Test_Loss": test_loss.item(), }
                                   , epoch+1)
                writer.add_scalars("ReLU/Accuracy", {
                    "Trn_acc": trn_acc,
                    "Test_acc": test_acc, }
                                   , epoch+1)
                writer.add_scalars("ReLU/InOutClass", {
                    "class0_inPoly": class0_inPoly,
                    "class0_outPoly": class0_outPoly,
                    "class0_inPoly": class1_inPoly,
                    "class0_outPoly": class1_outPoly,}
                                   , epoch+1)
                if pt_option :
                    print(f"Epoch: {epoch+1 :>5} || TRN_loss: {loss.item() :.4f} || TRN_ACC: {trn_acc:.3f}", end="")
                    print(f" || balanced_min: {self.fc1.weight.abs().min():.3f} || class0_inPoly : {class0_inPoly}", end="")
                    print(f" || class0_outPoly : {class0_outPoly} || class1_inPoly : {class1_inPoly} || class1_outPoly : {class1_outPoly}")
#                     print(net.fc0.weight.norm(dim=1)**2 + net.fc0.bias**2 - net.fc1.weight.norm(dim=0)**2) # balanced
                torch.save(self.state_dict(), folder_name+f"/Rep{repetition+1}_saved_net_width_{width}.pt") 
                # Pruning and merging
                self.pruning(trn_X[trn_X_index]) if epoch > 5555 else None ## pruning after pre-training
                self.partition(trn_X_index=trn_X_index, title=f"AB and DB, width={self.width}, epoch={self.epoch+1}", save_location=folder_name+f"/Rep{repetition+1}_{epoch+1}.png")
                
                
                if self.fc1.weight.sign().sum().abs() !=  self.width: # there is a flipped sign.
                    print("Sign flipped !!") if pt_option else None
                    self.fail = True # skip
                    break
                
                if self.width == 0 : # there is no active neuron.
                    print("All neruons have been removed !!") if pt_option else None
                    self.fail = True # skip 
                    break

                # terminating condition - if one class is completely surrounded by a convex polytope.
                if (class0_outPoly == 0) and (epoch>39990):
                    print("Found a complete convex polytope cover, escape the training loop") if pt_option else None
                    self.fail = False
                    break
        self.save_loss_graph()
        self.save_result_txt()
        self.partition(trn_X_index=trn_X_index, title=f"Polytope cover {repetition+1}, width={self.width}", save_location=folder_name_above+f"/Cover_{repetition+1}.png")
        time.sleep(1)
        print(f"===================================   Training finished, Repetition: {repetition+1}  ================================== \n\n") if pt_option else None
        
    
    def save_loss_graph(self):
        # plot and save the graphs
        # PLOT THE LOSS GRAPH
        plt.plot(self.trn_loss_tr, label="Train")
        plt.plot(self.test_loss_tr, label="Test")
        plt.xlabel("Iterations (k)")
        plt.ylabel("Loss")
        # plt.yscale("log")
        plt.legend()
        plt.title(f"Rep{repetition+1}_Loss")
        plt.savefig(folder_name+"/loss_fig_fullview.png")
#         plt.show() if pt_option else None
        plt.close()

        # PLOT THE ACCURACY GRAPH
        plt.plot(self.trn_acc_tr, label="Train")
        plt.plot(self.test_acc_tr, label="Test")
        plt.xlabel("Iterations (k)")
        plt.ylabel("Accuracy")
        plt.legend()
        plt.title(f"Rep{repetition+1}_Accuracy")
        plt.savefig(folder_name+"/Accuracy.png")
#         plt.show() if pt_option else None
        plt.close()
    
    # Save result txt file.
    def save_result_txt(self):
        class0_inPoly = (net(trn_X[trn_X_index][(trn_Y[trn_X_index]!=float(self.positive_init)).view(-1)]) == self.bias).sum().item()
        class1_inPoly = (net(trn_X[trn_X_index][(trn_Y[trn_X_index]==float(self.positive_init)).view(-1)]) == self.bias).sum().item()
        class0_outPoly = (net(trn_X[trn_X_index][(trn_Y[trn_X_index]!=float(self.positive_init)).view(-1)]) != self.bias).sum().item()
        class1_outPoly = (net(trn_X[trn_X_index][(trn_Y[trn_X_index]==float(self.positive_init)).view(-1)]) != self.bias).sum().item()
        
        f = open(f"./{folder_name}/Rep{repetition+1}_result.txt", 'w')
        f.write(
        f"""
            This is a CFG file.

            # Accuracy
            Trn_acc : {self.trn_acc_tr[-1] :.3f}
            Test_acc : {self.test_acc_tr[-1] :.3f}

            # Loss
            Trn_loss : {self.trn_loss_tr[-1].item() :.4f} 
            Test_loss : {self.test_loss_tr[-1].item() :.4f} 


            # dataset
            dataset = {dataset}
            n = {n} # number of total data
            d = {d} # dimension
            num_class = {num_classes} # the class, used in this binary classification
            starting_width = {width}
            final_width = {net.width}
            # Network
            {net}
            
            # Polytope
            number of data used in this training (Class 0 / 1): {class0_inPoly+class0_outPoly} / {class1_inPoly+class1_outPoly}
            Class 0 in Polytope : {class0_inPoly}
            Class 1 in Polytope : {class1_inPoly}
            Class 0 out Polytope : {class0_outPoly}
            Class 1 out Polytope : {class1_outPoly}
            
            # optimization
            Epochs = {Epochs}
            lr = {lr}
            optimizer = {self.optimizer}
            Last epoch = {self.epoch+1}
            
            ## trn_X_index
            {trn_X_index}
        """
        )
        f.close()
    
    def pruning(self, dataset):
        width_before = self.width
        dataset_activation_pattern = self.activation_pattern(dataset).sum(dim=0)
        if dataset_activation_pattern.max() == len(dataset) or dataset_activation_pattern.min() == 0 : 
            W0 = self.W(layer=0).detach()
            b0 = self.b(layer=0).detach()
            W1 = self.W(layer=1).detach()
            activation_filter = ((self.activation_pattern(dataset).sum(dim=0) < len(dataset)).float()
                                 * (self.activation_pattern(dataset).sum(dim=0) > 0).float()
                                 * (self.W(layer=1).abs()>0.1).float().view(-1)
                                )
            # change polytope width 
            self.width = activation_filter.sum().long().item()
            width_after = self.width
            self.fc0 = nn.Linear(d, self.width)
            self.fc1 = nn.Linear(self.width, self.output_class, bias=False)
            self.change_layer_weights(layer=0, W=W0[(activation_filter == 1)], b=b0[activation_filter == 1])
            self.change_layer_weights(layer=1, W=W1.view(-1)[activation_filter == 1].view(1,-1), b=None)
            text = "(removing)"
            print(f"\t{text:<15}The network is prunned, width : {width_before} --> {width_after}") if pt_option else None
            
        ## merging vectors with same activation patterns ##
        width_activation_pattern = self.activation_pattern(dataset) # len(dataset) x width
        width_activation_pattern_unique = width_activation_pattern.unique(dim=1).t()
        width_before = self.width
        if len(width_activation_pattern_unique) < self.width :
            self.width = len(width_activation_pattern_unique) 
            width_after = self.width
            W0 = self.W(layer=0).detach()
            b0 = self.b(layer=0).detach()
            W1 = self.W(layer=1).detach()
            self.fc0 = nn.Linear(d, self.width)
            self.fc1 = nn.Linear(self.width, self.output_class, bias=False)
            self.to(device)
            # build weight matrices
            new_W0 = self.W(layer=0).detach()
            new_b0 = self.b(layer=0).detach()
            new_W1 = self.W(layer=1).detach()
            for index, pattern in enumerate(width_activation_pattern_unique):
                new_W0[index] = W0[(width_activation_pattern.t() == pattern).prod(dim=1)==1].sum(dim=0)
                new_b0[index] = b0[(width_activation_pattern.t() == pattern).prod(dim=1)==1].sum(dim=0)
                new_W1[0][index] = W1[0][(width_activation_pattern.t() == pattern).prod(dim=1)==1].sum(dim=0)
            self.change_layer_weights(layer=0, W=new_W0, b=new_b0)
            self.change_layer_weights(layer=1, W=new_W1, b=None)
            text = "(merging)"
            print(f"\t{text:<15}The network is merged, width : {width_before} --> {width_after}") if pt_option else None

        ###### original #######
        smallest_norm_neuron_index = self.width # not index yet.
        smallest_norm = np.inf
        width_activation_pattern = self.activation_pattern(dataset) # len(dataset) x width
        for neuron_index in range(self.width): # for loop
            activated_index = (width_activation_pattern[:, neuron_index] == 1)
            # index 가 activated 된 data들의 activation pattern ## 이게 2 이상이면 제거해버리면 됨 !
            if self.activation_pattern(dataset[activated_index]).sum(dim=1).min().item() >=2 :
                # this neuron can be removed.
                this_round_norm = (self.W(layer=0)[neuron_index].norm() * self.W(layer=1).squeeze()[neuron_index]).item() ## v||w||
#                 this_round_norm = self.W(layer=1).squeeze()[neuron_index].item() ## v
                if this_round_norm < smallest_norm:
                    smallest_norm_neuron_index = neuron_index
                    smallest_norm = this_round_norm
        if smallest_norm_neuron_index < self.width:
            # remove the small-norm redundant neuron, remove only one neuron at once
            width_before = self.width
            width_after = self.width - 1
            W0 = self.W(layer=0).detach()
            b0 = self.b(layer=0).detach()
            W1 = self.W(layer=1).detach()
            self.fc0 = nn.Linear(d, width_after)
            self.fc1 = nn.Linear(width_after, self.output_class, bias=False)
            self.to(device)        
            new_W0 = torch.cat((W0[:smallest_norm_neuron_index], W0[smallest_norm_neuron_index+1:]), dim=0)
            new_b0 = torch.cat((b0[:smallest_norm_neuron_index], b0[smallest_norm_neuron_index+1:]), dim=0)
            new_W1 = torch.cat((W1[:,:smallest_norm_neuron_index], W1[:,smallest_norm_neuron_index+1:]), dim=1)
            self.change_layer_weights(layer=0, W=new_W0, b=new_b0)
            self.change_layer_weights(layer=1, W=new_W1, b=None)
            self.width = width_after
            text = "(pruning)"
            print(f"\t{text:<15}In the polytope, a redundant neuron has been removed") if pt_option else None
        
        
        ### rescaling function ## here, we don't need to rescale the neuron.
        rescale=1.1
        intermediate_index = ( ( F.relu(self.forward(dataset)) != 0).float() * (self.forward(dataset) != self.bias).float() >0).squeeze()
        if intermediate_index.float().sum() >0 :
            activation_pattern_of_intermediate_data = self.activation_pattern(dataset[intermediate_index])
            if activation_pattern_of_intermediate_data.max().item() >0 : ## max
                new_W0 = self.W(layer=0).detach()
                new_b0 = self.b(layer=0).detach()
                new_W1 = self.W(layer=1).detach()
                true_or_false = activation_pattern_of_intermediate_data.sum(dim=0)>0
                for i, boolean in enumerate(true_or_false):
                    if boolean:
                        new_W0[i] = new_W0[i] * rescale
                        new_b0[i] = new_b0[i] * rescale
                        new_W1[0][i] = new_W1[0][i] * rescale
                self.change_layer_weights(layer=0, W=new_W0, b=new_b0)
                self.change_layer_weights(layer=1, W=new_W1, b=None)
                text = "(rescaling)"
                print(f"\t{text:<15}In the polytope, {true_or_false.sum().item()} neurons are rescaled by {rescale}") if pt_option else None
        
        self.optimizer = optim.Adam(self.parameters(), lr=lr) # net. paramgeter 를 optimizer 가능하도록 넣어줘야해...!!

In [3]:
# Optimization setting
Epochs = 150*1000
lr = 0.0001
criterion = nn.BCEWithLogitsLoss()

# finding the polytope cover !
repetition = 0
trn_X_index_all = (torch.ones(len(trn_X))>0).to(device) # Boolean tensor, shape = len(dataset)
trn_X_index = trn_X_index_all.clone()
# trn_X_index *= (torch.randn(len(trn_X))>0.9).to(device) # random
network_list = []
folder_name_above = output_path + f'/runs/{dataset}/' + datetime.datetime.now().strftime("%B%d_%H_%M_%S")
create_directory(folder_name_above)
# save the initial figure
plt.scatter(trn_X[:int(n/2)][trn_X_index[:int(n/2)],0].cpu(),  trn_X[:int(n/2)][trn_X_index[:int(n/2)],1].cpu(), marker='.', color="C0", label="class 0")
plt.scatter(trn_X[int(n/2):][trn_X_index[int(n/2):],0].cpu(),  trn_X[int(n/2):][trn_X_index[int(n/2):],1].cpu(), marker='.', color="C1", label="class 1")
plt.title("Dataset")
plt.legend()
plt.savefig(folder_name_above+"/Dataset.png")
plt.close()

In [4]:
positive_init = True ### if positive_init=True ==> polytope is negative (-1) ==> approximating the orange class (label 0).
# blue class is labeld by one(+1)
while trn_X_index[(trn_Y==0).view(-1)].sum()>0 and trn_X_index[(trn_Y==1).view(-1)].sum()>0 :
    width = int(input(f"Cover {repetition+1}: What width do you want ?")) ## manually adjust the width
    net = polytope(width=width, positive_init=positive_init, small_norm_init=False).to(device)
    folder_name = folder_name_above + f"/Cover_{repetition+1}"
    writer = SummaryWriter(folder_name)
    
    # training
    trn_X_index_this_time = trn_X_index.clone()
    ####
    trn_X_index_this_time[(trn_Y==float(positive_init)).view(-1)] = True  #### consider all data points in the other class
    ####
    net.train(repetition=repetition, trn_X_index= trn_X_index_this_time)
    ans = input("Wanna train again ? (y/n)")
    if ans == "y":
        net.fail = True
    # after trainnig, extract the index.
    trn_X_index_in_polytope = ((net(trn_X).view(-1) * trn_X_index)  == net.bias) # boolean tensor, shape = len(dataset)
    # new index
    trn_X_index = trn_X_index * trn_X_index_in_polytope
    if net.fail:
        # there is some problem for finding a cover in this training time.
        # therefore, we skip this try and re-train.
        positive_init = not positive_init # flip the convexity of the class.
        continue
    # otherwise, it was successful.
    network_list.append(net)
    positive_init = not positive_init # flip the convexity of the class.
    repetition += 1

# result
if pt_option :
    print("\n\nCompleted to find a polytope-basis cover!!")
    print(f"There are {len(network_list)} polytopes.")
    print(network_list)

Cover 1: What width do you want ? 15


Epoch:     1 || TRN_loss: 1.5500 || TRN_ACC: 41.900 || balanced_min: 1.024 || class0_inPoly : 0 || class0_outPoly : 500 || class1_inPoly : 0 || class1_outPoly : 500


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Epoch:  1000 || TRN_loss: 46.4491 || TRN_ACC: 61.500 || balanced_min: 0.984 || class0_inPoly : 0 || class0_outPoly : 500 || class1_inPoly : 0 || class1_outPoly : 500
Epoch:  2000 || TRN_loss: 27.5455 || TRN_ACC: 50.000 || balanced_min: 0.930 || class0_inPoly : 0 || class0_outPoly : 500 || class1_inPoly : 0 || class1_outPoly : 500
Epoch:  3000 || TRN_loss: 19.2140 || TRN_ACC: 50.000 || balanced_min: 0.896 || class0_inPoly : 0 || class0_outPoly : 500 || class1_inPoly : 0 || class1_outPoly : 500
Epoch:  4000 || TRN_loss: 14.8389 || TRN_ACC: 50.000 || balanced_min: 0.879 || class0_inPoly : 0 || class0_outPoly : 500 || class1_inPoly : 0 || class1_outPoly : 500
Epoch:  5000 || TRN_loss: 12.3245 || TRN_ACC: 50.000 || balanced_min: 0.876 || class0_inPoly : 0 || class0_outPoly : 500 || class1_inPoly : 0 || class1_outPoly : 500
Epoch:  6000 || TRN_loss: 10.6392 || TRN_ACC: 50.000 || balanced_min: 0.836 || class0_inPoly : 0 || class0_outPoly : 500 || class1_inPoly : 0 || class1_outPoly : 500
	(re


KeyboardInterrupt



In [None]:
### visualize every polytope...
w = 1.6
N = 200
x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2

fig, axs = plt.subplots(int(np.ceil(len(network_list)/4)), 4, figsize=(16, 4*int(np.ceil(len(network_list)/4))), squeeze=False) # 4 figures in each row

for layer, net in enumerate(network_list):
    rank = (net(grid).view(N,N).detach().cpu() == net.bias).float()
#     for i in range(net.width):
#         axs[int(layer/4), layer%4].contour(x,y, net.activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='blue', linewidths=0.5)
    axs[int(layer/4), layer%4].contourf(x,y,rank.cpu(), cmap='Greys') ###
    axs[int(layer/4), layer%4].contour(x,y,rank.cpu(), cmap='Greens', linewidths=2) ###
    axs[int(layer/4), layer%4].set_title(f"Polytope {layer+1}, width={net.width}")
    axs[int(layer/4), layer%4].scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.', color="C0")
    axs[int(layer/4), layer%4].scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.', color="C1")
plt.suptitle(r"The polytope-basis cover")
plt.savefig(folder_name_above+"/polytopes.png")
plt.close()

########### Accuracy #########
answer = torch.zeros_like(net(trn_X))
previous_answer = torch.ones_like(net(trn_X))
for exponential, net in enumerate(network_list):
    answer += 2**exponential *(net(trn_X)==net.bias).float()*net.bias * previous_answer
    previous_answer = (net(trn_X)==net.bias).float()
acc = ( (answer>0.01).float().view(-1) == trn_Y.view(-1) ).sum().item()/n * 100
print(f"The accuracy is {acc :.3f} %")
##################################

### combined polytopes
plt.figure()
rank = torch.zeros_like(rank)
previous_rank = torch.ones_like(rank)
for exponential, net in enumerate(network_list):
    rank += 2**exponential * (net(grid).view(N,N).detach().cpu() == net.bias).float() * net.bias * previous_rank
    previous_rank = (net(grid).view(N,N).detach().cpu() == net.bias).float()
plt.contourf(x,y,(rank>0.01).float().cpu(), np.arange(-1,2,1), cmap='Greys') ###
plt.colorbar()
plt.contour(x,y,(rank>0.01).float().cpu(), cmap='Greens', linewidths=2) ###
plt.scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.', color="C0", label="class 0", s=25)
plt.scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.', color="C1", label="class 1", s=25)
plt.title("The polytope-basis cover")
plt.legend()
plt.savefig(folder_name_above+f"/polytope_cover_{acc :.3f}%.png")
# plt.show()
plt.close()