In [14]:
import torch
import numpy as np

In [21]:
def extract_item_embeddings(model_path, output_file='vocab/Books_SASREC_item_emb.npy'):
    """
    Extract item embeddings from a trained model checkpoint and save as numpy array
    
    Args:
        model_path: Path to the model checkpoint file
        output_dir: Path to save the embeddings
    """
    # Load model checkpoint
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # Extract item embeddings from model state dict
    # Assuming embeddings are stored in 'item_embeddings.weight'
    item_embeddings = checkpoint['item_feat.emb_look_up.weight'].numpy()
    
    # Save embeddings
    np.save(output_file, item_embeddings)
    print(f"Embeddings saved to: {output_file}")
    
    return item_embeddings

# Example usage
model_path = 'SASRec_checkpoint/best.pth'  # Update with actual model path
embeddings = extract_item_embeddings(model_path)
print(f"Embedding shape: {embeddings.shape}")


Embeddings saved to: vocab/Books_SASREC_item_emb.npy
Embedding shape: (367983, 64)
