In [6]:
import json
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModel.from_pretrained("google/gemma-2b")

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:14<00:00,  7.45s/it]


In [25]:
class TypeEmbedding(nn.Module):
    def __init__(self, model=AutoModel, tokenizer=AutoTokenizer):
        super(TypeEmbedding, self).__init__()
        embedding_dim = model.config.hidden_size
        self.learnable_params = nn.Parameter(torch.randn(embedding_dim))

    def forward(self, object_type_str: str):
        inputs = self.tokenizer(
            object_type_str, return_tensors="pt", padding=True, truncation=True
        )
        with torch.no_grad():
            outputs = self.model(**inputs)
        type_embedding = outputs.last_hidden_state.mean(
            dim=1
        )  # [batch_size, embedding_dim]
        type_embedding = type_embedding.squeeze(0)  # Remove batch dimension
        return type_embedding + self.learnable_params



import torchtune

class InventoryEmbedding(nn.Module):
    def __init__(
        self,
        model=AutoModel,
        tokenizer=AutoTokenizer,
        rope_embedding_dim=128,
        max_quantity=64,
        max_slot=46,
    ):
        super(InventoryEmbedding, self).__init__()
        self.type_embedding = TypeEmbedding(model, tokenizer)
        self.quantity_embedding = torchtune.modules.RotaryPositionalEmbeddings(
            rope_embedding_dim, max_seq_len=max_quantity
        )
        self.slot_embedding = torchtune.modules.RotaryPositionalEmbeddings(
            rope_embedding_dim, max_seq_len=max_slot
        )
        self.fc = nn.Linear(3 * rope_embedding_dim, rope_embedding_dim)

    def forward(self, inventory: list[dict]):

        for item in inventory:
            object_type_str = item["type"]
            quantity = item["quantity"]
            slot = item["slot"]
            type_embedding = self.type_embedding(object_type_str)
            quantity_embedding = self.quantity_embedding(quantity)
            slot_embedding = self.slot_embedding(slot)
        return type_embedding, slot_embedding, quantity_embedding


inv_embed = InventoryEmbedding(model, tokenizer)

tensor([[[[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
            1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
            1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
            1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
            1.0000,  1.0000,  1.0000,  1.0000]],

         [[-0.3012,  1.3818,  0.3128,  1.3792,  0.6394,  1.2614,  0.8073,
            1.1611,  0.8952,  1.0948,  0.9422,  1.0546,  0.9679,  1.0311,
            0.9821,  1.0176,  0.9900,  1.0099,  0.9944,  1.0056,  0.9968,
            1.0032,  0.9982,  1.0018,  0.9990,  1.0010,  0.9994,  1.0006,
            0.9997,  1.0003,  0.9998,  1.0002]],

         [[-1.3254,  0.4932, -0.4707,  1.3336,  0.2155,  1.3977,  0.5892,
            1.2856,  0.7814,  1.1787,  0.8815,  1.1059,  0.9348,  1.0612,
            0.9638,  1.0349,  0.9798,  1.0198,  0.9887,  1.0112,  0.9937,
            1.0063,  0.9964,  1.0036,  0.9980,  1.0020,  0.9989,  1.0011,
            

In [None]:
# load data/train.json

# with open('data/train.json') as f:
#     data = json.load(f)