## Introduction
Notebook for testing the Fenchel-Young loss approach to training what we call PertOpt-net. All code based on 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import perturbations
import fenchel_young as fy
from torch_Dijkstra import Dijkstra

In [2]:
## Create NN using perturbed differentiable optimization
class Pert_ShortestPathNet(nn.Module):
    '''
    This net is equipped to run an m-by-m grid graphs. No A matrix is necessary.
    '''
    def __init__(self, m, context_size, device='cpu'):
        super().__init__()
        self.m = m
        self.device = device
        self.hidden_dim = 2*context_size

        ## Standard layers
        self.fc_1 = nn.Linear(context_size, self.hidden_dim)
        self.fc_12 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc_2 = nn.Linear(self.hidden_dim, self.m**2)
        self.leaky_relu = nn.LeakyReLU(0.1)
        self.relu = nn.ReLU()
        
      ## Put it all together
    def forward(self, d):
        w = self.leaky_relu(self.fc_12(self.leaky_relu(self.fc_1(d))))
        w = self.relu(self.fc_2(w))
        return w.view(w.shape[0], self.m, self.m)

In [13]:
device = 'cpu'
dijkstra = Dijkstra(euclidean_weight=False,four_neighbors=True)
criterion2 = nn.MSELoss()
criterion = fy.FenchelYoungLoss(dijkstra, num_samples=10, sigma=0.01, noise='gumbel', batched=True, maximize=False, device=device)

In [14]:
## Load data
grid_size = 5
data_path = '../shortest_path_data/Shortest_Path_training_data'+str(grid_size)+'.pth'
state = torch.load(data_path)

## Extract data from state
train_dataset = state['train_dataset_v']
test_dataset = state['test_dataset_v']
m = state["m"]
A = state["A"].float()
b = state["b"].float()
num_edges = state["num_edges"]
Edge_list = state["Edge_list"]
Edge_list_torch = torch.tensor(Edge_list)

In [15]:
net = Pert_ShortestPathNet(grid_size, context_size=5, device='cpu')

In [16]:
## Training setup
from torch.utils.data import Dataset, TensorDataset, DataLoader
learning_rate = 1e-4
test_size = 200
train_loader = DataLoader(dataset=train_dataset, batch_size=200,
                              shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=test_size,
                             shuffle=False)

optimizer = optim.Adam(net.parameters(), lr=learning_rate)

In [18]:
## Train!
import time
from utils import Edge_to_Node, compute_perfect_path_acc, compute_perfect_path_acc_vertex
max_epochs = 100
max_time = 1e10
test_loss_hist= []
test_acc_hist = []
train_time = []
train_loss_ave = 0
train_start_time = time.time()
epoch=0
epoch_time=0

while epoch <= max_epochs and epoch_time <= max_time:
    for d_batch, path_batch in train_loader:
        d_batch = d_batch.to(device)
        path_batch =path_batch.to(device)
        net.train()
        optimizer.zero_grad()
        weight_pred = net(d_batch)
        print(weight_pred[1,:,:])
        loss = criterion(weight_pred, path_batch).mean()
        path_pred = dijkstra(weight_pred)
        loss2 = criterion2(path_pred, path_batch)
        print(loss2.item())
        train_loss_ave = 0.95*train_loss_ave + 0.05*loss2.item()
        loss.backward()
        optimizer.step()

    # print('epoch:', epoch, ', av. training loss = ', train_loss_ave)
    epoch_time = time.time() - train_start_time
    train_time.append(epoch_time)

    # Evaluate progress on test set. (note one batch is entire dataset)
    net.eval()
    for d_batch, path_batch in test_loader:
        d_batch = d_batch.to(device)
        path_batch =path_batch.to(device)
        weight_pred = net(d_batch)
        path_pred = dijkstra(weight_pred)
        test_loss = criterion2(weight_pred, path_batch).mean().item()
        test_loss_hist.append(test_loss)
        for param in net.parameters():
            print(param.grad)
        # print('epoch: ', epoch, 'test loss is ', test_loss)
        ## Evaluate accuracy
        accuracy = compute_perfect_path_acc_vertex(path_pred, path_batch)
        # regret = compute_regret(WW, d_batch, path_batch, path_pred,'V', Edge_list, grid_size, device)
        # print('epoch: ', epoch, 'accuracy is ', accuracy)
        test_acc_hist.append(accuracy)

    # if test_loss < best_loss:
    #     best_params = net.state_dict()

    print('epoch: ', epoch, '| ave_tr_loss: ', "{:5.2e}".format(train_loss_ave), '| te_loss: ', "{:5.2e}".format(test_loss), '| acc.: ', "{:<7f}".format(accuracy), '| lr: ', "{:5.2e}".format(optimizer.param_groups[0]['lr']), '| time: ', "{:<15f}".format(epoch_time))

tensor([[0.1263, 0.0000, 0.0658, 0.1584, 0.0000],
        [0.0000, 0.1869, 0.0000, 0.2668, 0.0000],
        [0.0000, 0.1536, 0.1520, 0.2397, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2100, 0.2366],
        [0.2252, 0.0000, 0.2065, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.43880000710487366
tensor([[0.1488, 0.0000, 0.0199, 0.1837, 0.0000],
        [0.0000, 0.2182, 0.0000, 0.2441, 0.0000],
        [0.0000, 0.2406, 0.1279, 0.2232, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.1999, 0.2438],
        [0.1761, 0.0000, 0.2053, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.4424000084400177
tensor([[0.1343, 0.0000, 0.0433, 0.1853, 0.0000],
        [0.0000, 0.2191, 0.0000, 0.2385, 0.0000],
        [0.0000, 0.1908, 0.1212, 0.2416, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.1888, 0.2394],
        [0.1810, 0.0000, 0.1918, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.4580000042915344
tensor([[0.1390, 0.0000, 0.0603, 0.2031, 0.0000],
        [0.0000, 0.2414, 0.0000, 0.2014, 0.0000],
        [0.

0.4472000002861023
tensor([[0.1667, 0.0000, 0.0501, 0.1910, 0.0000],
        [0.0000, 0.2493, 0.0000, 0.1458, 0.0000],
        [0.0000, 0.1938, 0.0894, 0.2230, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2157, 0.2157],
        [0.1602, 0.0000, 0.1360, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.4503999948501587
tensor([[0.1528, 0.0000, 0.0354, 0.1795, 0.0000],
        [0.0000, 0.2306, 0.0000, 0.2003, 0.0000],
        [0.0000, 0.2163, 0.1073, 0.2230, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2114, 0.2242],
        [0.1736, 0.0000, 0.1764, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.43160000443458557
tensor([[0.1530, 0.0000, 0.0798, 0.2337, 0.0000],
        [0.0000, 0.2830, 0.0000, 0.1233, 0.0000],
        [0.0000, 0.1092, 0.0545, 0.2674, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.1501, 0.2267],
        [0.1254, 0.0000, 0.0990, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.44760000705718994
tensor([[0.1465, 0.0000, 0.0689, 0.1785, 0.0000],
        [0.0000, 0.2245, 0.0000, 0.1769, 

0.44999998807907104
tensor([[0.1482, 0.0000, 0.0136, 0.1584, 0.0000],
        [0.0000, 0.1876, 0.0000, 0.2553, 0.0000],
        [0.0000, 0.2365, 0.1439, 0.2124, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2104, 0.2390],
        [0.1956, 0.0000, 0.2272, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.44519999623298645
tensor([[0.1504, 0.0000, 0.0207, 0.1673, 0.0000],
        [0.0000, 0.1972, 0.0000, 0.2260, 0.0000],
        [0.0000, 0.2144, 0.1374, 0.2177, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2047, 0.2362],
        [0.1822, 0.0000, 0.2034, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.43639999628067017
tensor([[0.1447, 0.0000, 0.0164, 0.1703, 0.0000],
        [0.0000, 0.2086, 0.0000, 0.2467, 0.0000],
        [0.0000, 0.2378, 0.1254, 0.2185, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2070, 0.2350],
        [0.1823, 0.0000, 0.2162, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)
0.44920000433921814
tensor([[0.1416, 0.0000, 0.0336, 0.1980, 0.0000],
        [0.0000, 0.2243, 0.0000, 0.2200

KeyboardInterrupt: 

In [19]:
print(weight_pred[2,:,:])

tensor([[0.1435, 0.0000, 0.0082, 0.1712, 0.0000],
        [0.0000, 0.2040, 0.0000, 0.2673, 0.0000],
        [0.0000, 0.2524, 0.1311, 0.2151, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2081, 0.2429],
        [0.1842, 0.0000, 0.2307, 0.0000, 0.0000]], grad_fn=<SliceBackward0>)


In [20]:
print(path_batch[2,:,:])

tensor([[1., 1., 0., 0., 0.],
        [0., 1., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 1., 1.]])
