.. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here <sphx_glr_download_gallery_extract-saturation.py>` to download the full example code
.. rst-class:: sphx-glr-example-title
Extract layer saturation with Delve.
import torch
from tqdm import trange
from delve import SaturationTracker
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1)
for h in [3, 32]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, h, 10
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
x_test = torch.randn(N, D_in)
y_test = torch.randn(N, D_out)
# You can watch specific layers by handing them to delve as a list.
# Also, you can hand over the entire Module-object to delve and let delve search for recordable layers.
model = TwoLayerNet(D_in, H, D_out)
x, y, model = x.to(device), y.to(device), model.to(device)
x_test, y_test = x_test.to(device), y_test.to(device)
layers = [model.linear1, model.linear2]
stats = SaturationTracker('regression/h{}'.format(h),
save_to="plotcsv",
modules=layers,
device=device,
stats=["lsat", "lsat_eval"])
loss_fn = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
steps_iter = trange(2000, desc='steps', leave=True, position=0)
steps_iter.write("{:^80}".format(
"Regression - TwoLayerNet - Hidden layer size {}".format(h)))
for step in steps_iter:
# training step
model.train()
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# test step
model.eval()
y_pred = model(x_test)
loss_test = loss_fn(y_pred, y_test)
# update statistics
steps_iter.set_description('loss=%g' % loss.item())
stats.add_scalar("train-loss", loss.item())
stats.add_scalar("test-loss", loss_test.item())
stats.add_saturations()
steps_iter.write('\n')
stats.close()
steps_iter.close()
.. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds)
.. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: extract-saturation.py <extract-saturation.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: extract-saturation.ipynb <extract-saturation.ipynb>`
.. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_