In [19]:
import json
import numpy as np
import tensorflow as tf
from recommenders.models.sasrec.model import SASREC 

In [20]:
# 1. Load the saved configuration
with open("../data/ckpt/sasrec_config.json", "r") as f:
    config = json.load(f)

In [21]:
# 2. Load ID mappings (assuming you saved these during preprocessing)
with open("../data/id_maps.json", "r") as f:
    id_maps = json.load(f)

In [22]:
# Create bidirectional mappings
item_idx_to_id = id_maps.get("idx_to_item_id", {})
item_id_to_idx = id_maps.get("item_id_to_idx", {})

In [23]:
# 3. Recreate the model with same architecture
model = SASREC(
    item_num=config["item_num"],
    seq_max_len=config["seq_max_len"],
    num_blocks=config["num_blocks"],
    embedding_dim=config["embedding_dim"],
    attention_dim=config["attention_dim"],
    attention_num_heads=config["attention_num_heads"],
    dropout_rate=config["dropout_rate"],
    conv_dims=[100, 100],  # Use correct dimensions for SASREC
    l2_reg=config["l2_reg"],
    num_neg_test=config["num_neg_test"]
)

In [24]:
# 4. Restore weights
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore("../data/ckpt/sasrec.ckpt").expect_partial()

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7e89a4165fd0>

In [25]:
# 5. Function to recommend items based on a basket
def recommend_from_basket(basket_item_ids, top_k=10):
    """
    Get recommendations based on items in basket
    
    Args:
        basket_item_ids: List of original item IDs in the basket
        top_k: Number of recommendations to return
    
    Returns:
        List of recommended item IDs
    """
    # Convert basket items to indices
    basket_indices = []
    for item_id in basket_item_ids:
        if str(item_id) in item_id_to_idx:
            basket_indices.append(int(item_id_to_idx[str(item_id)]))
    
    if not basket_indices:
        print("No valid items found in basket")
        return []
    
    # Prepare sequence for the model (using basket as sequence)
    seq = np.zeros([1, config["seq_max_len"]], dtype=np.int32)
    idx = min(len(basket_indices), config["seq_max_len"])
    seq[0, -idx:] = basket_indices[-idx:]  # Add items to end of sequence
    
    # Use a custom prediction approach
    # Create scores for all possible items
    all_items = np.arange(1, config["item_num"] + 1)  # All item indices
    
    # Convert to tensors
    seq_tensor = tf.constant(seq, dtype=tf.int32)
    
    # Get item embeddings (this part depends on SASREC implementation)
    # We'll try to access these through the model's layers
    item_emb_table = None
    for layer in model.layers:
        if 'embedding' in layer.name and hasattr(layer, 'embeddings'):
            item_emb_table = layer.embeddings
            break
    
    if item_emb_table is not None:
        # Get the last non-zero item in sequence (for prediction context)
        seq_emb = model.seq_embeddings(seq_tensor)
        
        # Get a mask for the sequence (to ignore padding)
        mask = tf.expand_dims(tf.cast(tf.not_equal(seq_tensor, 0), tf.float32), -1)
        
        # Apply attention and get the final representation
        # This uses internal model layers - adjust based on actual implementation
        for i in range(config["num_blocks"]):
            seq_emb = model.attention_blocks[i](seq_emb, mask, training=False)
            
        # Get the final position representation (the last non-pad item)
        flat_mask = tf.reshape(mask, [-1])
        indices = tf.range(tf.shape(flat_mask)[0])
        last_indices = tf.reduce_sum(tf.cast(tf.not_equal(flat_mask, 0), tf.int32)) - 1
        seq_target = tf.gather(tf.reshape(seq_emb, [-1, seq_emb.shape[-1]]), last_indices)
        
        # Get item scores by dot product
        all_item_emb = tf.nn.embedding_lookup(item_emb_table, all_items)
        scores = tf.matmul(tf.expand_dims(seq_target, 0), all_item_emb, transpose_b=True)
        scores = tf.squeeze(scores)
        
        # Convert to numpy for processing
        predictions = scores.numpy()
    else:
        # Fallback - use a simpler prediction approach
        # Create a dummy model call for test mode
        print("Using fallback prediction method")
        predictions = np.random.rand(config["item_num"])  # Just for testing
    
    # Get top-k items, excluding those already in basket
    basket_set = set(basket_indices)
    
    # Sort and filter
    top_indices = []
    sorted_idx = np.argsort(-predictions)
    for idx in sorted_idx:
        if idx not in basket_set and idx > 0:  # Skip padding item (0)
            top_indices.append(idx)
            if len(top_indices) >= top_k:
                break
    
    # Convert indices back to original IDs
    recommended_items = [item_idx_to_id[str(idx)] for idx in top_indices 
                         if str(idx) in item_idx_to_id]
    
    return recommended_items

In [26]:
# Example usage
basket = ["123", "456"]  # Replace with real item IDs

In [27]:
# Check if basket items exist in the mappings
for item in basket:
    if item in item_id_to_idx:
        print(f"Item {item} exists in mappings")
    else:
        print(f"Item {item} DOES NOT exist in mappings")

# Print a few sample items from the mappings
print("Sample items from mappings:")
sample_items = list(item_id_to_idx.keys())[:5]
for item in sample_items:
    print(f"Item ID: {item}")

Item 123 exists in mappings
Item 456 exists in mappings
Sample items from mappings:
Item ID: 1
Item ID: 2
Item ID: 3
Item ID: 4
Item ID: 5


In [28]:
recommendations = recommend_from_basket(basket, top_k=5)
print(f"For basket {basket}, we recommend: {recommendations}")

Using fallback prediction method
For basket ['123', '456'], we recommend: [2547, 2209, 188, 2952, 3667]
