Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix split error #92

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions irec/environment/loader/full_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ def _filter(self,
for filter_method, value in filters.items():
print(f"\t {filter_method}: {value}")
data_df = getattr(FilterRegistry.get(key), filter_method)(data_df, value)

return data_df.to_numpy()

def _split(self,
dataset: Dataset) -> Tuple[Dataset, Dataset]:
dataset: Dataset,
validation=False) -> Tuple[Dataset, Dataset]:
"""split

Splits the data set into training and testing
Expand All @@ -109,14 +110,26 @@ def _split(self,
# Apply it in the data
test_uids = split_strategy.get_test_uids(dataset.data, num_test_users)
train_dataset, test_dataset = split_strategy.split_dataset(dataset.data, test_uids)
train_dataset.update_num_total_users_items(
num_total_users=dataset.num_total_users,
num_total_items=dataset.num_total_items
)
test_dataset.update_num_total_users_items(
num_total_users=dataset.num_total_users,
num_total_items=dataset.num_total_items
)
if not validation:
train_dataset.update_num_total_users_items(
num_total_users=dataset.num_total_users,
num_total_items=dataset.num_total_items
)
test_dataset.update_num_total_users_items(
num_total_users=dataset.num_total_users,
num_total_items=dataset.num_total_items
)
else:
num_total_users = train_dataset.max_uid+1 if train_dataset.max_uid >= test_dataset.max_uid else test_dataset.max_uid+1
num_total_items = train_dataset.max_iid+1 if train_dataset.max_iid >= test_dataset.max_iid else test_dataset.max_iid+1
train_dataset.update_num_total_users_items(
num_total_users=num_total_users,
num_total_items=num_total_items
)
test_dataset.update_num_total_users_items(
num_total_users=num_total_users,
num_total_items=num_total_items
)
return train_dataset, test_dataset

def process(self) -> Tuple[Dataset, Dataset]:
Expand Down Expand Up @@ -157,6 +170,11 @@ def process(self) -> Tuple[Dataset, Dataset]:
x_validation, y_validation = None, None
if self.validation is not None:
print("\nGenerating x_validation and y_validation: ")
x_validation, y_validation = self._split(train_dataset)
x_validation, y_validation = self._split(train_dataset, validation=True)

print("train_dataset", train_dataset)
print("test_dataset", test_dataset)
print("x_validation", x_validation)
print("y_validation", y_validation)

return train_dataset, test_dataset, x_validation, y_validation
Loading