From 48b1421a927c3e2f1579c923fa946ab11deb85fa Mon Sep 17 00:00:00 2001 From: dkorenkevych Date: Fri, 3 May 2024 16:14:42 -0700 Subject: [PATCH] Update test_rec_system.py --- test/unit/test_tutorials/test_rec_system.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/unit/test_tutorials/test_rec_system.py b/test/unit/test_tutorials/test_rec_system.py index 89b0c69..26b1f1d 100644 --- a/test/unit/test_tutorials/test_rec_system.py +++ b/test/unit/test_tutorials/test_rec_system.py @@ -2,7 +2,7 @@ # pyre-strict - +import os import random import unittest from typing import List, Optional, Tuple @@ -196,16 +196,23 @@ def setUp(self) -> None: def test_rec_system(self) -> None: # load environment model = SequenceClassificationModel(100).to(device) + if os.path.exists("../Pearl"): + # Github CI + model_dir = "tutorials/single_item_recommender_system_example/" + else: + # internal Meta tests + model_dir = "pearl/tutorials/single_item_recommender_system_example/" + model.load_state_dict( # Note: in the tutorial the directory "pearl" must be replaced by "Pearl" torch.load( - "pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt", + os.path.join(model_dir, "env_model_state_dict.pt"), weights_only=True, ) ) # Note: in the tutorial the directory "pearl" must be replaced by "Pearl" actions = torch.load( - "pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt", + os.path.join(model_dir, "news_embedding_small.pt"), weights_only=True, ) history_length = 8