In [25]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.optim import SGD

In [13]:
x = torch.randn(10)
w = torch.randn(10, requires_grad=True)
l = nn.Linear(10, 10)

In [14]:
def loss_fn(w, x):
  x = l(x)
  return torch.dot(w, x)

In [15]:
loss = loss_fn(w, x)
loss

tensor(1.0051, grad_fn=<DotBackward0>)

In [16]:
loss.backward()

In [17]:
w.grad

tensor([-1.3856, -1.1738, -0.4841, -0.3064, -1.6017, -0.2439,  1.1840, -0.7747,
        -0.3257, -0.6954])

In [22]:
weights, biases = list(l.parameters())

In [23]:
weights.grad

tensor([[-0.6793, -4.2598, -0.3011, -1.8228, -2.3300, -2.2589,  0.6463, -0.8030,
          0.7916, -0.4164],
        [ 0.4247,  2.6630,  0.1882,  1.1395,  1.4566,  1.4122, -0.4040,  0.5020,
         -0.4949,  0.2603],
        [ 0.2541,  1.5932,  0.1126,  0.6818,  0.8714,  0.8449, -0.2417,  0.3003,
         -0.2961,  0.1557],
        [-0.2587, -1.6223, -0.1147, -0.6942, -0.8873, -0.8603,  0.2461, -0.3058,
          0.3015, -0.1586],
        [ 1.2957,  8.1247,  0.5742,  3.4766,  4.4440,  4.3085, -1.2326,  1.5315,
         -1.5098,  0.7941],
        [ 0.6593,  4.1339,  0.2922,  1.7689,  2.2611,  2.1922, -0.6272,  0.7792,
         -0.7682,  0.4041],
        [ 1.2081,  7.5755,  0.5354,  3.2416,  4.1436,  4.0172, -1.1493,  1.4280,
         -1.4077,  0.7405],
        [-0.1377, -0.8634, -0.0610, -0.3695, -0.4723, -0.4579,  0.1310, -0.1628,
          0.1604, -0.0844],
        [-0.3558, -2.2310, -0.1577, -0.9547, -1.2203, -1.1831,  0.3385, -0.4206,
          0.4146, -0.2181],
        [ 0.5040,  

In [27]:
def loss_fn(layer, x):
  return F.mse_loss(layer(x), torch.ones(3))

In [39]:
model = nn.Linear(5, 3)
optim = SGD(model.parameters(), 1e-3)
for epoch in range(1):
  gradient_cache = []
  for instance in range(2): # households, cars, ...
    optim.zero_grad()
    x_batch = torch.tensor([1.,.3,.4,.3,.3])
    loss = loss_fn(model, x_batch)
    loss.backward()
    
    # local gradients
    gradient_cache.append([p.grad for p in model.parameters()])
  for instance_grad in gradient_cache:
    print(instance_grad)
    print()
  #gradient_cache = torch.stack(gradient_cache)
  #mean_gradient = torch.mean(gradient_cache, dim=0)
  # TODO: go over parameters and divide by number of instance
  #optim.step()

[tensor([[-0.7399, -0.2220, -0.2959, -0.2220, -0.2220],
        [-0.7462, -0.2239, -0.2985, -0.2239, -0.2239],
        [-0.4514, -0.1354, -0.1806, -0.1354, -0.1354]]), tensor([-0.7399, -0.7462, -0.4514])]

[tensor([[-0.7399, -0.2220, -0.2959, -0.2220, -0.2220],
        [-0.7462, -0.2239, -0.2985, -0.2239, -0.2239],
        [-0.4514, -0.1354, -0.1806, -0.1354, -0.1354]]), tensor([-0.7399, -0.7462, -0.4514])]



In [51]:
#model = nn.Linear(5, 3)
optim = SGD(model.parameters(), 1e-1)
for epoch in range(100):
  optim.zero_grad()
  instances = 2
  for instance in range(instances): # households, cars, ...
    x_batch = torch.tensor([1.,.3,.4,.3,.3])
    loss = loss_fn(model, x_batch)
    print("instance", instance, "loss", loss.item())
    loss.backward()
  for p in model.parameters():
    p.grad = p.grad / instances
  print(epoch)
  print([p.grad for p in model.parameters()])
  print()
  optim.step()

instance 0 loss 6.872803623991786e-07
instance 1 loss 6.872803623991786e-07
0
[tensor([[-0.0006, -0.0002, -0.0002, -0.0002, -0.0002],
        [-0.0006, -0.0002, -0.0002, -0.0002, -0.0002],
        [-0.0004, -0.0001, -0.0002, -0.0001, -0.0001]]), tensor([-0.0006, -0.0006, -0.0004])]

instance 0 loss 4.826600275009696e-07
instance 1 loss 4.826600275009696e-07
1
[tensor([[-5.1896e-04, -1.5569e-04, -2.0758e-04, -1.5569e-04, -1.5569e-04],
        [-5.2341e-04, -1.5702e-04, -2.0936e-04, -1.5702e-04, -1.5702e-04],
        [-3.1666e-04, -9.4998e-05, -1.2666e-04, -9.4998e-05, -9.4998e-05]]), tensor([-0.0005, -0.0005, -0.0003])]

instance 0 loss 3.3893391560013697e-07
instance 1 loss 3.3893391560013697e-07
2
[tensor([[-4.3488e-04, -1.3046e-04, -1.7395e-04, -1.3046e-04, -1.3046e-04],
        [-4.3861e-04, -1.3158e-04, -1.7544e-04, -1.3158e-04, -1.3158e-04],
        [-2.6536e-04, -7.9608e-05, -1.0614e-04, -7.9608e-05, -7.9608e-05]]), tensor([-0.0004, -0.0004, -0.0003])]

instance 0 loss 2.37983613

In [52]:
model(torch.tensor([1.,.3,.4,.3,.3]))

tensor([1.0000, 1.0000, 1.0000], grad_fn=<ViewBackward0>)