In [1]:
import json
import glob
import torch

import torchvision.transforms as transforms
from PIL import Image

from torch.utils.data import Dataset
from torch.optim import AdamW


class PlancraftEnvironmentDataset(Dataset):
    def __init__(self, dataset_dir: str = "data/oracle", split="train"):
        super().__init__()
        self.split = split
        self.transform = transforms.ToTensor()
        data = []
        for example_path in sorted(glob.glob(f"{dataset_dir}/{split}/oa/*.json")):
            with open(example_path) as f:
                messages = json.load(f)
                environments = []
                for message in messages:
                    if "inventory=" in message["content"] and message["role"] == "user":
                        environments.append(self.clean((message["content"].split("\ninventory=")[-1])))
                example = {
                    "environments": environments,
                    "example_id": example_path.split("/")[-1].split(".json")[0],
                }
                data.append(example)

        print("Loading images")
        for example in data:
            example["images"] = []
            for message_idx, _ in enumerate(example["environments"]):
                img_path = f"{dataset_dir}/{split}/imgs/{example['example_id']}_{message_idx}.png"
                img = Image.open(img_path).convert("RGB")
                example["images"].append(img)

        self.dataset = []
        for example in data:
            for i, (env, img) in enumerate(zip(example["environments"], example["images"])):
                self.dataset.append((env, img))

    def __len__(self) -> int:
        return len(self.dataset)
    
    @staticmethod
    def clean(s: str):
        return s.replace('"type": "', "").replace('"quantity": ', "").replace('"index": ', "").replace('"', "").replace("{", "").replace("}", "").replace(",", "").replace("[", "").replace("]", "").replace("_", " ")

    def __getitem__(self, idx: int) -> tuple:
        return self.dataset[idx]
    
# Load the dataset
dataset = PlancraftEnvironmentDataset(split="train")
val_dataset = PlancraftEnvironmentDataset(split="val")

Loading images
Loading images


In [2]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    img_tensors = []
    texts = []
    for text, img in batch:
        img_tensors.append(img)
        texts.append(text)
    return {
        "images": img_tensors,
        "texts": texts,
    }
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, pin_memory=True)

In [3]:
import torch.nn.functional as F
from transformers import AutoModel, AutoProcessor
from tqdm import tqdm

# Load model and processor
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384", 
                                  attn_implementation="sdpa",
                                #   torch_dtype=torch.bfloat16,
                                  device_map="auto")
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")

# Define optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    for batch in tqdm(dataloader, total=len(dataset)//32):
        texts = batch["texts"]
        images = batch["images"]
        inputs = processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        outputs = model(**inputs)
        image_preds = torch.sigmoid(outputs.logits_per_image)
        text_preds = torch.sigmoid(outputs.logits_per_text)
        eye = torch.eye(len(texts), len(images), device=model.device)
        image_loss = F.binary_cross_entropy(image_preds, eye)
        text_loss = F.binary_cross_entropy(text_preds, eye)
        loss = image_loss + text_loss
        loss.backward()
        optimizer.step()
        
    # val_loss = 0
    # for batch in val_dataloader:
    #     texts = batch["texts"]
    #     images = batch["images"]
    #     inputs = processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
    #     inputs = {k: v.to(model.device) for k, v in inputs.items()}
    #     outputs = model(**inputs)
    #     image_preds = torch.sigmoid(outputs.logits_per_image)
    #     text_preds = torch.sigmoid(outputs.logits_per_text)
    #     # logits = torch.mm(text_embeddings, image_embeddings.T)
    #     # labels = torch.arange(len(text_embeddings), device=model.device)
    #     # labels
    #     loss = F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)
    #     val_loss += loss.item()
    # print(f"Validation loss for epoch {epoch}: {val_loss}")

  from .autonotebook import tqdm as notebook_tqdm
 12%|█▏        | 41/331 [02:45<19:31,  4.04s/it]


KeyboardInterrupt: 

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0')

tensor([[1.0000e+00, 2.7906e-18, 4.5203e-13,  ..., 9.9979e-01, 3.3730e-04,
         1.5441e-21],
        [1.9761e-09, 1.0000e+00, 1.7174e-14,  ..., 9.1930e-08, 8.2393e-08,
         3.1202e-02],
        [5.1788e-12, 3.6618e-11, 1.0000e+00,  ..., 5.0566e-15, 7.0461e-01,
         2.3677e-15],
        ...,
        [3.7981e-01, 2.7335e-12, 1.9660e-20,  ..., 4.8340e-04, 4.3409e-05,
         2.9892e-14],
        [1.5553e-01, 3.3465e-18, 1.0885e-05,  ..., 9.9999e-01, 1.0000e+00,
         6.4046e-22],
        [3.3447e-07, 1.1841e-08, 3.6278e-08,  ..., 1.3880e-12, 7.3904e-13,
         1.0000e+00]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [5]:
# Training loop
num_epochs = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
    
    for batch in progress_bar:
        texts, images = zip(*batch)
        inputs = processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs.input_ids.to(device)
        pixel_values = inputs.pixel_values.to(device)
         
#         outputs = model(input_ids=input_ids, pixel_values=pixel_values)
#         text_embeddings = outputs.text_embeds
#         image_embeddings = outputs.image_embeds
        
#         # Assuming you want to calculate contrastive loss
#         logits = torch.mm(text_embeddings, image_embeddings.T)
#         labels = torch.arange(len(text_embeddings)).to(device)
#         loss = CrossEntropyLoss()(logits, labels) + CrossEntropyLoss()(logits.T, labels)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss.item()
#         progress_bar.set_postfix({"loss": total_loss / len(progress_bar)})
        
#     print(f"Epoch {epoch + 1} - Loss: {total_loss / len(dataloader)}")


NameError: name 'dataloader' is not defined

In [24]:
# inventory = [
#     {"slot": 13, "type": "stick", "quantity": 2},
#     {"slot": 20, "type": "acacia_log", "quantity": 1},
#     {"slot": 43, "type": "dead_fire_coral", "quantity": 55},
#     {"slot": 27, "type": "acacia_leaves", "quantity": 11},
#     {"slot": 28, "type": "brown_mushroom", "quantity": 23},
#     {"slot": 14, "type": "llama_spawn_egg", "quantity": 22},
#     {"slot": 45, "type": "bat_spawn_egg", "quantity": 6},
#     {"slot": 23, "type": "oak_leaves", "quantity": 8},
#     {"slot": 34, "type": "diorite_slab", "quantity": 38},
#     {"slot": 22, "type": "dark_prismarine_slab", "quantity": 54},
# ]

class TypeEmbedding(nn.Module):
    def __init__(self, model=AutoModel, tokenizer=AutoTokenizer):
        super(TypeEmbedding, self).__init__()
        self.embedding_dim = model.config.hidden_size
        self.learnable_params = nn.Parameter(torch.randn(self.embedding_dim))
        self.model = model
        self.tokenizer = tokenizer
        self.cache = {}

    def forward(self, object_types: list[str]):
        batch, new_types = ([], [])
        for object_type in object_types:
            if object_type not in self.cache:
                batch.append(object_type)
                new_types.append(object_type)
        if len(new_types) > 0:
            inputs = self.tokenizer(new_types, return_tensors="pt", padding=True)
            inputs = {k: v.cuda() for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self.model(**inputs)
            for i, object_type in enumerate(new_types):
                type_embedding = outputs.last_hidden_state[i].mean(dim=0)
                self.cache[object_type] = type_embedding
        embeddings = [
            self.cache[object_type] + self.learnable_params
            for object_type in object_types
        ]
        return torch.stack(embeddings)

class InventoryEncoder(nn.Module):
    def __init__(
        self,
        model=AutoModel,
        tokenizer=AutoTokenizer,
        max_quantity=64,
        max_slot=46,
    ):
        super(InventoryEncoder, self).__init__()
        hidden_size = model.config.hidden_size
        self.type_embedding = TypeEmbedding(model, tokenizer)
        self.quantity_embedding = nn.Embedding(max_quantity, hidden_size)
        self.slot_embedding = nn.Embedding(max_slot, hidden_size)
        self.combine = nn.Linear(
            hidden_size * 3,
            hidden_size,
        )

    def forward(self, inventory: list[dict]):
        type_embeddings = self.type_embedding([item["type"] for item in inventory])
        quantities = torch.tensor(
            [item["quantity"] for item in inventory], dtype=torch.long
        )
        slots = torch.tensor([item["slot"] for item in inventory], dtype=torch.long)

        quantities = quantities.cuda()
        slots = slots.cuda()

        quantity_embeddings = self.quantity_embedding(quantities)
        slot_embeddings = self.slot_embedding(slots)
        x_concat = torch.cat(
            [type_embeddings, quantity_embeddings, slot_embeddings], dim=-1
        )
        embed = self.combine(x_concat).mean(dim=0)
        return embed


encoder = InventoryEncoder(model, tokenizer)
encoder = encoder.cuda()

In [31]:
class InventoryGenerator(nn.Module):
    def __init__(
        self,
        model=AutoModel,
        tokenizer=AutoTokenizer,
        max_quantity=64,
        max_slot=46,
    ):
        super(InventoryGenerator, self).__init__()
        hidden_size = model.config.hidden_size
        self.max_quantity = max_quantity
        self.max_slot = max_slot
        self.hidden_size = hidden_size

        self.fc = nn.Linear(hidden_size, hidden_size * 3)
        self.type_decoder = TypeEmbedding(model, tokenizer)
        self.quantity_decoder = nn.Linear(hidden_size, max_quantity)
        self.slot_decoder = nn.Linear(hidden_size, max_slot)

    def forward(self, inventory_embedding):
        x = self.fc(inventory_embedding)

        type_embeds, quantity_embeds, slot_embeds = torch.split(
            x, self.hidden_size, dim=-1
        )

        # Decode type embeddings
        decoded_types = self.type_decoder.decode(type_embeds)

        # Decode quantity and slot embeddings
        quantities = self.quantity_decoder(quantity_embeds)
        slots = self.slot_decoder(slot_embeds)

        # Convert logits to indices
        quantities = torch.argmax(quantities, dim=-1)
        slots = torch.argmax(slots, dim=-1)

        # Create the decoded inventory list
        decoded_inventory = []
        for obj_type, quantity, slot in zip(decoded_types, quantities, slots):
            decoded_inventory.append(
                {
                    "type": obj_type,
                    "quantity": quantity.item(),
                    "slot": slot.item(),
                }
            )

        return decoded_inventory


# Example of how to use the InventoryEmbedding and InventoryGenerator
class InventoryAutoencoder(nn.Module):
    def __init__(
        self,
        model=AutoModel,
        tokenizer=AutoTokenizer,
        max_quantity=64,
        max_slot=46,
    ):
        super(InventoryAutoencoder, self).__init__()
        self.encoder = InventoryEncoder(model, tokenizer, max_quantity, max_slot)
        self.decoder = InventoryGenerator(model, tokenizer, max_quantity, max_slot)

    def forward(self, inventory: list[dict]):
        encoded = self.encoder(inventory)
        decoded = self.decoder(encoded)
        return decoded


# Example usage
inventory = [
    {"slot": 13, "type": "stick", "quantity": 2},
    {"slot": 20, "type": "acacia_log", "quantity": 1},
    {"slot": 43, "type": "dead_fire_coral", "quantity": 55},
    {"slot": 27, "type": "acacia_leaves", "quantity": 11},
    {"slot": 28, "type": "brown_mushroom", "quantity": 23},
    {"slot": 14, "type": "llama_spawn_egg", "quantity": 22},
    {"slot": 45, "type": "bat_spawn_egg", "quantity": 6},
    {"slot": 23, "type": "oak_leaves", "quantity": 8},
    {"slot": 34, "type": "diorite_slab", "quantity": 38},
    {"slot": 22, "type": "dark_prismarine_slab", "quantity": 54},
]

autoencoder = InventoryAutoencoder(model, tokenizer)
autoencoder = autoencoder.to("cuda")
encoded_inventory = autoencoder.encoder(inventory)

In [None]:
autoencoder.decoder(encoded_inventory)


In [8]:
import json

with open("data/train.json") as f:
    data = json.load(f)

In [25]:
encoder(data[0]["slotted_inventory"])


tensor([ 0.5024,  0.7074,  1.2459,  ...,  0.2150, -0.3527, -0.0361],
       device='cuda:0', grad_fn=<MeanBackward1>)

In [None]:
# load data/train.json

# with open('data/train.json') as f:
#     data = json.load(f)

In [1]:
import openai
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated

from plancraft.environments.actions import (
    SymbolicMoveAction,
    SymbolicSmeltAction,
)

openai.pydantic_function_tool(SymbolicMoveAction)

  import distutils.spawn


{'type': 'function',
 'function': {'name': 'SymbolicMoveAction',
  'strict': True,
  'parameters': {'description': '"Moves an item from one slot to another',
   'properties': {'slot_from': {'exclusiveMaximum': 46,
     'minimum': 0,
     'title': 'Slot From',
     'type': 'integer'},
    'slot_to': {'exclusiveMaximum': 46,
     'minimum': 0,
     'title': 'Slot To',
     'type': 'integer'},
    'quantity': {'default': 1,
     'exclusiveMinimum': 0,
     'maximum': 64,
     'title': 'Quantity',
     'type': 'integer'}},
   'required': ['slot_from', 'slot_to', 'quantity'],
   'title': 'SymbolicMoveAction',
   'type': 'object',
   'additionalProperties': False},
  'description': ' "Moves an item from one slot to another'}}

In [1]:
from openai import OpenAI
from plancraft.models.act_tools import OpenAIToolsGenerator, ACT_TOOLS_SYSTEM_PROMPT, ACT_TOOLS_EXAMPLE


messages = [{"role": "system", "content": ACT_TOOLS_SYSTEM_PROMPT}] + ACT_TOOLS_EXAMPLE[:3]

model = OpenAIToolsGenerator()
client = OpenAI()

for m in messages:
    print(m)


  import distutils.spawn
  from .autonotebook import tqdm as notebook_tqdm


{'role': 'system', 'content': '\nYou are crafting in Minecraft. Actions are tools.\n\nThe first 10 slots in the inventory are reserved for crafting and correspond to the minecraft crafting table. \n\n[1, 2, 3] \n[4, 5, 6] -> [0]\n[7, 8, 9]\n\nThe crafting matrix is a 3x3 grid, and the output is sent to slot 0.\nYou cannot move or smelt items into output slot 0.\nThe remaining slots (10-45) are for storing items.\n'}
{'role': 'user', 'content': 'Craft an item of type: andesite\ninventory=\'[{"type": "diorite", "slot": 27, "quantity": 1},{"type": "cobblestone", "slot": 39, "quantity": 1}]\''}
{'role': 'assistant', 'content': 'SymbolicMoveAction(slot_from=27, slot_to=4, quantity=1)'}
{'role': 'user', 'content': 'Craft an item of type: andesite\ninventory=[{"type": "diorite", "slot": 4,  "quantity": 1},{"type": "cobblestone", "slot": 39, "quantity": 1}]'}


In [3]:
model.generate_next(batch_messages=[messages])

# response = client.chat.completions.create(
#         model="gpt-4o-mini",
#         messages=messages,
#         temperature=1.0,
#         max_tokens=256,
#         top_p=1,
#         frequency_penalty=0,
#         presence_penalty=0,
#         # stop=["\n"],
#         tools=[MOVE, SMELT],
#     )

([SymbolicMoveAction(slot_from=4, slot_to=1, quantity=1)],
 ['SymbolicMoveAction(slot_from=4, slot_to=1, quantity=1)'],
 465)

In [1]:
import glob

for x in glob.glob("configs/text-evals/*"):
    print(f"python main.py --config-name {x}")

python main.py --config-name configs/text-evals/act_eval_gpt4o_mini_few_shot.yaml
python main.py --config-name configs/text-evals/dummy.yaml
python main.py --config-name configs/text-evals/react_eval_llama70b_few_shot.yaml
python main.py --config-name configs/text-evals/act_eval_llama8b_lora.yaml
python main.py --config-name configs/text-evals/act_eval_llama8b_zero_shot.yaml
python main.py --config-name configs/text-evals/act_eval_llama8b_few_shot.yaml
python main.py --config-name configs/text-evals/react_eval_gpt4o_mini_few_shot.yaml
python main.py --config-name configs/text-evals/react_eval_llama8b_lora.yaml
python main.py --config-name configs/text-evals/react_eval_llama8b_few_shot.yaml
python main.py --config-name configs/text-evals/oracle.yaml
python main.py --config-name configs/text-evals/react_eval_llama8b_zero_shot.yaml
python main.py --config-name configs/text-evals/act_eval_llama70B_few_shot.yaml


In [51]:
import json 

def format_function_call(func_obj):
    # convert the OpenAI function object to a string
    # Function(arguments='{"from_slot": 4, "to_slot": 1, "quantity": 1}', name='move')
    # -> move(from_slot=4, to_slot=1, quantity=1)
    args = json.loads(func_obj.arguments).items()
    return f"{func_obj.name}({', '.join([f'{k}={v}' for k, v in args])})"








# format_function_call(response.choices[0].message.function[0])
# format_function_call(response.choices[0].message.tool_calls[0].function)
response.choices[0].message.tool_calls

[ChatCompletionMessageToolCall(id='call_9KPxJLWCnuKIAEPZO0ymE4GO', function=Function(arguments='{"from_slot": 4, "to_slot": 1, "quantity": 1}', name='move'), type='function'),
 ChatCompletionMessageToolCall(id='call_cTzmGRkiGne1gTd6SZFgr4Ho', function=Function(arguments='{"from_slot": 39, "to_slot": 2, "quantity": 1}', name='move'), type='function')]