Skip to content

Commit

Permalink
Update seldonian.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hannanabdul55 committed Mar 9, 2021
1 parent 0748828 commit 9b96709
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions seldonian/seldonian.py
Expand Up @@ -87,7 +87,7 @@ def __init__(self, X, y, test_size=0.4, g_hats=[], verbose=False, stratify=False
self.y = self.y_t
self.X, self.X_s, self.y, self.y_s = train_test_split(
self.X, self.y, test_size=test_size,
random_state=count+1
random_state=count + 1
)
self.X = torch.as_tensor(self.X, dtype=torch.float, device=device)
self.y = torch.as_tensor(self.y, dtype=torch.long, device=device)
Expand Down Expand Up @@ -146,18 +146,20 @@ def fit(self, **kwargs):
# grad_check(self.mod.named_parameters())
self.optimizer.step()

self.optimizer.zero_grad()
if self.l_optimizer is not None:
self.l_optimizer.zero_grad()

if self.lagrange is not None:
loss_f = -1 * (self.loss_fn(self.mod(x), y) + (self.lagrange ** 2).dot(
self._safetyTest(predict=True)))
loss_f.backward(retain_graph=True)
# l_optimizer is a separate optimizer for the lagrangian.
if self.l_optimizer is not None:
self.l_optimizer.step()

# loss_f = -1 * (self.loss_fn(self.mod(x), y) + (self.lagrange ** 2).dot(
# self._safetyTest(predict=True)))
# loss_f.backward(retain_graph=True)
# # l_optimizer is a separate optimizer for the lagrangian.
# if self.l_optimizer is not None:
# self.l_optimizer.step()
with torch.no_grad:
self.lagrange = torch.sqrt((torch.abs(
self.loss_fn(self.mod(x), y) / self._safetyTest(predict=True))))
self.optimizer.zero_grad()
running_loss += loss.item()

if i % 10 == 9: # print every 2000 mini-batches
Expand Down

0 comments on commit 9b96709

Please sign in to comment.