In [1]:
# import basic packages
%run SetUp.ipynb

Configuration.

In [2]:
## Configuration
Title = "Exp_CIFAR"
device = "cuda:1" if torch.cuda.is_available() else "cpu"

# create directory
output_path = f"./Results/{Title}"
create_directory(output_path)

# data setting 
DATA = "CIFAR10"
img_size = 32
n = 1000 # number of data 
d = 1*img_size*img_size # dimension of data
output_class = 10 # output dimension // trainning multi single-layer networks
y_true = 30 # true label
y_false = 0 # false label

# load data 
X, Y = data_load(DATA)
X = X.to(device) # input X is of shape n x d 
X = torchvision.transforms.Grayscale()(X.reshape(n,32,32,3).permute(0,3,1,2)).permute(0,2,3,1).view(n,-1)
Y = (y_false + (y_true-y_false)*(Y == torch.arange(10)).float()).to(device) # label Y is one hot

# UserWarning occurs, due to the torch version and torchvision version.

In [3]:
net = ReLU_Net(output_class=output_class)
net.module[0].weight = nn.Parameter(torch.rand(output_class, d)*0.0000001)
net.to(device)


# Optimization setting
Epochs = 1000*1000
lr = 0.005
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=lr)
folder_name = output_path + '/runs/' + datetime.datetime.now().strftime("%B%d_%H_%M_%S")
writer = SummaryWriter(folder_name)
# Save CFG file
f = open(f"./{folder_name}/CFG.text", 'w')
f.write(
f"""
    This is a CFG file.

    # dataset
    {DATA}
    img_size = {img_size}
    n = {n} # number of data
    y_true = {y_true}
    y_false = {y_false}
    num_class = {output_class}

    # optimization
    Epochs = {Epochs}
    lr = {lr}
"""
)
f.close()

In [None]:
# loss_tr = np.empty(0)
print("Start Training")
time.sleep(1)

log_period = 100
for epoch in tqdm(range(Epochs)) :
    loss = criterion(net(X), Y)
    loss.backward()
    optimizer.step()
    net.zero_grad()
    
    if epoch % log_period == 0 :
#         loss_tr = np.append(loss_tr, loss.item())
        writer.add_scalar("ReLU/Loss", loss.item(), epoch+1)
        writer.add_scalar("ReLU/# of SVs", (net(X)>0).prod(axis=1).sum().item(), epoch+1)
        writer.add_scalar("ReLU/# of nSVs", (net(X)==0).prod(axis=1).sum().item(), epoch+1)

In [None]:
# save non-support vectors by images
# nSVs = X[(net(X)==0).squeeze()]
nSVs_index = (net(X)==0).prod(axis=1).nonzero()[:,0]
for index in nSVs_index:
    plt.imshow(X[index].view(img_size,img_size,3).cpu().detach())
    plt.savefig(folder_name+f"/{index}.png")
    plt.close()
print(f"There are {len(nSVs_index)} deactivated data")

In [None]:
# save
torch.save(net.state_dict(), f"./{folder_name}/net.pt")