In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets


trainset = datasets.MNIST('', train=True, download=True, 
                       transform=transforms.Compose([
                                transforms.ToTensor()
                            ]))
testset = datasets.MNIST('', train=False, download=True, 
                       transform=transforms.Compose([
                                transforms.ToTensor()
                            ]))


trainloader  = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, pin_memory=True)
testloader  = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, pin_memory=True)

#-------------------------------------------------------------------------------------------
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(28 * 28, 64)  # the image size is 28 by 28
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        
        return F.log_softmax(x, dim=1)

#-------------------------------------------------------------------------------------------
def hessian_vector_product(xs, ys, vs, create_graph=False):
    """
    Performs a matrix-vector-multiplication Hv, where H is the Hessian of sum(ys) w.r.t xs.

    :xs: list of independent input parameters
    :ys: list of outputs
    :vs: List of vectors to multiply. Must match xs in length and dimensions
    
    :return: The Hessian-Vector-Product Hv as a list of the same structure as xs.
    """
    for x in xs:
        if x.grad is not None:
            x.grad.detach_()
            x.grad.zero_()

    with torch.enable_grad():
        # Compute dot product of the gradient of y w.r.t. x and v
        grads = torch.autograd.grad(ys, xs, create_graph=True)

        prods = [torch.dot(g.flatten(), v.flatten()) for g, v in zip(grads, vs)]
        prod = torch.Tensor.new_zeros(prods[0], size=(1,))
        for p in prods:
            prod = prod + p

        # Compute the gradient of the dot product, which yields the desired hessian-vector product
        return torch.autograd.grad([prod], xs, create_graph=create_graph)
    
#-------------------------------------------------------------------------------------------
net = Net() # inital network
optimizer = optim.Adam(net.parameters(), lr=0.001)  # create a Adam optimizer

net.train() # set netowrk to traning mode
epochs = 2
for epoch in range(epochs):
    for data in trainloader:
        X, y = data
        # training process
        optimizer.zero_grad()    # clear the gradient calculated previously
        predicted = net(X.view(-1, 28 * 28))  # put the mini-batch training data to Nerual Network, and get the predicted labels
        loss = F.nll_loss(predicted, y)  # compare the predicted labels with ground-truth labels
        loss.backward()      # compute the gradient
        optimizer.step()     # optimize the network
    print(f'epoch:{epoch}, loss:{loss}')
    
#-------------------------------------------------------------------------------------------
"""
model.train()" and "model.eval()" activates and deactivates Dropout and BatchNorm, so it is quite important. 
"with torch.no_grad()" only deactivates gradient calculations, but doesn't turn off Dropout and BatchNorm.
Your model accuracy will therefore be lower if you don't use model.eval() when evaluating the model.
"""
net.eval() # evaluation mode

# Evaluation the trainig data
correct = 0
total = 0
with torch.no_grad():
    for data in trainloader:
        X, y = data
        output = net(X.view(-1, 28 * 28))
        correct += (torch.argmax(output, dim=1) == y).sum().item() # 計算此次batch有多少預測正確；item()是將Tensor資料型態轉成 Python資料型態，否則Tensor型態無法與Python互相運算
        total += y.size(0) # total加上每次batch數量

print(f'Training data Accuracy: {correct}/{total} = {round(correct/total, 3)}')

# Evaluation the testing data
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        X, y = data
        output = net(X.view(-1, 28 * 28))
        correct += (torch.argmax(output, dim=1) == y).sum().item()
        total += y.size(0)

print(f'testing data Accuracy: {correct}/{total} = {round(correct/total, 3)}')

# -------------------------------------------------------------------------------------
X = testset[0][0]
y = torch.tensor([testset[0][1]])

output = net(X.view(-1, 28 * 28))
loss = F.nll_loss(output, y)
vs = [torch.rand_like(p) for p in net.parameters()]
Hv = hessian_vector_product(list(net.parameters()), [loss], vs, create_graph=True)

print('type(Hv):', type(Hv))
print('us:', Hv)