Skip to content

Commit

Permalink
Add nn config
Browse files Browse the repository at this point in the history
  • Loading branch information
hannanabdul55 committed Mar 9, 2021
1 parent 1ef4db0 commit c3fcabd
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions seldonian/seldonian.py
Expand Up @@ -34,7 +34,7 @@ class VanillaNN(SeldonianAlgorithm):
"""

def __init__(self, X, y, test_size=0.4, g_hats=[], verbose=False, stratify=False, epochs=10,
model=None):
model=None, random_seed=0):
"""
Initialize a model with `g_hats` constraints. This class is an example of training a
non-linear model like a neural network based on the Seldonian Approach.
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self, X, y, test_size=0.4, g_hats=[], verbose=False, stratify=False
# Stratify the sampling method for safety and candidate set using the `stratify` param.
if not stratify:
self.X, self.X_s, self.y, self.y_s = train_test_split(
self.X, self.y, test_size=test_size, random_state=0
self.X, self.y, test_size=test_size, random_state=random_seed
)
self.X = torch.as_tensor(self.X, dtype=torch.float, device=device)
self.y = torch.as_tensor(self.y, dtype=torch.long, device=device)
Expand All @@ -86,7 +86,8 @@ def __init__(self, X, y, test_size=0.4, g_hats=[], verbose=False, stratify=False
self.X = self.X_t
self.y = self.y_t
self.X, self.X_s, self.y, self.y_s = train_test_split(
self.X, self.y, test_size=test_size
self.X, self.y, test_size=test_size,
random_state=count+1
)
self.X = torch.as_tensor(self.X, dtype=torch.float, device=device)
self.y = torch.as_tensor(self.y, dtype=torch.long, device=device)
Expand Down

0 comments on commit c3fcabd

Please sign in to comment.