Skip to content

Commit

Permalink
feat: allow weights sample for regression
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Jun 18, 2020
1 parent 16d92d5 commit d40b02f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 28 deletions.
3 changes: 2 additions & 1 deletion census_example.ipynb
Expand Up @@ -200,6 +200,7 @@
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False\n",
") "
]
Expand Down Expand Up @@ -369,7 +370,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
"version": "3.7.6"
}
},
"nbformat": 4,
Expand Down
20 changes: 12 additions & 8 deletions pytorch_tabnet/tab_model.py
Expand Up @@ -391,9 +391,9 @@ def infer_output_dim(self, y_train, y_valid):
if y_valid is not None:
valid_labels = unique_labels(y_train)
if not set(valid_labels).issubset(set(train_labels)):
print(f"""Valid set -- {set(valid_labels)} --
contains unkown targets from training -- {set(train_labels)}""")
raise
raise ValueError(f"""Valid set -- {set(valid_labels)} --
contains unkown targets from training --
{set(train_labels)}""")
return output_dim, train_labels

def weight_updater(self, weights):
Expand All @@ -416,8 +416,7 @@ def weight_updater(self, weights):
return {self.target_mapper[key]: value
for key, value in weights.items()}
else:
print("Unknown type for weights, please provide 0, 1 or dictionnary")
raise
return weights

def construct_loaders(self, X_train, y_train, X_valid, y_valid,
weights, batch_size, num_workers, drop_last):
Expand Down Expand Up @@ -689,11 +688,17 @@ def construct_loaders(self, X_train, y_train, X_valid, y_valid, weights,
Training and validation dataloaders
-------
"""
if isinstance(weights, int):
if weights == 1:
raise ValueError("Please provide a list of weights for regression.")
if isinstance(weights, dict):
raise ValueError("Please provide a list of weights for regression.")

train_dataloader, valid_dataloader = create_dataloaders(X_train,
y_train,
X_valid,
y_valid,
0,
weights,
batch_size,
num_workers,
drop_last)
Expand All @@ -717,8 +722,7 @@ def update_fit_params(self, X_train, y_train, X_valid, y_valid, loss_fn,
assert y_train.shape[1] == y_valid.shape[1], "Dimension mismatch y_train y_valid"
self.output_dim = y_train.shape[1]

self.weights = 0 # No weights for regression
self.updated_weights = 0
self.updated_weights = weights

self.max_epochs = max_epochs
self.patience = patience
Expand Down
54 changes: 35 additions & 19 deletions pytorch_tabnet/utils.py
Expand Up @@ -65,22 +65,24 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights,
Validation data
y_valid: np.array
Mapped Validation targets
weights : dictionnary or bool
Weight for each mapped target class
0 for no sampling
1 for balanced sampling
weights : either 0, 1, dict or iterable
if 0 (default) : no weights will be applied
if 1 : classification only, will balanced class with inverse frequency
if dict : keys are corresponding class values are sample weights
if iterable : list or np array must be of length equal to nb elements
in the training set
Returns
-------
train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader
Training and validation dataloaders
"""
if weights == 0:
train_dataloader = DataLoader(TorchDataset(X_train, y_train),
batch_size=batch_size, shuffle=True,
num_workers=num_workers,
drop_last=drop_last)
else:
if weights == 1:

if isinstance(weights, int):
if weights == 0:
need_shuffle = True
sampler = None
elif weights == 1:
need_shuffle = False
class_sample_count = np.array(
[len(np.where(y_train == t)[0]) for t in np.unique(y_train)])

Expand All @@ -90,18 +92,32 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights,

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
else:
# custom weights
samples_weight = np.array([weights[t] for t in y_train])
raise ValueError('Weights should be either 0, 1, dictionnary or list.')
elif isinstance(weights, dict):
# custom weights per class
need_shuffle = False
samples_weight = np.array([weights[t] for t in y_train])
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataloader = DataLoader(TorchDataset(X_train, y_train),
batch_size=batch_size, sampler=sampler,
num_workers=num_workers,
drop_last=drop_last
)
else:
# custom weights
if len(weights) != len(y_train):
raise ValueError('Custom weights should match number of train samples.')
need_shuffle = False
samples_weight = np.array(weights)
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

train_dataloader = DataLoader(TorchDataset(X_train, y_train),
batch_size=batch_size,
sampler=sampler,
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last)

valid_dataloader = DataLoader(TorchDataset(X_valid, y_valid),
batch_size=batch_size, shuffle=False,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)

return train_dataloader, valid_dataloader
Expand Down

0 comments on commit d40b02f

Please sign in to comment.