Skip to content

Commit

Permalink
Update test_rec_system.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dkorenkevych authored May 3, 2024
1 parent 84b8772 commit 48b1421
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions test/unit/test_tutorials/test_rec_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-strict


import os
import random
import unittest
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -196,16 +196,23 @@ def setUp(self) -> None:
def test_rec_system(self) -> None:
# load environment
model = SequenceClassificationModel(100).to(device)
if os.path.exists("../Pearl"):
# Github CI
model_dir = "tutorials/single_item_recommender_system_example/"
else:
# internal Meta tests
model_dir = "pearl/tutorials/single_item_recommender_system_example/"

model.load_state_dict(
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
torch.load(
"pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt",
os.path.join(model_dir, "env_model_state_dict.pt"),
weights_only=True,
)
)
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
actions = torch.load(
"pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt",
os.path.join(model_dir, "news_embedding_small.pt"),
weights_only=True,
)
history_length = 8
Expand Down

0 comments on commit 48b1421

Please sign in to comment.