In [6]:
import json
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModel.from_pretrained("google/gemma-2b")

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:14<00:00,  7.45s/it]


In [47]:
class TypeEmbedding(nn.Module):
    def __init__(self, model=AutoModel, tokenizer=AutoTokenizer):
        super(TypeEmbedding, self).__init__()
        self.embedding_dim = model.config.hidden_size
        self.learnable_params = nn.Parameter(torch.randn(self.embedding_dim))
        self.model = model
        self.tokenizer = tokenizer
        self.cache = {}

    def forward(self, object_types: list[str]):
        batch, new_types = ([], [])
        for object_type in object_types:
            if object_type not in self.cache:
                batch.append(object_type)
                new_types.append(object_type)
        if len(new_types) > 0:
            inputs = self.tokenizer(
                new_types, return_tensors="pt", padding=True, truncation=True
            )
            with torch.no_grad():
                outputs = self.model(**inputs)
            for i, object_type in enumerate(new_types):
                type_embedding = outputs.last_hidden_state[i].mean(dim=0)
                self.cache[object_type] = type_embedding
        embeddings = [
            self.cache[object_type] + self.learnable_params
            for object_type in object_types
        ]
        return torch.stack(embeddings)


class InventoryEmbedding(nn.Module):
    def __init__(
        self,
        model=AutoModel,
        tokenizer=AutoTokenizer,
        hid_dim=128,
        max_quantity=64,
        max_slot=46,
    ):
        super(InventoryEmbedding, self).__init__()
        self.type_embedding = TypeEmbedding(model, tokenizer)
        self.quantity_embedding = nn.Embedding(max_quantity, hid_dim)
        self.slot_embedding = nn.Embedding(max_slot, hid_dim)
        self.combine = nn.Linear(
            self.type_embedding.embedding_dim + hid_dim + hid_dim,
            hid_dim,
        )

    def forward(self, inventory: list[dict]):
        type_embeddings = self.type_embedding([item["type"] for item in inventory])
        quantities = torch.tensor([item["quantity"] for item in inventory], dtype=torch.long)
        slots = torch.tensor([item["slot"] for item in inventory], dtype=torch.long)
        quantity_embeddings = self.quantity_embedding(quantities)
        slot_embeddings = self.slot_embedding(slots)
        x_concat = torch.cat(
            [type_embeddings, quantity_embeddings, slot_embeddings], dim=-1
        )
        embed = self.combine(x_concat)
        
        return embed


inv_embed = InventoryEmbedding(model, tokenizer)

In [48]:
import json

with open("data/train.json") as f:
    data = json.load(f)

In [55]:
inv_embed(data[0]["slotted_inventory"])

tensor([[-0.2238, -1.8951, -1.0089,  ...,  0.3079, -0.0340, -1.9083],
        [-2.3568, -0.7631,  0.4788,  ..., -0.3396, -0.4099, -2.5646],
        [-2.6034, -0.2149,  1.0169,  ..., -0.7544, -0.3539, -2.3843],
        ...,
        [-2.2182, -0.1333, -0.3205,  ..., -0.8458,  0.0336, -2.8775],
        [-2.5166, -0.1252,  0.1649,  ..., -0.5924, -0.0082, -2.9452],
        [-2.2456, -0.0594,  0.3108,  ...,  0.1623,  0.0161, -1.8329]],
       grad_fn=<AddmmBackward0>)

In [None]:
# load data/train.json

# with open('data/train.json') as f:
#     data = json.load(f)