-
Notifications
You must be signed in to change notification settings - Fork 19
/
utilities.py
197 lines (147 loc) · 7.66 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from typing import List, Tuple, Iterable
import numpy as np
import torch
import torch.nn as nn
from scipy.sparse.linalg import LinearOperator, eigsh
from torch import Tensor
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.optim import SGD
from torch.optim.optimizer import Optimizer
from torch.utils.data import Dataset, DataLoader
import os
# the default value for "physical batch size", which is the largest batch size that we try to put on the GPU
DEFAULT_PHYS_BS = 1000
def get_gd_directory(dataset: str, lr: float, arch_id: str, seed: int, opt: str, loss: str, beta: float = None):
"""Return the directory in which the results should be saved."""
results_dir = os.environ["RESULTS"]
directory = f"{results_dir}/{dataset}/{arch_id}/seed_{seed}/{loss}/{opt}/"
if opt == "gd":
return f"{directory}/lr_{lr}"
elif opt == "polyak" or opt == "nesterov":
return f"{directory}/lr_{lr}_beta_{beta}"
def get_flow_directory(dataset: str, arch_id: str, seed: int, loss: str, tick: float):
"""Return the directory in which the results should be saved."""
results_dir = os.environ["RESULTS"]
return f"{results_dir}/{dataset}/{arch_id}/seed_{seed}/{loss}/flow/tick_{tick}"
def get_modified_flow_directory(dataset: str, arch_id: str, seed: int, loss: str, gd_lr: float, tick: float):
"""Return the directory in which the results should be saved."""
results_dir = os.environ["RESULTS"]
return f"{results_dir}/{dataset}/{arch_id}/seed_{seed}/{loss}/modified_flow_lr_{gd_lr}/tick_{tick}"
def get_gd_optimizer(parameters, opt: str, lr: float, momentum: float) -> Optimizer:
if opt == "gd":
return SGD(parameters, lr=lr)
elif opt == "polyak":
return SGD(parameters, lr=lr, momentum=momentum, nesterov=False)
elif opt == "nesterov":
return SGD(parameters, lr=lr, momentum=momentum, nesterov=True)
def save_files(directory: str, arrays: List[Tuple[str, torch.Tensor]]):
"""Save a bunch of tensors."""
for (arr_name, arr) in arrays:
torch.save(arr, f"{directory}/{arr_name}")
def save_files_final(directory: str, arrays: List[Tuple[str, torch.Tensor]]):
"""Save a bunch of tensors."""
for (arr_name, arr) in arrays:
torch.save(arr, f"{directory}/{arr_name}_final")
def iterate_dataset(dataset: Dataset, batch_size: int):
"""Iterate through a dataset, yielding batches of data."""
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
for (batch_X, batch_y) in loader:
yield batch_X.cuda(), batch_y.cuda()
def compute_losses(network: nn.Module, loss_functions: List[nn.Module], dataset: Dataset,
batch_size: int = DEFAULT_PHYS_BS):
"""Compute loss over a dataset."""
L = len(loss_functions)
losses = [0. for l in range(L)]
with torch.no_grad():
for (X, y) in iterate_dataset(dataset, batch_size):
preds = network(X)
for l, loss_fn in enumerate(loss_functions):
losses[l] += loss_fn(preds, y) / len(dataset)
return losses
def get_loss_and_acc(loss: str):
"""Return modules to compute the loss and accuracy. The loss module should be "sum" reduction. """
if loss == "mse":
return SquaredLoss(), SquaredAccuracy()
elif loss == "ce":
return nn.CrossEntropyLoss(reduction='sum'), AccuracyCE()
raise NotImplementedError(f"no such loss function: {loss}")
def compute_hvp(network: nn.Module, loss_fn: nn.Module,
dataset: Dataset, vector: Tensor, physical_batch_size: int = DEFAULT_PHYS_BS):
"""Compute a Hessian-vector product."""
p = len(parameters_to_vector(network.parameters()))
n = len(dataset)
hvp = torch.zeros(p, dtype=torch.float, device='cuda')
vector = vector.cuda()
for (X, y) in iterate_dataset(dataset, physical_batch_size):
loss = loss_fn(network(X), y) / n
grads = torch.autograd.grad(loss, inputs=network.parameters(), create_graph=True)
dot = parameters_to_vector(grads).mul(vector).sum()
grads = [g.contiguous() for g in torch.autograd.grad(dot, network.parameters(), retain_graph=True)]
hvp += parameters_to_vector(grads)
return hvp
def lanczos(matrix_vector, dim: int, neigs: int):
""" Invoke the Lanczos algorithm to compute the leading eigenvalues and eigenvectors of a matrix / linear operator
(which we can access via matrix-vector products). """
def mv(vec: np.ndarray):
gpu_vec = torch.tensor(vec, dtype=torch.float).cuda()
return matrix_vector(gpu_vec)
operator = LinearOperator((dim, dim), matvec=mv)
evals, evecs = eigsh(operator, neigs)
return torch.from_numpy(np.ascontiguousarray(evals[::-1]).copy()).float(), \
torch.from_numpy(np.ascontiguousarray(np.flip(evecs, -1)).copy()).float()
def get_hessian_eigenvalues(network: nn.Module, loss_fn: nn.Module, dataset: Dataset,
neigs=6, physical_batch_size=1000):
""" Compute the leading Hessian eigenvalues. """
hvp_delta = lambda delta: compute_hvp(network, loss_fn, dataset,
delta, physical_batch_size=physical_batch_size).detach().cpu()
nparams = len(parameters_to_vector((network.parameters())))
evals, evecs = lanczos(hvp_delta, nparams, neigs=neigs)
return evals
def compute_gradient(network: nn.Module, loss_fn: nn.Module,
dataset: Dataset, physical_batch_size: int = DEFAULT_PHYS_BS):
""" Compute the gradient of the loss function at the current network parameters. """
p = len(parameters_to_vector(network.parameters()))
average_gradient = torch.zeros(p, device='cuda')
for (X, y) in iterate_dataset(dataset, physical_batch_size):
batch_loss = loss_fn(network(X), y) / len(dataset)
batch_gradient = parameters_to_vector(torch.autograd.grad(batch_loss, inputs=network.parameters()))
average_gradient += batch_gradient
return average_gradient
class AtParams(object):
""" Within a with block, install a new set of parameters into a network.
Usage:
# suppose the network has parameter vector old_params
with AtParams(network, new_params):
# now network has parameter vector new_params
do_stuff()
# now the network once again has parameter vector new_params
"""
def __init__(self, network: nn.Module, new_params: Tensor):
self.network = network
self.new_params = new_params
def __enter__(self):
self.stash = parameters_to_vector(self.network.parameters())
vector_to_parameters(self.new_params, self.network.parameters())
def __exit__(self, type, value, traceback):
vector_to_parameters(self.stash, self.network.parameters())
def compute_gradient_at_theta(network: nn.Module, loss_fn: nn.Module, dataset: Dataset,
theta: torch.Tensor, batch_size=DEFAULT_PHYS_BS):
""" Compute the gradient of the loss function at arbitrary network parameters "theta". """
with AtParams(network, theta):
return compute_gradient(network, loss_fn, dataset, physical_batch_size=batch_size)
class SquaredLoss(nn.Module):
def forward(self, input: Tensor, target: Tensor):
return 0.5 * ((input - target) ** 2).sum()
class SquaredAccuracy(nn.Module):
def __init__(self):
super(SquaredAccuracy, self).__init__()
def forward(self, input, target):
return (input.argmax(1) == target.argmax(1)).float().sum()
class AccuracyCE(nn.Module):
def __init__(self):
super(AccuracyCE, self).__init__()
def forward(self, input, target):
return (input.argmax(1) == target).float().sum()
class VoidLoss(nn.Module):
def forward(self, X, Y):
return 0