Skip to content

Commit

Permalink
Merge pull request #92 from irec-org/fix-split
Browse files Browse the repository at this point in the history
fix split error
  • Loading branch information
thiagodks committed Apr 5, 2024
2 parents 468d56c + 8f066a7 commit 9851443
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions irec/environment/loader/full_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ def _filter(self,
for filter_method, value in filters.items():
print(f"\t {filter_method}: {value}")
data_df = getattr(FilterRegistry.get(key), filter_method)(data_df, value)

return data_df.to_numpy()

def _split(self,
dataset: Dataset) -> Tuple[Dataset, Dataset]:
dataset: Dataset,
validation=False) -> Tuple[Dataset, Dataset]:
"""split
Splits the data set into training and testing
Expand All @@ -109,14 +110,26 @@ def _split(self,
# Apply it in the data
test_uids = split_strategy.get_test_uids(dataset.data, num_test_users)
train_dataset, test_dataset = split_strategy.split_dataset(dataset.data, test_uids)
train_dataset.update_num_total_users_items(
num_total_users=dataset.num_total_users,
num_total_items=dataset.num_total_items
)
test_dataset.update_num_total_users_items(
num_total_users=dataset.num_total_users,
num_total_items=dataset.num_total_items
)
if not validation:
train_dataset.update_num_total_users_items(
num_total_users=dataset.num_total_users,
num_total_items=dataset.num_total_items
)
test_dataset.update_num_total_users_items(
num_total_users=dataset.num_total_users,
num_total_items=dataset.num_total_items
)
else:
num_total_users = train_dataset.max_uid+1 if train_dataset.max_uid >= test_dataset.max_uid else test_dataset.max_uid+1
num_total_items = train_dataset.max_iid+1 if train_dataset.max_iid >= test_dataset.max_iid else test_dataset.max_iid+1
train_dataset.update_num_total_users_items(
num_total_users=num_total_users,
num_total_items=num_total_items
)
test_dataset.update_num_total_users_items(
num_total_users=num_total_users,
num_total_items=num_total_items
)
return train_dataset, test_dataset

def process(self) -> Tuple[Dataset, Dataset]:
Expand Down Expand Up @@ -157,6 +170,11 @@ def process(self) -> Tuple[Dataset, Dataset]:
x_validation, y_validation = None, None
if self.validation is not None:
print("\nGenerating x_validation and y_validation: ")
x_validation, y_validation = self._split(train_dataset)
x_validation, y_validation = self._split(train_dataset, validation=True)

print("train_dataset", train_dataset)
print("test_dataset", test_dataset)
print("x_validation", x_validation)
print("y_validation", y_validation)

return train_dataset, test_dataset, x_validation, y_validation

0 comments on commit 9851443

Please sign in to comment.