# Practice Using torch.func with Logistic Regression
- Utilize sklearn's iris data with a 3-class Logistic Regression model (12 parameters).
- Compute gradients and Hessian matrix using one data point from test data.
- Compare the results with analytical ones.

Reference: https://pytorch.org/docs/stable/func.api.html

In [1]:
from math import sqrt

import matplotlib.pyplot as plt
plt.style.use("ggplot")

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call, grad, hessian, jacfwd, vmap
from torch.utils.data import DataLoader, TensorDataset
torch.set_default_dtype(torch.float64)
torch.set_printoptions(sci_mode=False)

# Setup of data and model, followed by training
- A simple model with 4-dimensional input, 3-dimensional output, and 12 parameters (4*3).
- In this notebook, it is not necessary to train the model.

In [2]:
# load data
X, y = load_iris(return_X_y=True)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
x_train, x_test = torch.from_numpy(x_train), torch.from_numpy(x_test)
y_train, y_test = torch.from_numpy(y_train).to(torch.long), torch.from_numpy(y_test).to(torch.long)
print(x_train.size(), y_train.size())
print(x_test.size(), y_test.size())

torch.Size([120, 4]) torch.Size([120])
torch.Size([30, 4]) torch.Size([30])


In [3]:
class LogisticRegressionModel(nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(4, 3, bias=False)

    def forward(self, x):
        return self.linear(x)

In [4]:
%%time
# data 
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataset = TensorDataset(x_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# model training
model = LogisticRegressionModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

print(f"# of parameters: {sum(p.numel() for p in model.parameters())}")

for epoch in range(100):
    for data, target in train_loader:
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

# of parameters: 12
CPU times: user 1.38 s, sys: 58.5 ms, total: 1.44 s
Wall time: 1.44 s


# Data and Model Points Used for Calculations
- Use one data point from the test data and a trained model to compute gradients and the Hessian matrix.
- Only one data point is used for simplicity.

In [5]:
# data point
data, target = next(iter(test_loader))
print(data.size(), data)
print(target.size(), target)

# model point
params = dict(model.named_parameters())

torch.Size([1, 4]) tensor([[5.1000, 3.5000, 1.4000, 0.3000]])
torch.Size([1]) tensor([0])


# Gradient
- Comparison of the following three calculation methods:
  - torch.func
  - loss.backward
  - analytical solution

In [6]:
# torch.func
def compute_loss(model, params, data, target):
    out = functional_call(model, params, (data,))
    return criterion(out, target)

g_dict = grad(compute_loss, argnums=1)(model, params, data, target)
g1 = torch.cat([g.data.flatten() for g in g_dict.values()]).unsqueeze(1)
print(g1.T)

tensor([[    -0.0335,     -0.0230,     -0.0092,     -0.0020,      0.0335,
              0.0230,      0.0092,      0.0020,      0.0000,      0.0000,
              0.0000,      0.0000]])


In [7]:
# loss.backward
model.zero_grad()
criterion(model(data), target).backward()
g2 = torch.cat([p.grad.data.flatten() for p in model.parameters()]).unsqueeze(1)
print(g2.T)

tensor([[    -0.0335,     -0.0230,     -0.0092,     -0.0020,      0.0335,
              0.0230,      0.0092,      0.0020,      0.0000,      0.0000,
              0.0000,      0.0000]])


In [8]:
# analytical solution
out = model(data)
target_onehot = F.one_hot(target, num_classes=3)
g3 = ((out.softmax(dim=1) - target_onehot).T @ data).flatten().unsqueeze(1).data
print(g3.T)

tensor([[    -0.0335,     -0.0230,     -0.0092,     -0.0020,      0.0335,
              0.0230,      0.0092,      0.0020,      0.0000,      0.0000,
              0.0000,      0.0000]])


In [9]:
torch.allclose(g1, g2), torch.allclose(g1, g3)

(True, True)

### Note
If `dict(model.named_parameters())` takes the form of `{'name1': parameter_values, 'name2': parameter_values}`,  
then the `torch.func.grad` function also returns in a similar format, such as `{'name1': gradient_values, 'name2': gradient_values}`.

# Hessian
- Comparison of the following two calculation methods:
  - torch.func
  - analytical solution

In [10]:
# torch.func
def compute_loss(model, params, data, target):
    out = functional_call(model, params, (data,))
    return criterion(out, target)

hess_dict = hessian(compute_loss, argnums=1)(model, params, data, target)
hess1 = hess_dict["linear.weight"]["linear.weight"].reshape(12, 12).data
print(hess1.size())

torch.Size([12, 12])


In [11]:
# analytical solution
prob = model(data)[0].softmax(dim=0)
hess2 = torch.zeros(12, 12)
for i in range(3):
    for j in range(3):
        hess2[i*4:(i+1)*4, j*4:(j+1)*4] = prob[i] * (int(i == j) - prob[j]) * (data.T @ data)

print(hess2.size())

torch.Size([12, 12])


In [12]:
torch.allclose(hess1, hess2)

True

### Note
If `dict(model.named_parameters())` takes the form of `{'name1': parameter_values, 'name2': parameter_values}`,  
then the `torch.func.hessian` function returns in the following format:
```
{
    'name1': {'name1': hessian_values, 'name2': hessian_values},
    'name2': {'name1': hessian_values, 'name2': hessian_values}
}
```