## Introduction
Testing the approach described in [this](https://arxiv.org/abs/1912.02175/) paper

In [4]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time as time
import blackbox_backprop as bb

import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, TensorDataset, DataLoader
import time as time

In [5]:
class bb_net(nn.Module):
    '''
    This net is equipped to run an m-by-m grid graphs. No A matrix is necessary.
    '''
    def __init__(self, m, context_size, device='cpu'):
        super().__init__()
        self.m = m
        self.device = device
        self.hidden_dim = 2*context_size
        # self.shortestPath = bb.ShortestPath()

        ## Standard layers
        self.fc_1 = nn.Linear(context_size, self.hidden_dim)
        self.fc_2 = nn.Linear(self.hidden_dim, self.m**2)
        self.leaky_relu = nn.LeakyReLU(0.1)
        
    def forward(self, d):
        w = self.leaky_relu(self.fc_1(d))
        w = self.fc_2(w)
        suggested_weights = w.view(w.shape[0], self.m, self.m)
        suggested_shortest_paths = bb.ShortestPath.apply(suggested_weights, 5.0)
        
        return suggested_shortest_paths

In [11]:
BB_net = bb_net(5, 5)

In [12]:
## Load data
grid_size = 5
data_path = './shortest_path_data/Shortest_Path_training_data'+str(grid_size)+'.pth'
state = torch.load(data_path)

## Extract data from state
train_dataset_e = state['train_dataset_e']
test_dataset_e = state['test_dataset_e']
train_dataset_v = state['train_dataset_v']
test_dataset_v = state['test_dataset_v']
m = state["m"]
A = state["A"].float()
b = state["b"].float()
WW = state["WW"].float()
num_edges = state["num_edges"]
Edge_list = state["Edge_list"]
Edge_list_torch = torch.tensor(Edge_list)

In [13]:
## Training setup
train_dataset = train_dataset_v
test_dataset = test_dataset_v
net = BB_net
learning_rate = 1e-3

test_size = 200
train_loader = DataLoader(dataset=train_dataset, batch_size=100,
                              shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=test_size,
                             shuffle=False)

optimizer = optim.SGD(net.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'min')
criterion = nn.MSELoss()

## Initialize arrays that will be returned.
test_loss_hist= []
test_acc_hist = []
train_time = [0]
train_loss_ave = 0
max_time = 3600

In [14]:
class HammingLoss(torch.nn.Module):
    def forward(self, suggested, target):
        errors = suggested * (1.0 - target) + (1.0 - suggested) * target
        return errors.mean(dim=0).sum()
        # return (torch.mean(suggested*(1.0-target)) + torch.mean((1.0-suggested)*target)) * 25.0


In [15]:
# This is the lost function used in the original paper. We can try it, but it doesn't seem to 
# perform significantly better then nn.MSE
# criterion = HammingLoss()

In [16]:
## Train!
max_epochs = 100
device = 'cuda:0'
net.to(device)
hammingLoss = HammingLoss()

for epoch in range(max_epochs):
    net.train()
    for d_batch, path_batch in train_loader:
            d_batch = d_batch.to(device)
            path_batch =path_batch.to(device)
            optimizer.zero_grad()
            suggested_shortest_paths = net(d_batch)
            # suggested_shortest_paths = shortestPath.apply(suggested_weights, 5.0) # Set the lambda hyperparameter
            # print(suggested_shortest_paths.shape)
            # loss = hammingLoss(suggested_shortest_paths, path_batch)
            loss = criterion(suggested_shortest_paths, path_batch)
            train_loss_ave = loss.item()
            loss.backward()
            optimizer.step()
    net.eval()
    for d_batch, path_batch in test_loader:
        d_batch = d_batch.to(device)
        path_batch =path_batch.to(device)
        optimizer.zero_grad()
        suggested_shortest_paths = net(d_batch)
        loss = criterion(suggested_shortest_paths, path_batch)
        test_loss = loss.item()
        scheduler.step(test_loss)
        print('epoch:', epoch, ', test_loss = ', test_loss)

epoch: 0 , test_loss =  7.539999961853027
epoch: 1 , test_loss =  7.169999599456787
epoch: 2 , test_loss =  7.039999961853027
epoch: 3 , test_loss =  6.979999542236328
epoch: 4 , test_loss =  6.920000076293945
epoch: 5 , test_loss =  6.899999618530273
epoch: 6 , test_loss =  6.899999618530273
epoch: 7 , test_loss =  6.880000114440918
epoch: 8 , test_loss =  6.869999885559082
epoch: 9 , test_loss =  6.869999885559082
epoch: 10 , test_loss =  6.869999885559082
epoch: 11 , test_loss =  6.869999885559082
epoch: 12 , test_loss =  6.869999885559082
epoch: 13 , test_loss =  6.869999885559082
epoch: 14 , test_loss =  6.859999656677246
epoch: 15 , test_loss =  6.859999656677246
epoch: 16 , test_loss =  6.859999656677246
epoch: 17 , test_loss =  6.84999942779541
epoch: 18 , test_loss =  6.84999942779541
epoch: 19 , test_loss =  6.84999942779541
epoch: 20 , test_loss =  6.84999942779541
epoch: 21 , test_loss =  6.84999942779541
epoch: 22 , test_loss =  6.84999942779541
epoch: 23 , test_loss =  6.

KeyboardInterrupt: 

In [None]:
d_batch, path_batch = next(iter(test_loader))
d_batch = d_batch.to(device)
path_batch = path_batch.to(device)
pred_batch = BB_net(d_batch)

In [None]:
import matplotlib.pyplot as plt
WW = WW.to(device)
sp = bb.ShortestPath()

for i in range(10):
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.matshow(pred_batch[i,:,:].cpu().detach().numpy())
    ax2.matshow(path_batch[i,:,:].cpu().detach().numpy())
    #plt.matshow(pred_batch[i,:,:].cpu().detach().numpy())
    #plt.show()
    #plt.matshow(path_batch[i,:,:].cpu().detach().numpy())
    #plt.show()
    #weights = torch.matmul(WW, d_batch[i]).view((1, 5, 5))
    # print(weights)
    #pred_sp = sp.apply(weights,20)
    #plt.matshow(pred_sp[0,:,:].cpu().detach().numpy())
    #plt.show()
    

In [None]:
WW.device