In [1]:
# Classification, Intrinsic dimension, Degeneracy
%run SetUp.ipynb
%run synthetic_datasets.ipynb

## Configuration 
Title = "Synthetic_datasets" # general two layer
device = "cuda:1"
pt_option=True
col = 1 # for visualization, figure 에 column 을 얼마로 할건지 ?
num_polytopes = (1*col, 1*col)
width = 8
dataset = "Two circles" # Swiss / Two circles / Two moons / XOR
d = 2
n = 1000 # used number of data
num_classes = 2

## load datasets ##
trn_X, trn_Y = load_synthetic_dataset(dataset)
Epochs = 150*1000
start_pruning = 10000
pruning_period = 2000
lambda_rescale = 1.2
positive_init = False
small_norm_init = True

# create directory
output_path = f"./Results/{Title}"
create_directory(output_path)

Network Architecture

In [2]:
# Network setting # d -> d1 -> 1
class polytope(nn.Module):
    def __init__(self, width, output_class=1, positive_init=True, small_norm_init=True):
        super(polytope, self).__init__()
        self.output_class = output_class
        self.fc0 = nn.Linear(d, width)
        self.fc1 = nn.Linear(width, output_class, bias=False)
        self.width = width # width
        self.bias = (1 - 2*positive_init)
        self.positive_init = positive_init
        self.output_class = output_class
        self.small_norm_init = small_norm_init
        self.device = device
        
        
        if positive_init:
            # initialization, to all v_k are negative.
            self.fc1.weight = nn.Parameter(  torch.sqrt((self.W(layer=0).norm(dim=1)**2 + self.b(0)**2+1).view(1,-1))  )
        else:
            # initialization, to all v_k are negative.
            self.fc1.weight = nn.Parameter(  -torch.sqrt((self.W(layer=0).norm(dim=1)**2 + self.b(0)**2+1).view(1,-1))  )
        
        if small_norm_init:
            self.change_layer_weights(layer=0, W=1.5/self.width*torch.randn_like(self.W(0)), b=0.6/self.width*(self.b(0))) # multiply 0.01 
        
    def forward(self, x):
        self.g1 = self.fc0(x)
        self.h1 = F.relu(self.g1)
        self.g2 = self.fc1(self.h1) + self.bias
        return self.g2
    
    def W(self, layer):
        if layer ==0:
            output = self.fc0.weight
        elif layer ==1:
            output = self.fc1.weight
        return output.clone()
    def b(self, layer):
        if layer ==0:
            output = self.fc0.bias
        elif layer ==1:
            output = self.fc1.bias
        return output if output == None else output.clone()
    
    def activation_pattern(self, x):
        self.forward(x)
        return (self.h1>0).float()
    
    def change_layer_weights(self, layer, W, b):
        if self.b(layer) == None : # no bias term
            if W.shape == self.W(layer).shape and self.b(layer) == b:
                if layer ==0:
                    self.fc0.weight = nn.Parameter(W.to(self.device))
                elif layer ==1:
                    self.fc1.weight = nn.Parameter(W.to(self.device))
            else:
                raise ValueError("wrong shape of input tensors")
        else:
            if W.shape == self.W(layer).shape and b.shape == self.b(layer).shape:
                if layer ==0:
                    self.fc0.weight = nn.Parameter(W.to(self.device))
                    self.fc0.bias = nn.Parameter(b.to(self.device))
                elif layer ==1:
                    self.fc1.weight = nn.Parameter(W.to(self.device))
    #                 self.fc1.bias = nn.Parameter(b)
            else:
                raise ValueError("wrong shape of input tensors")
    
    def partition(self, w=1.6, title="img"):
        N = 200
        x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
        grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        rank = (self.forward(grid.view(-1,2))<0.001).float().view(N,N)                     # decision boundary, output<0
        plt.contourf(x,y,rank.cpu(), np.arange(-1,2,1), cmap='gray')
        #     plt.contourf(x,y,rank.cpu(), np.arange(-1,2,0.1), cmap='RdBu_r')
        plt.colorbar()
        for i in range(self.width):
            plt.contour(x,y, self.activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='blue', linewidths=0.5)

#         black_patch = mpatches.Patch(color='black', label='Rank 0')
#         gray_patch = mpatches.Patch(color='gray', label='Rank 1')
#         blue_patch = mpatches.Patch(color='blue', label='1st layer')
#         plt.legend(handles=[blue_patch])
        # dataset
        plt.scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.')
        plt.scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.')
        plt.title(fr"Polytope $P_1$; width={self.width}")
        plt.savefig(folder_name+f"/{title}.png")
        plt.show()
    
    
    def pruning(self, subnetwork_index=0):
        dataset = trn_X # the dataset.
        
        ## remove a neuron which deactivates all data. # STEP 1
        width_before = self.width
        dataset_activation_pattern = self.activation_pattern(dataset).sum(dim=0)
        if dataset_activation_pattern.max() == len(dataset) or dataset_activation_pattern.min() == 0 :  # either all or none of data points.
            W0 = self.W(layer=0).detach()
            b0 = self.b(layer=0).detach()
            W1 = self.W(layer=1).detach()
            activation_filter = ((self.activation_pattern(dataset).sum(dim=0) < len(dataset)).float()
                                 * (self.activation_pattern(dataset).sum(dim=0) > 0).float()
                                 * (self.W(layer=1).abs()>0.1).float().view(-1)
                                )
            # change polytope width 
            self.width = activation_filter.sum().long().item()
            if self.width > 0 :
                width_after = self.width
                self.fc0 = nn.Linear(d, self.width)
                self.fc1 = nn.Linear(self.width, self.output_class, bias=False)
                self.change_layer_weights(layer=0, W=W0[(activation_filter == 1)], b=b0[activation_filter == 1])
                self.change_layer_weights(layer=1, W=W1.view(-1)[activation_filter == 1].view(1,-1), b=None)
                text = "(removing)"
                print(f"\t{text:<15}The subnetwork {subnetwork_index} is prunned, width : {width_before} --> {width_after}") if pt_option else None
            else:
                text = "(removing)"
                print(f"\t{text:<15}The subnetwork {subnetwork_index} has been completely prunned,") if pt_option else None
            
        ## merging vectors with the same activation patterns ## # STEP 2
        width_activation_pattern = self.activation_pattern(dataset) # len(dataset) x width
        width_activation_pattern_unique = width_activation_pattern.unique(dim=1).t()
        width_before = self.width
        if len(width_activation_pattern_unique) < self.width :
            self.width = len(width_activation_pattern_unique) 
            width_after = self.width
            W0 = self.W(layer=0).detach()
            b0 = self.b(layer=0).detach()
            W1 = self.W(layer=1).detach()
            self.fc0 = nn.Linear(d, self.width)
            self.fc1 = nn.Linear(self.width, self.output_class, bias=False)
            self.to(device)
            # build weight matrices
            new_W0 = self.W(layer=0).detach()
            new_b0 = self.b(layer=0).detach()
            new_W1 = self.W(layer=1).detach()
            for index, pattern in enumerate(width_activation_pattern_unique):
                new_W0[index] = W0[(width_activation_pattern.t() == pattern).prod(dim=1)==1].sum(dim=0)
                new_b0[index] = b0[(width_activation_pattern.t() == pattern).prod(dim=1)==1].sum(dim=0)
                new_W1[0][index] = W1[0][(width_activation_pattern.t() == pattern).prod(dim=1)==1].sum(dim=0)
            self.change_layer_weights(layer=0, W=new_W0, b=new_b0)
            self.change_layer_weights(layer=1, W=new_W1, b=None)
            text = "(merging)"
            print(f"\t{text:<15}The subnetwork {subnetwork_index} is merged, width : {width_before} --> {width_after}") if pt_option else None
            
            
        ###### original ####### # STEP 3
        smallest_norm_neuron_index = self.width # not index yet.
        smallest_norm = np.inf
        width_activation_pattern = self.activation_pattern(dataset) # len(dataset) x width
        _count = 0
        for neuron_index in range(self.width): # for loop
            activated_index = (width_activation_pattern[:, neuron_index] == 1)
            # index 가 activated 된 data들의 activation pattern ## 이게 2 이상이면 제거해버리면 됨 !
            if self.activation_pattern(dataset[activated_index]).sum(dim=1).min().item() >=2 : # 항상 다른게 activate 되어있다.
                # this neuron can be removed.
                _count += 1
                this_round_norm = (self.W(layer=0)[neuron_index].norm() * self.W(layer=1).squeeze()[neuron_index]).item() ## v||w||
#                 this_round_norm = self.W(layer=1).squeeze()[neuron_index].item() ## v
                if this_round_norm < smallest_norm:
                    smallest_norm_neuron_index = neuron_index
                    smallest_norm = this_round_norm
        if smallest_norm_neuron_index < self.width:
            # remove the small-norm redundant neuron, remove only one neuron at once
            width_before = self.width
            width_after = self.width - 1
            W0 = self.W(layer=0).detach()
            b0 = self.b(layer=0).detach()
            W1 = self.W(layer=1).detach()
            self.fc0 = nn.Linear(d, width_after)
            self.fc1 = nn.Linear(width_after, self.output_class, bias=False)
            self.to(device)        
            new_W0 = torch.cat((W0[:smallest_norm_neuron_index], W0[smallest_norm_neuron_index+1:]), dim=0)
            new_b0 = torch.cat((b0[:smallest_norm_neuron_index], b0[smallest_norm_neuron_index+1:]), dim=0)
            new_W1 = torch.cat((W1[:,:smallest_norm_neuron_index], W1[:,smallest_norm_neuron_index+1:]), dim=1)
            self.change_layer_weights(layer=0, W=new_W0, b=new_b0)
            self.change_layer_weights(layer=1, W=new_W1, b=None)
            self.width = width_after
            text = "(pruning)"
            print(f"\t{text:<15}In the subnetwork {subnetwork_index}, there are {_count} redundant neurons, and one has been removed") if pt_option else None
                

    def rescaling_the_boundary_neurons(self, subnetwork_index=0, rescale=1.1): # STEP 4
        dataset= trn_X
        intermediate_index = ( ( F.relu(self.forward(trn_X)) != 0).float() * (self.forward(trn_X) != self.bias).float() >0).squeeze()
        if intermediate_index.float().sum() >0 :
            activation_pattern_of_intermediate_data = self.activation_pattern(trn_X[intermediate_index])
            if activation_pattern_of_intermediate_data.max().item() >0 : ## max
                new_W0 = self.W(layer=0).detach()
                new_b0 = self.b(layer=0).detach()
                new_W1 = self.W(layer=1).detach()
                true_or_false = activation_pattern_of_intermediate_data.sum(dim=0)>0
                for i, boolean in enumerate(true_or_false):
                    if boolean:
                        new_W0[i] = new_W0[i] * rescale
                        new_b0[i] = new_b0[i] * rescale
                        new_W1[0][i] = new_W1[0][i] * rescale
                self.change_layer_weights(layer=0, W=new_W0, b=new_b0)
                self.change_layer_weights(layer=1, W=new_W1, b=None)
                text = "(rescaling)"
                print(f"\t{text:<15}In the subnetwork {subnetwork_index}, {true_or_false.sum().item()} neurons are rescaled by {rescale}") if pt_option else None
                
    
    def convexity(self):
        convexity = (self.W(1).sgn() == self.bias).float().sum() # if convex, this is zero. either one.
        return convexity

In [3]:
# multiple polytopes : polytope basis-cover
class cover(nn.Module):
    def __init__(self, width, num_polytopes, output_class=1, positive_init=True, small_norm_init=True):
        super(cover, self).__init__()
        assert len(num_polytopes) == 2, "must be a tuple of two integers"
        self.num_polytopes = num_polytopes # (+_polytope, -_polytope)
        self.width = width
        self.device = device
        self.module_list = nn.ModuleList() # list of polytopes
        for layer in range(num_polytopes[0] + num_polytopes[1]):
            self.module_list.append(polytope(width=width, positive_init=positive_init, small_norm_init=small_norm_init))
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        
    def forward(self, x, partial_sum=None):
        positive_output = 0
        for layer in range(num_polytopes[0]): 
            positive_output += F.relu(self.module_list[layer](x))
        negative_output = 0
        for layer in range(num_polytopes[0], num_polytopes[0]+num_polytopes[1]):  # negative polytope
            negative_output += F.relu(self.module_list[layer](x))
        if partial_sum == 'positive':
            output = positive_output
        elif partial_sum == 'negative':
            output = negative_output
        else:
            output = positive_output - negative_output ## (+polytopes) - (-polytopes) #### three-layer network architecture.
        return output
    
    def activation_pattern(self, x):
        output = torch.empty(0).to(self.device)
        for layer in range(num_polytopes[0]): 
            output = torch.cat((output, self.module_list[layer].activation_pattern(x)))
        return output
    
    def partition(self, w=1.6, title="img"):
        N = 200
        x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
        grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2
        rank = (net(grid.view(-1,2))<0.001).float().view(N,N)                     # decision boundary, output<0
        plt.contourf(x,y,rank.cpu(), np.arange(-1,2,1), cmap='gray')
        plt.colorbar()
        # dataset
        plt.scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.')
        plt.scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.')
        plt.title(f"Decision Boundary, Epoch: {self.epoch+1}")
        plt.savefig(folder_name+f"/DB_{title}.png")

        for layer in range(num_polytopes[0]): 
            for i in range(self.module_list[layer].width):
                plt.contour(x,y, self.module_list[layer].activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='blue', linewidths=0.5)
        for layer in range(num_polytopes[0], num_polytopes[0]+num_polytopes[1]): 
            for i in range(self.module_list[layer].width):
                plt.contour(x,y, self.module_list[layer].activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i], levels=1, colors='red', linewidths=0.5)

        black_patch = mpatches.Patch(color='black', label='DB')
        blue_patch = mpatches.Patch(color='blue', label=r'$a_j=+1$')
        red_patch = mpatches.Patch(color='red', label=r'$a_j=-1$')
        plt.legend(handles=[blue_patch, red_patch])
        plt.title(f"AB and DB, Epoch: {self.epoch+1}")
        plt.savefig(folder_name+f"/{title}.png")
#         plt.show()
        plt.close()
    
    def pruning(self):
        # for each polytope in the network, prun the useless widths.
        for subnetwork_index, polytope in enumerate(self.module_list):
            polytope.pruning(subnetwork_index=subnetwork_index)
        self.to(self.device)
        self.optimizer = optim.Adam(self.parameters(), lr=lr) # net. paramgeter 를 optimizer 가능하도록 넣어줘야해...!!
        
    def rescaling_the_boundary_neurons(self, rescale=1.1):
        # for each polytope in the network, prun the useless widths.
        for subnetwork_index, polytope in enumerate(self.module_list):
            polytope.rescaling_the_boundary_neurons(subnetwork_index=subnetwork_index, rescale=rescale)
        self.to(self.device)
        self.optimizer = optim.Adam(self.parameters(), lr=lr) # net. paramgeter 를 optimizer 가능하도록 넣어줘야해...!!
        
    def convexity_check(self):
        for subnetwork_index, polytope in enumerate(self.module_list):
            assert polytope.convexity() == 0, f"The subnetwork {subnetwork_index} is not convex"
            pass
    
    # Save result txt file.
    def save_result_txt(self):
        # Save result txt file.
        f = open(f"./{folder_name}/result.txt", 'w')
        f.write(
        f"""
            This is a CFG file.

            # Accuracy
            Trn_acc : {self.trn_acc_tr[-1] :.3f}
            Test_acc : {self.test_acc_tr[-1] :.3f}

            # Loss
            Trn_loss : {self.trn_loss_tr[-1].item() :.4f} 
            Test_loss : {self.test_loss_tr[-1].item() :.4f} 


            # dataset
            dataset = {dataset}
            n = {n} # number of data
            d = {d} # dimension
            num_class = {num_classes} # the class, used in this binary classification
            num_polytopes = f{num_polytopes} 
            width = f{width}
            # Network
            {net}

            # optimization
            Epochs = {Epochs}
            lr = {lr}
            optimizer = {optimizer}
            Last epoch = {self.epoch+1}
        """
        )
        f.close()
        
        
    def save_loss_graph(self):
        # plot and save the graphs
        # PLOT THE LOSS GRAPH
        plt.plot(self.trn_loss_tr, label="Train")
        plt.plot(self.test_loss_tr, label="Test")
        plt.xlabel("Iterations (k)")
        plt.ylabel("Loss")
        # plt.yscale("log")
        plt.legend()
        plt.savefig(folder_name+"/loss_fig_fullview.png")
        plt.show() if pt_option else None
        plt.close()

        # PLOT THE ACCURACY GRAPH
        plt.plot(self.trn_acc_tr, label="Train")
        plt.plot(self.test_acc_tr, label="Test")
        plt.xlabel("Iterations (k)")
        plt.ylabel("Accuracy")
        plt.legend()
        plt.savefig(folder_name+"/Accuracy.png")
        plt.show() if pt_option else None
        plt.close()
    
    
    ### visualize every polytope...
    def partition_polytopes(self, title=""):
        w=1.6
        N = 200
#         col = 3
        x,y = torch.meshgrid(torch.linspace(-w,w,N), torch.linspace(-w,w,N))
        grid = torch.stack((x,y),dim=2).to(device).float() # SHAPE 10,10,2

        # positive polytopes
        fig, axs = plt.subplots(int(np.ceil(num_polytopes[0]/col)), col, figsize=(4*col, 4*int(np.ceil(num_polytopes[0]/col))), squeeze=False) # col figures in each row
        for layer in range(num_polytopes[0]):
            rank = F.relu(self.module_list[layer](grid).view(N,N).detach().cpu())
            for i in range(self.module_list[layer].width):
                axs[int(layer/col), layer%col].contour(x,y, self.module_list[layer].activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i],
                                                   levels=1, colors='blue', linewidths=0.5)
                axs[int(layer/col), layer%col].contourf(x,y,rank.cpu(), cmap='gray')
                axs[int(layer/col), layer%col].set_title(fr"Polytope $P_{layer+1}$, width={self.module_list[layer].width}")
                axs[int(layer/col), layer%col].scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.', color="C0")
                axs[int(layer/col), layer%col].scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.', color="C1")
            if self.module_list[layer].width == 0:
                axs[int(layer/col), layer%col].set_title(fr"Polytope $P_{layer+1}$, width={self.module_list[layer].width}")
                axs[int(layer/col), layer%col].scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.', color="C0")
                axs[int(layer/col), layer%col].scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.', color="C1")
                axs[int(layer/col), layer%col].set_xlim([-w, w])
                axs[int(layer/col), layer%col].set_ylim([-w, w])
        plt.suptitle(fr"The shape of polytopes in the group $P$, epoch={title}")
        plt.savefig(folder_name+f"/positive_AB_{title}.png")
    #     plt.show()
        plt.close()

        ## negative polytopes
        fig, axs = plt.subplots(int(np.ceil(num_polytopes[1]/col)), col, figsize=(4*col, 4*int(np.ceil(num_polytopes[1]/col))), squeeze=False) # col figures in each row
        for layer in range(num_polytopes[1]): 
            rank = F.relu(self.module_list[num_polytopes[0]+layer](grid).view(N,N).detach().cpu())
            for i in range(self.module_list[num_polytopes[0]+layer].width):
                axs[int(layer/col), layer%col].contour(x,y, self.module_list[num_polytopes[0]+layer].activation_pattern(grid.view(-1,2)).float().cpu().view(N,N,-1)[:,:,i],
                                                       levels=1, colors='red', linewidths=0.5)
                axs[int(layer/col), layer%col].contourf(x,y,rank.cpu(), cmap='gray')
                axs[int(layer/col), layer%col].set_title(fr"Polytope $Q_{layer+1}$, width={self.module_list[num_polytopes[0]+layer].width}")
                axs[int(layer/col), layer%col].scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.', color="C0")
                axs[int(layer/col), layer%col].scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.', color="C1")
            if self.module_list[num_polytopes[0]+layer].width == 0:
                axs[int(layer/col), layer%col].set_title(fr"Polytope $Q_{layer+1}$, width={self.module_list[num_polytopes[0]+layer].width}")
                axs[int(layer/col), layer%col].scatter(trn_X[:int(n/2),0].cpu(),trn_X[:int(n/2),1].cpu(), marker='.', color="C0")
                axs[int(layer/col), layer%col].scatter(trn_X[int(n/2):,0].cpu(),trn_X[int(n/2):,1].cpu(), marker='.', color="C1")
                axs[int(layer/col), layer%col].set_xlim([-w, w])
                axs[int(layer/col), layer%col].set_ylim([-w, w])
        plt.suptitle(fr"The shape of polytopes in the group $Q$, epoch={title}")
        plt.savefig(folder_name+f"/negativie_AB_{title}.png")
        plt.close()


    def train(self):
        # Training
        self.trn_loss_tr = np.empty(0)
        self.test_loss_tr = np.empty(0)
        self.trn_acc_tr = np.empty(0)
        self.test_acc_tr = np.empty(0)


        test_acc = 0
        test_loss = torch.zeros(1)
        print("Start Training")
        time.sleep(1)

        log_period = 1000
        for epoch in range(Epochs) :
            self.epoch=epoch
#             loss = criterion(self.forward(trn_X, partial_sum='positive'), trn_Y)
#             index = (self.forward(trn_X, partial_sum='positive') == self.module_list[0].bias ).view(-1)
#             loss += criterion(self.forward(trn_X[index], partial_sum='negative'), trn_Y[index])
            
            loss = criterion(self.forward(trn_X), trn_Y)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()


            if epoch ==0 or epoch % pruning_period == (pruning_period-1) : # perturb every 50000 steps
                self.trn_loss_tr = np.append(self.trn_loss_tr, loss.item())
                self.test_loss_tr = np.append(self.test_loss_tr, test_loss.item())

                trn_acc = ((net(trn_X)>0).float() == trn_Y).sum().item() / len(trn_Y) *100
                self.trn_acc_tr = np.append(self.trn_acc_tr, trn_acc)
                self.test_acc_tr = np.append(self.test_acc_tr, test_acc)
                writer.add_scalars("ReLU/Loss", {
                    "Trn_Loss": loss.item(),
                    "Test_Loss": test_loss.item(), }
                                   , epoch+1)
                writer.add_scalars("ReLU/Accuracy", {
                    "Trn_acc": trn_acc,
                    "Test_acc": test_acc, }
                                   , epoch+1)
#                 if pt_option :
                print(f"Epoch: {epoch+1 :>5} || TRN_loss: {loss.item() :.4f} || TEST_loss: {test_loss.item():.4f} || TRN_ACC: {trn_acc:.3f} || TEST_ACC: {test_acc:.3f}")
                torch.save(net.state_dict(), folder_name+f"/saved_net_width_{width}.pt") 
                
                if epoch > start_pruning : ### finetuning
                    print("Strat Finetuning")  if epoch==start_pruning else None
#                     self.pruning() #### Pruning !!!!! ######
                    self.rescaling_the_boundary_neurons(rescale=lambda_rescale) #### move neurons to boundary #######
                self.partition(title=f"{epoch+1}") ### save partition images
                self.partition_polytopes(title=f"{epoch+1}") ### save each polytopes separately
                self.convexity_check() # raise error if not every polytope is convex
                
        self.save_loss_graph()
        self.save_result_txt()
        time.sleep(1)
        print(f"===================================   Training finished, Repetition: {repetition+1}  ================================== \n\n") if pt_option else None

In [4]:
# Optimization setting
lr = 0.0001
net = cover(width=width, num_polytopes=num_polytopes, positive_init=positive_init, small_norm_init=small_norm_init).to(device)

criterion = nn.BCEWithLogitsLoss()
folder_name = output_path + f'/runs/{dataset}/' + datetime.datetime.now().strftime("%B%d_%H_%M_%S")
writer = SummaryWriter(folder_name)
plt.scatter(trn_X[:int(n/2),0].cpu(), trn_X[:int(n/2),1].cpu(), label='class 1')
plt.scatter(trn_X[int(n/2):,0].cpu(), trn_X[int(n/2):,1].cpu(), label='class 0')
plt.legend()
plt.title(f"{dataset} dataset") 
plt.savefig(folder_name+"/0_dataset.png") 
# plt.show()
plt.close()

### training ######
net.train()

Start Training
Epoch:     1 || TRN_loss: 0.6930 || TEST_loss: 0.0000 || TRN_ACC: 50.900 || TEST_ACC: 0.000


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Epoch:  2000 || TRN_loss: 0.4982 || TEST_loss: 0.0000 || TRN_ACC: 84.800 || TEST_ACC: 0.000
Epoch:  4000 || TRN_loss: 0.4578 || TEST_loss: 0.0000 || TRN_ACC: 93.300 || TEST_ACC: 0.000
Epoch:  6000 || TRN_loss: 0.4466 || TEST_loss: 0.0000 || TRN_ACC: 97.800 || TEST_ACC: 0.000
Epoch:  8000 || TRN_loss: 0.4430 || TEST_loss: 0.0000 || TRN_ACC: 99.300 || TEST_ACC: 0.000
Epoch: 10000 || TRN_loss: 0.4417 || TEST_loss: 0.0000 || TRN_ACC: 99.800 || TEST_ACC: 0.000
Epoch: 12000 || TRN_loss: 0.4412 || TEST_loss: 0.0000 || TRN_ACC: 99.900 || TEST_ACC: 0.000
	(rescaling)    In the subnetwork 0, 8 neurons are rescaled by 1.2
	(rescaling)    In the subnetwork 1, 8 neurons are rescaled by 1.2
Epoch: 14000 || TRN_loss: 0.4406 || TEST_loss: 0.0000 || TRN_ACC: 100.000 || TEST_ACC: 0.000
	(rescaling)    In the subnetwork 1, 8 neurons are rescaled by 1.2
Epoch: 16000 || TRN_loss: 0.4404 || TEST_loss: 0.0000 || TRN_ACC: 99.900 || TEST_ACC: 0.000
	(rescaling)    In the subnetwork 1, 8 neurons are rescaled by

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x00000232668A5B80>>
Traceback (most recent call last):
  File "C:\Users\User\AppData\Local\Programs\Python\Python312\Lib\site-packages\ipykernel\ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


KeyboardInterrupt: 