From d40b02f5e1cb8ca8c28c398cb0e26cba5cec3445 Mon Sep 17 00:00:00 2001 From: Sebastien Fischman Date: Sun, 14 Jun 2020 17:17:17 +0200 Subject: [PATCH] feat: allow weights sample for regression --- census_example.ipynb | 3 ++- pytorch_tabnet/tab_model.py | 20 ++++++++------ pytorch_tabnet/utils.py | 54 ++++++++++++++++++++++++------------- 3 files changed, 49 insertions(+), 28 deletions(-) diff --git a/census_example.ipynb b/census_example.ipynb index df7c218f..6bea6ff4 100755 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -200,6 +200,7 @@ " max_epochs=max_epochs , patience=20,\n", " batch_size=1024, virtual_batch_size=128,\n", " num_workers=0,\n", + " weights=1,\n", " drop_last=False\n", ") " ] @@ -369,7 +370,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5" + "version": "3.7.6" } }, "nbformat": 4, diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index c3b14822..cef61eb8 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -391,9 +391,9 @@ def infer_output_dim(self, y_train, y_valid): if y_valid is not None: valid_labels = unique_labels(y_train) if not set(valid_labels).issubset(set(train_labels)): - print(f"""Valid set -- {set(valid_labels)} -- - contains unkown targets from training -- {set(train_labels)}""") - raise + raise ValueError(f"""Valid set -- {set(valid_labels)} -- + contains unkown targets from training -- + {set(train_labels)}""") return output_dim, train_labels def weight_updater(self, weights): @@ -416,8 +416,7 @@ def weight_updater(self, weights): return {self.target_mapper[key]: value for key, value in weights.items()} else: - print("Unknown type for weights, please provide 0, 1 or dictionnary") - raise + return weights def construct_loaders(self, X_train, y_train, X_valid, y_valid, weights, batch_size, num_workers, drop_last): @@ -689,11 +688,17 @@ def construct_loaders(self, X_train, y_train, X_valid, y_valid, weights, Training and validation dataloaders ------- """ + if isinstance(weights, int): + if weights == 1: + raise ValueError("Please provide a list of weights for regression.") + if isinstance(weights, dict): + raise ValueError("Please provide a list of weights for regression.") + train_dataloader, valid_dataloader = create_dataloaders(X_train, y_train, X_valid, y_valid, - 0, + weights, batch_size, num_workers, drop_last) @@ -717,8 +722,7 @@ def update_fit_params(self, X_train, y_train, X_valid, y_valid, loss_fn, assert y_train.shape[1] == y_valid.shape[1], "Dimension mismatch y_train y_valid" self.output_dim = y_train.shape[1] - self.weights = 0 # No weights for regression - self.updated_weights = 0 + self.updated_weights = weights self.max_epochs = max_epochs self.patience = patience diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index b39dcbaa..330b0886 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -65,22 +65,24 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, Validation data y_valid: np.array Mapped Validation targets - weights : dictionnary or bool - Weight for each mapped target class - 0 for no sampling - 1 for balanced sampling + weights : either 0, 1, dict or iterable + if 0 (default) : no weights will be applied + if 1 : classification only, will balanced class with inverse frequency + if dict : keys are corresponding class values are sample weights + if iterable : list or np array must be of length equal to nb elements + in the training set Returns ------- train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader Training and validation dataloaders """ - if weights == 0: - train_dataloader = DataLoader(TorchDataset(X_train, y_train), - batch_size=batch_size, shuffle=True, - num_workers=num_workers, - drop_last=drop_last) - else: - if weights == 1: + + if isinstance(weights, int): + if weights == 0: + need_shuffle = True + sampler = None + elif weights == 1: + need_shuffle = False class_sample_count = np.array( [len(np.where(y_train == t)[0]) for t in np.unique(y_train)]) @@ -90,18 +92,32 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, samples_weight = torch.from_numpy(samples_weight) samples_weight = samples_weight.double() + sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) else: - # custom weights - samples_weight = np.array([weights[t] for t in y_train]) + raise ValueError('Weights should be either 0, 1, dictionnary or list.') + elif isinstance(weights, dict): + # custom weights per class + need_shuffle = False + samples_weight = np.array([weights[t] for t in y_train]) sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) - train_dataloader = DataLoader(TorchDataset(X_train, y_train), - batch_size=batch_size, sampler=sampler, - num_workers=num_workers, - drop_last=drop_last - ) + else: + # custom weights + if len(weights) != len(y_train): + raise ValueError('Custom weights should match number of train samples.') + need_shuffle = False + samples_weight = np.array(weights) + sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) + + train_dataloader = DataLoader(TorchDataset(X_train, y_train), + batch_size=batch_size, + sampler=sampler, + shuffle=need_shuffle, + num_workers=num_workers, + drop_last=drop_last) valid_dataloader = DataLoader(TorchDataset(X_valid, y_valid), - batch_size=batch_size, shuffle=False, + batch_size=batch_size, + shuffle=False, num_workers=num_workers) return train_dataloader, valid_dataloader