## 9/1
- DiffPool

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import DenseSAGEConv, GCNConv, GATConv, TransformerConv, SAGEConv, GINConv
from torch_geometric.utils import to_dense_adj, to_dense_batch
from torch_geometric.nn.dense import dense_diff_pool

def build_conv(conv_type: str):
    """Return the specific gnn as`conv_type`"""
    if conv_type == "GCN":
        return GCNConv
    elif conv_type == "GIN":
        return lambda i, h: GINConv(
            nn.Sequential(nn.Linear(i, h), nn.ReLU(), nn.Linear(h, h))
        )
    elif conv_type == "GAT":
        return GATConv
    elif conv_type == "TransformerConv":
        return TransformerConv
    elif conv_type == "SAGE":
        return SAGEConv
    elif conv_type == "DenseSAGE":
        return DenseSAGEConv
    else:
        raise KeyError("GNN_TYPE can only be GAT, GCN, SAGE, GIN, and TransformerConv")


class GNNEncoder(nn.Module):
    """
    A Graph Neural Network Encoder that uses sparse graph convolutions.
    It can be configured with different GNN layers (GCN, GAT, SAGE, etc.).
    """
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers=1, gnn_type="DenseSAGEConv", dropout=0.0):
        super().__init__()

        conv = build_conv(gnn_type)

        self.gnn_type = gnn_type
        self.hidden_dim = hidden_dim 
        self.output_dim = output_dim
        self.dropout = dropout
        self.act = nn.LeakyReLU()
        
        self.conv_layers = nn.ModuleList()
        if n_layers == 1:
            self.conv_layers.append(conv(input_dim, output_dim))
            self.bns = nn.ModuleList()
        else:
            self.conv_layers.append(conv(input_dim, hidden_dim))
            for _ in range(n_layers - 2):
                self.conv_layers.append(conv(hidden_dim, hidden_dim))
            self.conv_layers.append(conv(hidden_dim, output_dim))
            self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(n_layers-1)])
        
    def reset_parameters(self):
        for conv in self.conv_layers:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
    
    def forward(self, x, edge_index):
        for i, graph_conv in enumerate(self.conv_layers[:-1]):
            x = graph_conv(x, edge_index)
            x = self.bns[i](x)
            x = self.act(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            
        node_emb = self.conv_layers[-1](x, edge_index)
        return node_emb


class DiffPoolGNNEncoder(nn.Module):
    """
    A Differentiable Pooling (DiffPool) layer that uses a GNNEncoder to learn
    node embeddings and cluster assignments. This module performs one level of pooling.
    
    It takes a sparse graph (or a batch of graphs) and returns a pooled, dense,
    and coarsened graph representation, along with the link prediction and entropy
    losses from the DiffPool operation.
    """
    def __init__(self, args):
        """
        Args:
            input_dim (int): Dimensionality of the input node features.
            hidden_dim (int): Dimensionality of the hidden layers in the GNNs.
            output_dim (int): Dimensionality of the output node features after pooling.
            num_clusters (int): The number of clusters to pool the nodes into.
            n_layers (int): The number of layers in the internal GNNEncoders.
            gnn_type (str): The type of GNN convolution to use (e.g., "GAT", "GCN").
            dropout (float): Dropout probability.
        """
        super().__init__()
        

        self.gnn_embed = GNNEncoder(args.gnn_in_dim, args.gnn_hidden_dim, args.gnn_output_dim, args.n_layers, args.gnn_type, args.dropout)

        self.gnn_pool = GNNEncoder(args.gnn_in_dim, args.gnn_hidden_dim, args.num_clusters, args.n_layers, args.gnn_type, args.dropout)
        
    def forward(self, x, edge_index, batch=None):

        if batch is None:
            batch = x.new_zeros(x.size(0), dtype=torch.long)
            
        x_embed = self.gnn_embed(x, edge_index)
        s = self.gnn_pool(x, edge_index)
        
        s = F.softmax(s, dim=-1)
        
        x_embed_dense, mask = to_dense_batch(x_embed, batch)
        s_dense, _ = to_dense_batch(s, batch)
        adj_dense = to_dense_adj(edge_index, batch)
        
        # 3. Apply the Differentiable Pooling operation
        x_pooled, adj_pooled, link_loss, entropy_loss = dense_diff_pool(
            x=x_embed_dense,
            adj=adj_dense,
            s=s_dense,
            mask=mask
        )
        
        return x_pooled


In [15]:
graph_tokenize = DiffPoolGNNEncoder(args)

AttributeError: 'Args' object has no attribute 'num_clusters'

## 9/9 
- centrality 추가

In [2]:
from torch_geometric.utils import degree
import torch.nn as nn 
import torch
import numpy as np
from torch_geometric.data import Data
from gnn import DiffPoolGNNEncoder

import sys
sys.path.append("../")
from utils import init_random_state, load_tool, get_cur_time

In [None]:

class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Args(
    dataset="huggingface",
    llm="Mistral-7B",
    seed=0,
    device="cuda:0",
    max_txt_length=512,
    max_ans_length=256,
    gnn_in_dim=1024,
    gnn_hidden_dim=1024,
    gnn_output_dim=2560, # mistral-7b: 4096, codellama-13b: 5120, gpt-oss-20b: 2880, # gemma-3-4b-it: 2560
    n_layers=2,
    gnn_type="SAGE",
    num_epochs=4,
    batch_size=6,
    eval_batch_size=6,
    patience=2,
    lr=1e-5,
    wd=0.05,
    dropout=0.0,
    num_clusters=10,
    output_dir="output",
    grad_steps=4
)
device="cuda"

tool_texts, tool2index, index2tool, edge_index, _, adj_g = load_tool(dataset_name=args.dataset)

task_graph = Data(x=torch.FloatTensor(np.load(f"process/{args.dataset}.npy")), edge_index=edge_index).to(device)

out_degree = degree(task_graph.edge_index[0], dtype=torch.long).to(device)
in_degree = degree(task_graph.edge_index[1], dtype=torch.long).to(device)
task_graph.out_degree = out_degree
task_graph.in_degree = in_degree

  d_inv = np.power(row_sum, -0.5).flatten()


In [6]:
task_graph.out_degree

tensor([12, 12, 12, 12, 12, 12,  0, 13, 13, 13,  9, 13, 10,  0, 12, 12,  9,  9,
         3, 13,  2, 13,  9], device='cuda:0')

In [7]:
task_graph.in_degree

tensor([13, 13, 13, 13, 13, 13, 14,  5,  5,  5,  4,  5, 14, 14, 18, 18,  4,  4,
        14,  2,  1,  2, 18], device='cuda:0')

In [8]:
in_degree_encoder = nn.Embedding(23, args.gnn_hidden_dim, padding_idx=0).to(device)
out_degree_encoder = nn.Embedding(23, args.gnn_hidden_dim, padding_idx=0).to(device)

In [11]:
in_degree_encoder(task_graph.in_degree)

tensor([[ 1.3691,  1.0561,  0.2520,  ..., -1.1737,  1.5478, -1.0749],
        [ 1.3691,  1.0561,  0.2520,  ..., -1.1737,  1.5478, -1.0749],
        [ 1.3691,  1.0561,  0.2520,  ..., -1.1737,  1.5478, -1.0749],
        ...,
        [ 1.2557,  0.1254, -0.3001,  ...,  0.4116, -0.0029,  0.3576],
        [-2.0151,  0.9396, -1.0036,  ...,  0.8895, -0.8732,  1.0593],
        [-0.3082, -1.2320, -0.4838,  ...,  0.2067,  1.3552, -0.9189]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

In [10]:
out_degree_encoder(task_graph.out_degree)

tensor([[ 0.8714, -0.7760, -0.1889,  ..., -1.5068, -2.3798,  0.1450],
        [ 0.8714, -0.7760, -0.1889,  ..., -1.5068, -2.3798,  0.1450],
        [ 0.8714, -0.7760, -0.1889,  ..., -1.5068, -2.3798,  0.1450],
        ...,
        [-0.7180,  0.1857, -0.3028,  ..., -1.2868,  0.1780, -0.8984],
        [ 0.1651, -0.8536, -0.1650,  ..., -0.0422,  0.7681,  0.0489],
        [ 0.1489, -0.0531, -0.0896,  ...,  0.0845,  0.0602,  1.2075]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

### diffpool encoder

In [7]:
graph_tokenizer = DiffPoolGNNEncoder(
        input_dim=args.gnn_in_dim, 
        hidden_dim=args.gnn_hidden_dim, 
        output_dim=args.gnn_output_dim, 
        num_nodes=23,
        n_layers=args.n_layers, 
        gnn_type=args.gnn_type
        )

TypeError: build_conv() takes 1 positional argument but 3 were given

In [1]:
int(23*0.8)

18

## 9/5 
- graphtoken node version main practice

In [3]:
%load_ext autoreload
%autoreload 2
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
# from graph_llm import GraphToken 
from glm_node import GraphToken
from plan_dataset import TaskPlanningDataset
import json
import sys
sys.path.append("../")
from utils import init_random_state, load_tool, get_cur_time
from torch_geometric.data import Data





class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Args(
    dataset="huggingface",
    llm="Mistral-7B",
    seed=0,
    device="cuda:0",
    max_txt_length=512,
    max_ans_length=256,
    gnn_in_dim=1024,
    gnn_hidden_dim=1024,
    gnn_output_dim=2560, # mistral-7b: 4096, codellama-13b: 5120, gpt-oss-20b: 2880, # gemma-3-4b-it: 2560
    n_layers=2,
    gnn_type="SAGE",
    num_epochs=4,
    batch_size=6,
    eval_batch_size=6,
    patience=2,
    lr=1e-5,
    wd=0.05,
    output_dir="output",
    grad_steps=4
)

path_mapping = {
    "CodeLlama-13B": "codellama/CodeLlama-13b-Instruct-hf",
    "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
    "CodeLlama-7B": "codellama/CodeLlama-7b-Instruct-hf",
    "Vicuna-13B": "lmsys/vicuna-13b-v1.5",
    "gpt-oss-20b": "openai/gpt-oss-20b",
    "gemma-3-270m-it": "google/gemma-3-270m-it",
    "gemma-3-4b-it": "google/gemma-3-4b-it"
}


gnn_hidden_mapping = {"CodeLlama-13B": 5120, "Mistral-7B": 4096, "Vicuna-13B": 5120, "CodeLlama-7B": 4096, "gpt-oss-20b": 2880, "gemma-3-4b-it": 2560}
args.llm_model_path = path_mapping[args.llm]
args.gnn_output_dim = gnn_hidden_mapping[args.llm]


plan_dataset = TaskPlanningDataset(args.dataset)

train_ids = plan_dataset.idxes_split["train"]
test_ids = plan_dataset.idxes_split["test"]

train_dataset = [plan_dataset[i] for i in train_ids[: int(0.8 * len(train_ids))]]
eval_dataset = [plan_dataset[i] for i in train_ids[int(0.8 * len(train_ids)) :]]
test_dataset = [plan_dataset[i] for i in test_ids]

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, pin_memory=True, shuffle=True)
val_loader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False)

model = GraphToken(args)
params = [p for _, p in model.named_parameters() if p.requires_grad]
trainable_params, all_params = model.print_trainable_params()
print(f"{trainable_params:,}")
print(f"{all_params:,}")



device='cuda:0'
tool_texts, tool2index, index2tool, edge_index, _, adj_g = load_tool(dataset_name=args.dataset)
task_graph = Data(x=torch.FloatTensor(np.load(f"process/{args.dataset}.npy")), edge_index=edge_index).to(device)

num_training_steps = args.num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))

best_val_loss = float('inf')

num_training_steps = args.num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
best_val_loss = float('inf')
model.model.gradient_checkpointing_enable() 

[Training Data] # Chain Samples 1509 (50.30)
[Data Split] # Train 3000  # Test 500


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.99it/s]
  d_inv = np.power(row_sum, -0.5).flatten()
  return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape))


Finish loading pre-trained Mistral-7B model!
10,492,928
7,252,225,024


  0%|          | 0/1600 [00:00<?, ?it/s]


In [3]:
samples = next(iter(train_loader))

In [7]:
loss = model(samples, task_graph)

  return torch.cuda.amp.autocast(dtype=dtype)


In [8]:
loss

tensor(0.9793, device='cuda:0', grad_fn=<NllLossBackward0>)

### forward()

In [4]:
samples

{'id': tensor([ 622, 2236,  826,  756, 2314,   96]),
 'origin_id': ['26775635',
  '17121963',
  '14809338',
  '30836617',
  '86015657',
  '20773083'],
 'request': ['# TASK LIST #:\nToken Classification, Translation, Summarization, Question Answering, Conversational, Text Generation, Sentence Similarity, Tabular Classification, Object Detection, Image Classification, Image-to-Image, Image-to-Text, Text-to-Image, Text-to-Video, Visual Question Answering, Document Question Answering, Image Segmentation, Depth Estimation, Text-to-Speech, Automatic Speech Recognition, Audio-to-Audio, Audio Classification, Image Editing\n\n# GOAL #\nPlease understand the user\'s request and generate task steps and task invocation graph to solve it.\n\n# REQUIREMENT #\n1. The format must in a strict JSON format as {"task_steps": [ concrete step descriptions ], "task_nodes": [ a list of tasks to be executed in sequence to fulfill user\'s request ], "task_links": [{"source": "task name i", "target": "task name 

In [22]:
# BOS = '<s>[INST]'
# EOS_USER = '[/INST]'
# EOS = '</s>'
# IGNORE_INDEX = -100 


# encode prompts, user requests, and labels 
requests = model.tokenizer(samples["request"], add_special_tokens=False)
labels = model.tokenizer(samples["label"], add_special_tokens=False)

# encode special tokens
eos_tokens = model.tokenizer(model.EOS, add_special_tokens=False)
eos_user_tokens = model.tokenizer(model.EOS_USER, add_special_tokens=False)
bos_tokens = model.tokenizer(model.BOS, add_special_tokens=False, return_tensors='pt').input_ids[0]
bos_embeds = model.word_embedding(bos_tokens.to(model.device))
pad_embeds = model.word_embedding(torch.tensor(model.tokenizer.pad_token_id).to(model.device)).unsqueeze(0)

batch_size = len(samples['id'])
# encode graphs 
node_embeds = model.encode_task_graph(task_graph, batch_size)


In [23]:
print(requests)
print(labels)

print(eos_tokens)
print(eos_user_tokens)
print(bos_tokens)
print(bos_embeds.shape)
print(pad_embeds.shape)
print(batch_size)
print(node_embeds.shape)

{'input_ids': [[422, 320, 16804, 393, 8048, 422, 28747, 13, 3856, 4950, 2500, 28725, 4335, 1465, 28725, 6927, 3479, 1837, 28725, 22478, 1094, 1616, 2131, 28725, 1325, 740, 1249, 28725, 7379, 26802, 28725, 318, 308, 636, 24232, 472, 28725, 14319, 1098, 4950, 2500, 28725, 4625, 384, 22820, 28725, 9833, 4950, 2500, 28725, 9833, 28733, 532, 28733, 4176, 28725, 9833, 28733, 532, 28733, 1874, 28725, 7379, 28733, 532, 28733, 4176, 28725, 7379, 28733, 532, 28733, 11761, 28725, 24497, 22478, 1094, 1616, 2131, 28725, 14873, 22478, 1094, 1616, 2131, 28725, 9833, 9594, 466, 352, 28725, 3995, 362, 3978, 8258, 28725, 7379, 28733, 532, 28733, 24812, 5295, 28725, 15939, 1711, 8819, 5295, 3523, 3159, 685, 28725, 16957, 28733, 532, 28733, 13361, 28725, 16957, 4950, 2500, 28725, 9833, 2690, 4328, 13, 13, 28771, 15044, 1086, 422, 13, 12069, 2380, 272, 2188, 28742, 28713, 2159, 304, 8270, 3638, 5944, 304, 3638, 1304, 10001, 5246, 298, 12049, 378, 28723, 13, 13, 28771, 4515, 28824, 5057, 896, 7178, 422, 13,

In [24]:
batch_inputs_embeds = []
batch_attention_masks = []
batch_label_input_ids = []

for i in range(batch_size):

    label_input_ids = labels.input_ids[i][:model.max_new_tokens] + eos_tokens.input_ids 
    input_ids = requests.input_ids[i][:model.max_txt_len] + eos_user_tokens.input_ids + label_input_ids

    input_embeds = model.word_embedding(torch.tensor(input_ids).to(model.device))
    input_embeds = torch.cat([bos_embeds, node_embeds, input_embeds], dim=0)
    
    batch_inputs_embeds.append(input_embeds)
    
    num_graph_tokens = node_embeds.shape[0]
    seq_len = input_embeds.shape[0]
    mask = torch.tril(torch.ones((seq_len, seq_len), device=model.device))
    graph_end_idx = len(bos_tokens) + num_graph_tokens
    mask[:graph_end_idx, :graph_end_idx] = 1
    batch_attention_masks.append(mask)

    label_input_ids = [model.IGNORE_INDEX] * (input_embeds.shape[0] - len(label_input_ids)) + label_input_ids
    batch_label_input_ids.append(label_input_ids)
    
    

In [25]:
print(batch_inputs_embeds[0].shape)
print(batch_attention_masks[0].shape)
print(len(batch_label_input_ids[0]))

torch.Size([520, 4096])
torch.Size([520, 520])
520


In [26]:

max_length = max([x.shape[0] for x in batch_inputs_embeds])
# added -----------------
attention_mask = torch.zeros(batch_size, max_length, max_length, device=model.device)

In [27]:
for i in range(batch_size):
    pad_length = max_length - batch_inputs_embeds[i].shape[0]
    print(pad_length)
    
    batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
    
    # changed -------------
    attention_mask[i, pad_length:, pad_length:] = batch_attention_masks[i]
    # ----------------------
    
    batch_label_input_ids[i] = [model.IGNORE_INDEX] * pad_length + batch_label_input_ids[i]

input_embeds = torch.stack(batch_inputs_embeds, dim=0).to(model.model.device, model.model.dtype)
# changed -------------------------------------
# attention_mask = torch.tensor(batch_attention_mask).to(model.model.device)
attention_mask = attention_mask.unsqueeze(1).to(model.model.device, model.model.dtype) # added dtype
# ------------------------------------------
label_input_ids = torch.tensor(batch_label_input_ids).to(model.model.device)
batch_label_input_ids.append(label_input_ids)


153
84
230
51
116
0


In [28]:
print(input_embeds.shape, input_embeds.dtype)
print(label_input_ids.shape, label_input_ids.dtype)
print(attention_mask.shape, attention_mask.dtype)

torch.Size([6, 673, 4096]) torch.float16
torch.Size([6, 673]) torch.int64
torch.Size([6, 1, 673, 673]) torch.float16


In [29]:
len(label_input_ids[0])

673

In [40]:
label_input_ids[0][600:]

tensor([16186,   272, 16776,  1204,   778,  8666,  8883,   345,  5553, 28730,
        12333,  1264,  7367,  4176,  4950,  2500,   548,   345,  4176,  2690,
         4328,   548,   345,  1874, 28733,   532, 28733, 24812,  5295,  8883,
          345,  5553, 28730, 17052,  1264,   733,  6799,  1394,  1264,   345,
         4176,  4950,  2500,   548,   345,  3731,  1264,   345,  1874, 28733,
          532, 28733, 24812,  5295,  7706,  9830,  1394,  1264,   345,  4176,
         2690,  4328,   548,   345,  3731,  1264,   345,  4176,  4950,  2500,
        17395,  9205,     2], device='cuda:0')

In [46]:
model.tokenizer.decode(label_input_ids[0][565:])

'steps": ["Step 1: Modify the input image to match the given text description", "Step 2: Classify the edited image", "Step 3: Convert the classification result into speech"], "task_nodes": ["Image Classification", "Image Editing", "Text-to-Speech"], "task_links": [{"source": "Image Classification", "target": "Text-to-Speech"}, {"source": "Image Editing", "target": "Image Classification"}]}</s>'

In [47]:
outputs = model.model(
    inputs_embeds=input_embeds,
    attention_mask=attention_mask,
    return_dict=True,
    labels=label_input_ids
)

In [51]:
outputs.keys()

odict_keys(['loss', 'logits', 'past_key_values'])

In [54]:
print(outputs['loss'])
print(outputs['logits'].shape)
print(outputs['past_key_values'])

tensor(0.9768, device='cuda:0', grad_fn=<NllLossBackward0>)
torch.Size([6, 515, 32000])
DynamicCache(layers=[<transformers.cache_utils.DynamicLayer object at 0x7fef25392b20>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296490>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c2967c0>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c2969d0>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296340>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296970>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296e50>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296d60>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c6040>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c65e0>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c62b0>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c6730>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c6a60>, <transformers.cache_utils.Dynamic

DynamicCache(layers=[<transformers.cache_utils.DynamicLayer object at 0x7fef25392b20>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296490>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c2967c0>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c2969d0>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296340>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296970>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296e50>, <transformers.cache_utils.DynamicLayer object at 0x7fef4c296d60>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c6040>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c65e0>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c62b0>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c6730>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c6a60>, <transformers.cache_utils.DynamicLayer object at 0x7fef253c87c0>, <transformers.cache_utils.DynamicLayer object at 0x7fef

### inference

In [19]:
samples = next(iter(test_loader))

In [20]:
requests = model.tokenizer(samples["request"], add_special_tokens=False)

eos_user_tokens = model.tokenizer(model.EOS_USER, add_special_tokens=False)
bos_tokens = model.tokenizer(model.BOS, add_special_tokens=False, return_tensors='pt').input_ids[0]
bos_embeds = model.word_embedding(bos_tokens.to(model.device))
pad_embeds = model.word_embedding(torch.tensor(model.tokenizer.pad_token_id).to(model.device)).unsqueeze(0)

batch_size = len(samples["id"])
node_embeds = model.encode_task_graph(task_graph, batch_size)
num_graph_tokens = node_embeds.shape[0]

batch_inputs_embeds = []
batch_attention_masks = []

for i in range(batch_size):
    input_ids = requests.input_ids[i][:model.max_txt_len] + eos_user_tokens.input_ids
    input_embeds = model.word_embedding(torch.tensor(input_ids).to(model.model.device))
    input_embeds = torch.cat([bos_embeds, node_embeds, input_embeds], dim=0)
    batch_inputs_embeds.append(input_embeds)

    seq_len = input_embeds.shape[0]
    mask = torch.tril(torch.ones((seq_len, seq_len), device=model.device))
    graph_end_idx = len(bos_tokens) + num_graph_tokens
    mask[:graph_end_idx, :graph_end_idx] = 1
    batch_attention_masks.append(mask)


max_length = max([x.shape[0] for x in batch_inputs_embeds])
print(max_length)

attention_mask = torch.zeros(batch_size, max_length, max_length, device=model.device)
print(attention_mask.shape)


for i in range(batch_size):
    pad_length = max_length - batch_inputs_embeds[i].shape[0]
    batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
    attention_mask[i, pad_length:, pad_length:] = batch_attention_masks[i]

input_embeds = torch.stack(batch_inputs_embeds, dim=0).to(model.model.device, model.model.dtype)
attention_mask = attention_mask.unsqueeze(1).to(model.model.device, model.model.dtype)

print(input_embeds.shape, input_embeds.dtype)
print(attention_mask.shape, attention_mask.dtype)

524
torch.Size([6, 524, 524])
torch.Size([6, 524, 4096]) torch.float16
torch.Size([6, 1, 524, 524]) torch.float16


In [21]:
from ckpt import reload_best_model
model = reload_best_model(model, args)

Loading checkpoint from output/huggingface/Mistral-7B/SAGE_Epoch4_checkpoint_best.pth


#### efficient version

In [22]:
print(input_embeds.shape)
print(attention_mask.shape)

torch.Size([6, 524, 4096])
torch.Size([6, 1, 524, 524])


In [23]:
# import torch

# # --- Corrected and Optimized Custom Generation Loop ---

# # This assumes 'input_embeds' and the initial 'attention_mask' are defined.
# batch_size = input_embeds.shape[0]

# Store the generated tokens (not embeddings)
generated_ids = torch.empty(batch_size, 0, dtype=torch.long, device=model.device)
initial_sequence_length = input_embeds.shape[1]

with torch.inference_mode():
    # --- 1. Priming Step (First Forward Pass) ---
    # Create initial position_ids for the prompt
    position_ids = torch.arange(
        0, initial_sequence_length, dtype=torch.long, device=model.device
    ).unsqueeze(0)

    # First forward pass to get the initial cache
    outputs = model.model(
        inputs_embeds=input_embeds,
        attention_mask=attention_mask,
        position_ids=position_ids,
        use_cache=True,
    )
    past_key_values = outputs.past_key_values

    # Get the first token to start the generation
    next_token_logits = outputs.logits[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)
    generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1)

    # --- 2. Generation Loop (Using the Cache) ---
    current_sequence_length = initial_sequence_length
    for _ in range(512): 
        # Check for EOS token to stop generation
        if (next_token == model.tokenizer.eos_token_id).all():
            break

        # Move to the next position
        current_sequence_length += 1
        
        # **FIX 1: Create the correct position_ids for the new token**
        # It's just the index of the last token in the sequence.
        position_ids = torch.tensor(
            [[current_sequence_length - 1]], device=model.device, dtype=torch.long
        ).expand(batch_size, -1)

        # **FIX 2: Create a simple attention mask for the new token**
        # Shape: (batch_size, total_sequence_length). It's all ones because the new
        # token can attend to all previous tokens in the cache.
        attention_mask = torch.ones(
            (batch_size, current_sequence_length), device=model.device, dtype=attention_mask.dtype
        )

        # Get the embedding for ONLY the last generated token
        next_token_embeds = model.word_embedding(next_token).unsqueeze(1)

        # --- Forward pass with cache, position_ids, and corrected mask ---
        outputs = model.model(
            inputs_embeds=next_token_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=True,
            past_key_values=past_key_values,
        )

        # Update the cache
        past_key_values = outputs.past_key_values

        # Get the next token
        next_token_logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1)

        # Append the new token to our results
        generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1)

# Decode the final generated sequence
pred = model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [7]:
len(pred)

6

In [31]:
print(pred[5])

1. The format must in a strict JSON format as {"task_steps": [ concrete step descriptions ], "task_nodes": [ a list of tasks to be executed in sequence to fulfill user's request ], "task_links": [{"source": "task name i", "target": "task name j"}]}
2. The generated task steps and task nodes can resolve the given user request perfectly. Task name must be selected from TASK LIST.
3. Task steps should strictly aligned with task nodes, and the number of task steps should be same with the task nodes.
4. The task links should reflect the dependencies among task nodes, i.e. the order in which the APIs are invoked.

# USER REQUEST #: Please find an answer to the question 'What is the Capital of France?' in the given example.jpg image and provide the answer as an audio file.
Now please generate your result in a strict JSON format:
# RESULT #: [/INST] 1. The format must in a strict JSON format as {"task_steps": [ concrete step descriptions ], "task_nodes": [ a list of tasks to be executed in seq

In [18]:
print(samples['request'][0])

# TASK LIST #:
Token Classification, Translation, Summarization, Question Answering, Conversational, Text Generation, Sentence Similarity, Tabular Classification, Object Detection, Image Classification, Image-to-Image, Image-to-Text, Text-to-Image, Text-to-Video, Visual Question Answering, Document Question Answering, Image Segmentation, Depth Estimation, Text-to-Speech, Automatic Speech Recognition, Audio-to-Audio, Audio Classification, Image Editing

# GOAL #
Please understand the user's request and generate task steps and task invocation graph to solve it.

# REQUIREMENT #
1. The format must in a strict JSON format as {"task_steps": [ concrete step descriptions ], "task_nodes": [ a list of tasks to be executed in sequence to fulfill user's request ], "task_links": [{"source": "task name i", "target": "task name j"}]}
2. The generated task steps and task nodes can resolve the given user request perfectly. Task name must be selected from TASK LIST.
3. Task steps should strictly aligne

## gnn_llm.py

In [2]:
import torch 
import torch.nn as nn 
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel

# llm_model_path = "google/gemma-3-270m"
# llm_model_path = "openai/gpt-oss-20b"
llm_model_path = "mistralai/Mistral-7B-Instruct-v0.2"


tokenizer = AutoTokenizer.from_pretrained(llm_model_path)
# model = AutoModelForCausalLM.from_pretrained(llm_model_path)

In [3]:
tokenizer.eos_token_id

2

### mistral

In [2]:
print(tokenizer.pad_token_id)
print(tokenizer.eos_token_id)
print(tokenizer.bos_token_id)
print(tokenizer.padding_side)

None
2
1
left


In [5]:
print(tokenizer.decode(tokenizer.eos_token_id))
print(tokenizer.decode(tokenizer.bos_token_id))


</s>
<s>


In [7]:
tokenizer.special_tokens_map_extended

{'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True)}

### gpt-oss

In [105]:
print(tokenizer.decode(tokenizer.bos_token_id))
print(tokenizer.decode(tokenizer.pad_token_id))
print(tokenizer.decode(tokenizer.eos_token_id))
print(tokenizer.decode(200003))
# print(tokenizer.decode(200004))
print(tokenizer.decode(200005))
print(tokenizer.decode(200006))
print(tokenizer.decode(200007))
print(tokenizer.decode(200008))
print(tokenizer.decode(200009))
print(tokenizer.decode(200010))
print(tokenizer.decode(200011))
print(tokenizer.decode(200012))
print(tokenizer.decode(200013))
print(tokenizer.decode(200014))
print(tokenizer.decode(200015))
print(tokenizer.decode(200016))
print(tokenizer.decode(200017))
print(tokenizer.decode(200018))
print(tokenizer.decode(200019))
print(tokenizer.decode(200020))

<|startoftext|>
<|endoftext|>
<|return|>
<|constrain|>
<|channel|>
<|start|>
<|end|>
<|message|>
<|reserved_200009|>
<|reserved_200010|>
<|reserved_200011|>
<|call|>
<|reserved_200013|>
<|reserved_200014|>
<|reserved_200015|>
<|reserved_200016|>
<|reserved_200017|>
<|endofprompt|>




In [7]:
print(tokenizer.decode(200006))

<|start|>


In [6]:
print(tokenizer.decode(200007))

<|end|>


In [8]:
print(tokenizer.decode(200008))

<|message|>


In [108]:
tokenizer.decode(199999)

'<|endoftext|>'

In [40]:
for i in range(30):
    print(f"{i} : {tokenizer.decode(i)}")



0 : <unk>
1 : <s>
2 : </s>
3 :  
4 : 
5 : 
6 : 
7 : 
8 : 
9 : 
10 : 
11 :
12 : 	
13 : 

14 : 
15 : 
16 : 
17 : 
18 : 
19 : 
20 : 
21 : 
22 : 
23 : 
24 : 
25 : 
26 : 
27 : 
28 : 
29 : 


In [9]:
BOS = '<s>[INST]'
EOS_USER = '[/INST]'
EOS = '</s>'

eos_user_tokens = tokenizer(EOS_USER, add_special_tokens=False)
bos_tokens = tokenizer(BOS, add_special_tokens=False)
eos_tokens = tokenizer(EOS, add_special_tokens=False)

In [10]:
print(eos_user_tokens)
print(bos_tokens)
print(eos_tokens)

{'input_ids': [49613, 34177, 236842], 'attention_mask': [1, 1, 1]}
{'input_ids': [203, 236840, 34177, 236842], 'attention_mask': [1, 1, 1, 1]}
{'input_ids': [212], 'attention_mask': [1]}


In [25]:
for i in range(30):
    print(f"{i} : {tokenizer.decode(i)}")



0 : <pad>
1 : <eos>
2 : <bos>
3 : <unk>
4 : <mask>
5 : [multimodal]
6 : <unused0>
7 : <unused1>
8 : <unused2>
9 : <unused3>
10 : <unused4>
11 : <unused5>
12 : <unused6>
13 : <unused7>
14 : <unused8>
15 : <unused9>
16 : <unused10>
17 : <unused11>
18 : <unused12>
19 : <unused13>
20 : <unused14>
21 : <unused15>
22 : <unused16>
23 : <unused17>
24 : <unused18>
25 : <unused19>
26 : <unused20>
27 : <unused21>
28 : <unused22>
29 : <unused23>


In [None]:
tokenizer.encode('[INST]')

[1, 733, 16289, 28793]

In [31]:
tokenizer.encode('<s>[INST]')

[1, 1, 28792, 16289, 28793]

In [28]:
BOS = '<s>[INST]'
EOS_USER = '[/INST]'
EOS = '</s>'

tokenizer(BOS, add_special_tokens=False, return_tensors='pt')

{'input_ids': tensor([[    1, 28792, 16289, 28793]]), 'attention_mask': tensor([[1, 1, 1, 1]])}

In [13]:
from transformers import AutoTokenizer

conversation = [
    {"role": "user", "content": "Hello! Can you tell me about the weather in Hanam-si today?"},
    {"role": "assistant", "content": "Of course! The weather in Hanam-si, Gyeonggi-do is currently sunny with a high of 28°C."},
    {"role": "user", "content": "That sounds great. What about tomorrow?"}
]

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

formatted_prompt = tokenizer.apply_chat_template(
    conversation,
    tokenize=False,
    add_generation_prompt=True
)

print(formatted_prompt)

<s> [INST] Hello! Can you tell me about the weather in Hanam-si today? [/INST] Of course! The weather in Hanam-si, Gyeonggi-do is currently sunny with a high of 28°C.</s> [INST] That sounds great. What about tomorrow? [/INST]


In [15]:
tokenizer.pad_token_id

In [20]:
tokenizer.decode(0)

'<unk>'

In [7]:
conversation = [
    {"role": "user", "content": "Hello! Can you tell me about the weather in Hanam-si today?"},
    {"role": "assistant", "content": "Of course! The weather in Hanam-si, Gyeonggi-do is currently sunny with a high of 28°C."},
    {"role": "user", "content": "That sounds great. What about tomorrow?"}
] 

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it")

formatted_prompt = tokenizer.apply_chat_template(
    conversation,
    tokenize=False,
    add_generation_prompt=True
)

print(formatted_prompt)

<bos><start_of_turn>user
Hello! Can you tell me about the weather in Hanam-si today?<end_of_turn>
<start_of_turn>model
Of course! The weather in Hanam-si, Gyeonggi-do is currently sunny with a high of 28°C.<end_of_turn>
<start_of_turn>user
That sounds great. What about tomorrow?<end_of_turn>
<start_of_turn>model



In [8]:
tokenizer.pad_token

'<pad>'

In [9]:
tokenizer.pad_token_id

0

In [10]:
conversation = [
    {"role": "user", "content": "Hello! Can you tell me about the weather in Hanam-si today?"},
    {"role": "assistant", "content": "Of course! The weather in Hanam-si, Gyeonggi-do is currently sunny with a high of 28°C."},
    {"role": "user", "content": "That sounds great. What about tomorrow?"}

]

tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")

formatted_prompt = tokenizer.apply_chat_template(
    conversation,
    tokenize=False,
    add_generation_prompt=True
)

print(formatted_prompt)

<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-08-29

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>Hello! Can you tell me about the weather in Hanam-si today?<|end|><|start|>assistant<|channel|>final<|message|>Of course! The weather in Hanam-si, Gyeonggi-do is currently sunny with a high of 28°C.<|end|><|start|>user<|message|>That sounds great. What about tomorrow?<|end|><|start|>assistant


In [11]:
print(tokenizer.decode(tokenizer.bos_token_id))
print(tokenizer.decode(tokenizer.pad_token_id))
print(tokenizer.decode(tokenizer.eos_token_id))

<|startoftext|>
<|endoftext|>
<|return|>


In [12]:
tokenizer.pad_token_id

199999

In [96]:
print(tokenizer.bos_token)
print(tokenizer.eos_token)
print(tokenizer.pad_token)

<|startoftext|>
<|return|>
<|endoftext|>


In [97]:
tokenizer.encode("<|pad|>")

[27, 91, 16730, 91, 29]

In [98]:
tokenizer.encode("<|start|>")

[200006]

In [99]:
tokenizer.encode("<|end|>")

[200007]

In [1]:
import torch 
import torch.nn as nn 
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoProcessor, Gemma3ForConditionalGeneration

device="cuda:0"
# llm_model_path = "google/gemma-3-4b-it"
# llm_model_path = "openai/gpt-oss-20b"
llm_model_path = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(llm_model_path)
# processor = AutoProcessor.from_pretrained(llm_model_path)
model = AutoModelForCausalLM.from_pretrained(
    llm_model_path, device_map="auto"
).eval()

# messages = [
#     {"role": "user", "content": {"text": "Hello! Can you explain what LLM is?"}},
# ]
messages = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "What is the capital of France?"},
]


inputs = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,return_dict=True, return_tensors="pt"
).to(model.device)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=500, do_sample=False)
    generation = generation[0]
    
decoded = tokenizer.decode(generation, skip_special_tokens=False)
print(decoded)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [01:02<00:00, 20.76s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> [INST] You are a helpful assistant

What is the capital of France? [/INST] The capital city of France is Paris. Paris is one of the most famous cities in the world and is known for its iconic landmarks such as the Eiffel Tower, Louvre Museum, Notre-Dame Cathedral, and the Arc de Triomphe. It is also home to many important cultural and artistic institutions. Paris is located in the northern part of France and is the country's most populous city.</s>


In [22]:
tokenizer.eos_token_id

2

In [None]:
import torch 
import torch.nn as nn 
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoProcessor, Gemma3ForConditionalGeneration

device="cuda:0"
llm_model_path = "google/gemma-3-4b-it"
# llm_model_path = "openai/gpt-oss-20b"
# llm_model_path = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(llm_model_path)
# processor = AutoProcessor.from_pretrained(llm_model_path)
model = AutoModelForCausalLM.from_pretrained(
    llm_model_path, device_map="auto"
).eval()

# messages = [
#     {"role": "user", "content": {"text": "Hello! Can you explain what LLM is?"}},
# ]
messages = [
    {"role": "user", "content": "Who is the president of the United States?"},
]

inputs = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,return_dict=True, return_tensors="pt"
).to(model.device)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=500, do_sample=False)
    generation = generation[0]
    
decoded = tokenizer.decode(generation, skip_special_tokens=False)
print(decoded)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


In [2]:
"google/gemma-3-4b-it".split("/")[0]

'google'

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


<bos><start_of_turn>user
Who is the president of the United States?<end_of_turn>
<start_of_turn>model
As of today, November 2, 2023, the President of the United States is **Joe Biden**. 

You can always find the most up-to-date information on the White House website: [https://www.whitehouse.gov/](https://www.whitehouse.gov/)<end_of_turn>


In [1]:
import torch 
import torch.nn as nn 
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoProcessor, Gemma3ForConditionalGeneration


llm_model_path = "google/gemma-3-12b-it"
# llm_model_path = "openai/gpt-oss-20b"
# llm_model_path = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(llm_model_path)
# processor = AutoProcessor.from_pretrained(llm_model_path)
model = AutoModelForCausalLM.from_pretrained(
    llm_model_path, device_map="auto"
).eval()

# messages = [
#     {"role": "user", "content": {"text": "Hello! Can you explain what LLM is?"}},
# ]
messages = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "What is the capital of France?"},
]

inputs = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,return_dict=True, return_tensors="pt"
).to(model.device)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=500, do_sample=False)
    generation = generation[0]
    
decoded = tokenizer.decode(generation, skip_special_tokens=False)
print(decoded)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 5/5 [00:07<00:00,  1.47s/it]
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


<bos><start_of_turn>user
You are a helpful assistant

What is the capital of France?<end_of_turn>
<start_of_turn>model
The capital of France is **Paris**. 🇫🇷
<end_of_turn>


In [2]:
messages = [
    {"role": "user", "content": "What is the capital of France?"},
]

inputs = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,return_dict=True, return_tensors="pt"
).to(model.device)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=500, do_sample=False)
    generation = generation[0]
    
decoded = tokenizer.decode(generation, skip_special_tokens=False)
print(decoded)

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


<bos><start_of_turn>user
What is the capital of France?<end_of_turn>
<start_of_turn>model
The capital of France is **Paris**.



It's also the largest city in France and a global center for art, fashion, gastronomy, and culture.<end_of_turn>


In [1]:
import torch 
import torch.nn as nn 
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoProcessor, Gemma3ForConditionalGeneration

device="cuda:0"
# llm_model_path = "google/gemma-3-4b-it"
llm_model_path = "openai/gpt-oss-20b"
# llm_model_path = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(llm_model_path)
# processor = AutoProcessor.from_pretrained(llm_model_path)
model = AutoModelForCausalLM.from_pretrained(
    llm_model_path, device_map="auto"
).eval()

messages = messages = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "What is the capital of France?"},
]

inputs = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,return_dict=True, return_tensors="pt"
).to(model.device)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=600, do_sample=False)
    generation = generation[0]
    
decoded = tokenizer.decode(generation, skip_special_tokens=False)
print(decoded)

  from .autonotebook import tqdm as notebook_tqdm
MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16
Loading checkpoint shards: 100%|██████████| 3/3 [00:24<00:00,  8.26s/it]


<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-08-30

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions

You are a helpful assistant

<|end|><|start|>user<|message|>What is the capital of France?<|end|><|start|>assistant<|channel|>analysis<|message|>We need to answer: "What is the capital of France?" The answer: Paris. Provide concise answer.<|end|><|start|>assistant<|channel|>final<|message|>The capital of France is **Paris**.<|return|>


<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-08-29

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>Who is the president of the United States?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks: "Who is the president of the United States?" This is a factual question. As of the current date, 2025-08-29, the president is Joe Biden? Wait, Joe Biden was president from 2021 to 2025. But as of 2025, the next election is in 2024. Joe Biden's term ends on January 20, 2025. The 2024 election will decide the next president. As of August 2025, the president would be the winner of the 2024 election. The 2024 election hasn't happened yet? Actually, the 2024 election is scheduled for November 5, 2024. So as of August 2025, the president would be the winner of that election. But we don't know the outcome. Howe

In [22]:
len(generation)

484

## main.py

In [1]:
%load_ext autoreload
%autoreload 2
from torch.utils.data import DataLoader
from graph_llm import GraphToken 
from plan_dataset import TaskPlanningDataset
import json
import sys
sys.path.append("../")
from utils import init_random_state, load_tool, get_cur_time



class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Args(
    dataset="huggingface",
    llm="Mistral-7B",
    seed=0,
    device="cuda:0",
    max_txt_length=512,
    max_ans_length=256,
    gnn_in_dim=1024,
    gnn_hidden_dim=1024,
    gnn_output_dim=2560, # mistral-7b: 4096, codellama-13b: 5120, gpt-oss-20b: 2880, # gemma-3-4b-it: 2560
    n_layers=2,
    gnn_type="SAGE",
    num_epochs=4,
    batch_size=6,
    eval_batch_size=6,
    patience=2,
    lr=1e-5,
    wd=0.05,
    output_dir="output",
    grad_steps=4
)

path_mapping = {
    "CodeLlama-13B": "codellama/CodeLlama-13b-Instruct-hf",
    "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
    "CodeLlama-7B": "codellama/CodeLlama-7b-Instruct-hf",
    "Vicuna-13B": "lmsys/vicuna-13b-v1.5",
    "gpt-oss-20b": "openai/gpt-oss-20b",
    "gemma-3-270m-it": "google/gemma-3-270m-it",
    "gemma-3-4b-it": "google/gemma-3-4b-it"
}


gnn_hidden_mapping = {"CodeLlama-13B": 5120, "Mistral-7B": 4096, "Vicuna-13B": 5120, "CodeLlama-7B": 4096, "gpt-oss-20b": 2880, "gemma-3-4b-it": 2560}
args.llm_model_path = path_mapping[args.llm]
args.gnn_output_dim = gnn_hidden_mapping[args.llm]


plan_dataset = TaskPlanningDataset(args.dataset)

train_ids = plan_dataset.idxes_split["train"]
test_ids = plan_dataset.idxes_split["test"]

train_dataset = [plan_dataset[i] for i in train_ids[: int(0.8 * len(train_ids))]]
eval_dataset = [plan_dataset[i] for i in train_ids[int(0.8 * len(train_ids)) :]]
test_dataset = [plan_dataset[i] for i in test_ids]

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, pin_memory=True, shuffle=True)
val_loader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False)

model = GraphToken(args)
params = [p for _, p in model.named_parameters() if p.requires_grad]
trainable_params, all_params = model.print_trainable_params()
print(f"{trainable_params:,}")
print(f"{all_params:,}")


from torch_geometric.data import Data
import numpy as np
from tqdm import tqdm
import torch

device='cuda:0'
tool_texts, tool2index, index2tool, edge_index, _, adj_g = load_tool(dataset_name=args.dataset)
task_graph = Data(x=torch.FloatTensor(np.load(f"process/{args.dataset}.npy")), edge_index=edge_index).to(device)


num_training_steps = args.num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
best_val_loss = float('inf')
model.model.gradient_checkpointing_enable() 

  from .autonotebook import tqdm as notebook_tqdm


[Training Data] # Chain Samples 1514 (50.47)
[Data Split] # Train 3000  # Test 500


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s]
  d_inv = np.power(row_sum, -0.5).flatten()
  return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape))


Finish loading pre-trained Mistral-7B model!
10,492,928
7,252,225,024


  0%|          | 0/1600 [00:00<?, ?it/s]

In [2]:
for step, batch in enumerate(train_loader):
    step = step
    batch = batch
    break

In [3]:
task_graph

Data(x=[23, 1024], edge_index=[2, 225])

In [4]:
batch

{'id': tensor([ 228,  810,  618, 1876,   98, 1049]),
 'origin_id': ['14719993',
  '24037587',
  '10482344',
  '12432470',
  '64894887',
  '11478569'],
 'request': ['# TASK LIST #:\nToken Classification, Translation, Summarization, Question Answering, Conversational, Text Generation, Sentence Similarity, Tabular Classification, Object Detection, Image Classification, Image-to-Image, Image-to-Text, Text-to-Image, Text-to-Video, Visual Question Answering, Document Question Answering, Image Segmentation, Depth Estimation, Text-to-Speech, Automatic Speech Recognition, Audio-to-Audio, Audio Classification, Image Editing\n\n# GOAL #\nPlease understand the user\'s request and generate task steps and task invocation graph to solve it.\n\n# REQUIREMENT #\n1. The format must in a strict JSON format as {"task_steps": [ concrete step descriptions ], "task_nodes": [ a list of tasks to be executed in sequence to fulfill user\'s request ], "task_links": [{"source": "task name i", "target": "task name 

#### inside forward(self, samples, task_graph)

In [8]:
samples = batch.copy()
print(samples)

{'id': tensor([ 228,  810,  618, 1876,   98, 1049]), 'origin_id': ['14719993', '24037587', '10482344', '12432470', '64894887', '11478569'], 'request': ['# TASK LIST #:\nToken Classification, Translation, Summarization, Question Answering, Conversational, Text Generation, Sentence Similarity, Tabular Classification, Object Detection, Image Classification, Image-to-Image, Image-to-Text, Text-to-Image, Text-to-Video, Visual Question Answering, Document Question Answering, Image Segmentation, Depth Estimation, Text-to-Speech, Automatic Speech Recognition, Audio-to-Audio, Audio Classification, Image Editing\n\n# GOAL #\nPlease understand the user\'s request and generate task steps and task invocation graph to solve it.\n\n# REQUIREMENT #\n1. The format must in a strict JSON format as {"task_steps": [ concrete step descriptions ], "task_nodes": [ a list of tasks to be executed in sequence to fulfill user\'s request ], "task_links": [{"source": "task name i", "target": "task name j"}]}\n2. Th

In [9]:
print(samples['request'][0])

# TASK LIST #:
Token Classification, Translation, Summarization, Question Answering, Conversational, Text Generation, Sentence Similarity, Tabular Classification, Object Detection, Image Classification, Image-to-Image, Image-to-Text, Text-to-Image, Text-to-Video, Visual Question Answering, Document Question Answering, Image Segmentation, Depth Estimation, Text-to-Speech, Automatic Speech Recognition, Audio-to-Audio, Audio Classification, Image Editing

# GOAL #
Please understand the user's request and generate task steps and task invocation graph to solve it.

# REQUIREMENT #
1. The format must in a strict JSON format as {"task_steps": [ concrete step descriptions ], "task_nodes": [ a list of tasks to be executed in sequence to fulfill user's request ], "task_links": [{"source": "task name i", "target": "task name j"}]}
2. The generated task steps and task nodes can resolve the given user request perfectly. Task name must be selected from TASK LIST.
3. Task steps should strictly aligne

In [10]:
requests = model.tokenizer(samples["request"])
print(requests)
samples["label"]

{'input_ids': [[1, 422, 320, 16804, 393, 8048, 422, 28747, 13, 3856, 4950, 2500, 28725, 4335, 1465, 28725, 6927, 3479, 1837, 28725, 22478, 1094, 1616, 2131, 28725, 1325, 740, 1249, 28725, 7379, 26802, 28725, 318, 308, 636, 24232, 472, 28725, 14319, 1098, 4950, 2500, 28725, 4625, 384, 22820, 28725, 9833, 4950, 2500, 28725, 9833, 28733, 532, 28733, 4176, 28725, 9833, 28733, 532, 28733, 1874, 28725, 7379, 28733, 532, 28733, 4176, 28725, 7379, 28733, 532, 28733, 11761, 28725, 24497, 22478, 1094, 1616, 2131, 28725, 14873, 22478, 1094, 1616, 2131, 28725, 9833, 9594, 466, 352, 28725, 3995, 362, 3978, 8258, 28725, 7379, 28733, 532, 28733, 24812, 5295, 28725, 15939, 1711, 8819, 5295, 3523, 3159, 685, 28725, 16957, 28733, 532, 28733, 13361, 28725, 16957, 4950, 2500, 28725, 9833, 2690, 4328, 13, 13, 28771, 15044, 1086, 422, 13, 12069, 2380, 272, 2188, 28742, 28713, 2159, 304, 8270, 3638, 5944, 304, 3638, 1304, 10001, 5246, 298, 12049, 378, 28723, 13, 13, 28771, 4515, 28824, 5057, 896, 7178, 422, 

['{"task_steps": ["Step 1: Translate the text description from English to French.", "Step 2: Edit the image according to the translated text description.", "Step 3: Perform image segmentation on the edited image."], "task_nodes": ["Translation", "Image Editing", "Image Segmentation"], "task_links": [{"source": "Translation", "target": "Image Editing"}, {"source": "Image Editing", "target": "Image Segmentation"}]}',
 '{"task_steps": ["Step 1: Determine the similarity between two sentences"], "task_nodes": ["Sentence Similarity"], "task_links": []}',
 '{"task_steps": ["Step 1: Estimate the depth of objects in an image using Depth Estimation tool."], "task_nodes": ["Depth Estimation"], "task_links": []}',
 '{"task_steps": ["Step 1: Use Text-to-Speech tool to convert the user-specified text into speech"], "task_nodes": ["Text-to-Speech"], "task_links": []}',
 '{"task_steps": ["Step 1: Convert the user provided audio file into text"], "task_nodes": ["Automatic Speech Recognition"], "task_li

In [11]:
print(samples["label"][0])

{"task_steps": ["Step 1: Translate the text description from English to French.", "Step 2: Edit the image according to the translated text description.", "Step 3: Perform image segmentation on the edited image."], "task_nodes": ["Translation", "Image Editing", "Image Segmentation"], "task_links": [{"source": "Translation", "target": "Image Editing"}, {"source": "Image Editing", "target": "Image Segmentation"}]}


In [12]:
labels = model.tokenizer(samples["label"], add_special_tokens=False)
print(labels)

{'input_ids': [[9830, 5553, 28730, 16005, 1264, 7367, 9977, 28705, 28740, 28747, 4335, 10020, 272, 2245, 5436, 477, 4300, 298, 4949, 9191, 345, 9977, 28705, 28750, 28747, 12838, 272, 3469, 4771, 298, 272, 19004, 2245, 5436, 9191, 345, 9977, 28705, 28770, 28747, 2744, 674, 3469, 10424, 352, 356, 272, 19527, 3469, 611, 1181, 345, 5553, 28730, 12333, 1264, 7367, 25825, 548, 345, 4176, 2690, 4328, 548, 345, 4176, 9594, 466, 352, 8883, 345, 5553, 28730, 17052, 1264, 733, 6799, 1394, 1264, 345, 25825, 548, 345, 3731, 1264, 345, 4176, 2690, 4328, 7706, 9830, 1394, 1264, 345, 4176, 2690, 4328, 548, 345, 3731, 1264, 345, 4176, 9594, 466, 352, 17395, 9205], [9830, 5553, 28730, 16005, 1264, 7367, 9977, 28705, 28740, 28747, 5158, 21824, 272, 3684, 472, 1444, 989, 23748, 8883, 345, 5553, 28730, 12333, 1264, 7367, 26968, 636, 24232, 472, 8883, 345, 5553, 28730, 17052, 1264, 3980, 28752], [9830, 5553, 28730, 16005, 1264, 7367, 9977, 28705, 28740, 28747, 3978, 3314, 272, 8478, 302, 6697, 297, 396, 346

In [13]:
BOS = '<s>[INST]'
EOS_USER = '[/INST]'
EOS = '</s>'

IGNORE_INDEX = -100 

In [14]:
eos_tokens = model.tokenizer(EOS, add_special_tokens=False)
eos_user_tokens = model.tokenizer(EOS_USER, add_special_tokens=False)
bos_embeds = model.word_embedding(model.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(model.device))

In [15]:
eos_tokens

{'input_ids': [2], 'attention_mask': [1]}

In [16]:
eos_user_tokens

{'input_ids': [733, 28748, 16289, 28793], 'attention_mask': [1, 1, 1, 1]}

In [17]:
bos_embeds

tensor([[-4.3640e-03, -1.0633e-04, -5.6152e-03,  ..., -5.0545e-05,
         -1.1520e-03,  1.5926e-04],
        [-6.7139e-04, -5.7983e-04, -3.1891e-03,  ..., -1.7071e-04,
          3.1281e-04,  8.5449e-04],
        [ 1.4496e-04,  5.0354e-04, -2.3499e-03,  ..., -2.5024e-03,
          3.2349e-03, -2.8229e-03],
        [-4.1504e-03, -1.7548e-03,  3.7231e-03,  ..., -1.2589e-04,
         -9.2697e-04,  3.2196e-03]], device='cuda:0', dtype=torch.float16)

In [18]:
pad_embeds = model.word_embedding(torch.tensor(model.tokenizer.pad_token_id).to(model.device)).unsqueeze(0)


In [19]:
model.tokenizer.pad_token_id

0

In [20]:
print(pad_embeds)
print(pad_embeds.size())

tensor([[-0., 0., -0.,  ..., -0., -0., -0.]], device='cuda:0',
       dtype=torch.float16)
torch.Size([1, 4096])


In [21]:
len(samples['id'])

6

In [22]:
batch_size = len(samples['id'])
print(batch_size)

6


In [25]:
graph_embeds = model.encode_task_graph(task_graph, batch_size)
print(graph_embeds.shape)

torch.Size([6, 4096])


In [26]:
graph_embeds[0].shape

torch.Size([4096])

In [28]:
graph_embeds = torch.mean(graph_embeds[0], dim=0, keepdim=True)
print(graph_embeds)

tensor([0.0041], device='cuda:0', grad_fn=<MeanBackward1>)


In [29]:
graph_embed = graph_embeds.repeat(batch_size,1)

In [30]:
graph_embed

tensor([[0.0041],
        [0.0041],
        [0.0041],
        [0.0041],
        [0.0041],
        [0.0041]], device='cuda:0', grad_fn=<RepeatBackward0>)

In [31]:
graph_embed.shape

torch.Size([6, 1])

In [32]:
model.max_new_tokens

256

In [33]:
labels.input_ids[0][:model.max_new_tokens] 

[9830,
 5553,
 28730,
 16005,
 1264,
 7367,
 9977,
 28705,
 28740,
 28747,
 4335,
 10020,
 272,
 2245,
 5436,
 477,
 4300,
 298,
 4949,
 9191,
 345,
 9977,
 28705,
 28750,
 28747,
 12838,
 272,
 3469,
 4771,
 298,
 272,
 19004,
 2245,
 5436,
 9191,
 345,
 9977,
 28705,
 28770,
 28747,
 2744,
 674,
 3469,
 10424,
 352,
 356,
 272,
 19527,
 3469,
 611,
 1181,
 345,
 5553,
 28730,
 12333,
 1264,
 7367,
 25825,
 548,
 345,
 4176,
 2690,
 4328,
 548,
 345,
 4176,
 9594,
 466,
 352,
 8883,
 345,
 5553,
 28730,
 17052,
 1264,
 733,
 6799,
 1394,
 1264,
 345,
 25825,
 548,
 345,
 3731,
 1264,
 345,
 4176,
 2690,
 4328,
 7706,
 9830,
 1394,
 1264,
 345,
 4176,
 2690,
 4328,
 548,
 345,
 3731,
 1264,
 345,
 4176,
 9594,
 466,
 352,
 17395,
 9205]

In [34]:
len(labels.input_ids[0][:model.max_new_tokens] )

108

In [44]:
eos_tokens

{'input_ids': [2], 'attention_mask': [1]}

In [45]:
labels.input_ids[0][:model.max_new_tokens] + eos_tokens.input_ids

[9830,
 5553,
 28730,
 16005,
 1264,
 7367,
 9977,
 28705,
 28740,
 5529,
 2841,
 28730,
 9701,
 2045,
 28730,
 1613,
 298,
 1388,
 272,
 13355,
 1369,
 304,
 3084,
 21448,
 1871,
 28725,
 3595,
 28725,
 304,
 11487,
 548,
 345,
 9977,
 28705,
 28750,
 5529,
 2841,
 28730,
 9701,
 2045,
 28730,
 1613,
 298,
 8270,
 264,
 13355,
 2264,
 2818,
 356,
 272,
 3857,
 1178,
 548,
 345,
 9977,
 28705,
 28770,
 5529,
 15382,
 28730,
 7141,
 17769,
 298,
 2623,
 396,
 15382,
 14211,
 288,
 7103,
 477,
 23330,
 904,
 602,
 4120,
 298,
 272,
 16099,
 302,
 22830,
 4120,
 28725,
 9720,
 1059,
 475,
 742,
 7645,
 4120,
 28725,
 1312,
 14517,
 272,
 25754,
 7103,
 304,
 23329,
 2948,
 15014,
 8883,
 345,
 5553,
 28730,
 12333,
 1264,
 7367,
 14908,
 28730,
 9701,
 2045,
 28730,
 1613,
 548,
 345,
 406,
 9405,
 28730,
 7141,
 17769,
 8883,
 345,
 5553,
 28730,
 17052,
 1264,
 733,
 6799,
 1394,
 1264,
 345,
 14908,
 28730,
 9701,
 2045,
 28730,
 1613,
 548,
 345,
 3731,
 1264,
 345,
 406,
 9405,
 2873

In [46]:
label_input_ids = labels.input_ids[0][:model.max_new_tokens] + eos_tokens.input_ids

In [None]:
requests.input_ids[0][:model.max_new_tokens]

[1,
 422,
 320,
 16804,
 393,
 8048,
 422,
 28747,
 13,
 2274,
 28730,
 7822,
 288,
 28725,
 3472,
 28730,
 9157,
 992,
 28730,
 14620,
 28725,
 4080,
 28730,
 28713,
 1033,
 28725,
 2841,
 28730,
 25973,
 2660,
 28730,
 1114,
 282,
 28730,
 5134,
 28725,
 3472,
 28730,
 357,
 9279,
 28725,
 8514,
 28730,
 5033,
 28730,
 3016,
 28725,
 8594,
 28730,
 20913,
 28725,
 10537,
 28730,
 262,
 18831,
 28725,
 9442,
 28730,
 452,
 7041,
 28725,
 14933,
 4609,
 28730,
 9157,
 992,
 28730,
 28717,
 13759,
 352,
 28725,
 9314,
 28730,
 2837,
 28730,
 7822,
 263,
 28725,
 1220,
 28730,
 357,
 9279,
 28725,
 7689,
 28730,
 3521,
 288,
 28730,
 3385,
 28725,
 726,
 28730,
 20913,
 28725,
 4530,
 28730,
 11009,
 28730,
 19963,
 28725,
 882,
 28730,
 14556,
 28725,
 3472,
 28730,
 3290,
 3507,
 1549,
 28725,
 5835,
 28730,
 28707,
 10106,
 28730,
 3385,
 28725,
 7223,
 28730,
 16714,
 28730,
 720,
 4078,
 28725,
 6790,
 28730,
 20913,
 28730,
 9307,
 28725,
 2093,
 28730,
 24945,
 28730,
 2360,
 2872

In [71]:
len(requests.input_ids[0][:model.max_new_tokens])

256

In [49]:
eos_user_tokens.input_ids

[733, 28748, 16289, 28793]

In [50]:
label_input_ids

[9830,
 5553,
 28730,
 16005,
 1264,
 7367,
 9977,
 28705,
 28740,
 5529,
 2841,
 28730,
 9701,
 2045,
 28730,
 1613,
 298,
 1388,
 272,
 13355,
 1369,
 304,
 3084,
 21448,
 1871,
 28725,
 3595,
 28725,
 304,
 11487,
 548,
 345,
 9977,
 28705,
 28750,
 5529,
 2841,
 28730,
 9701,
 2045,
 28730,
 1613,
 298,
 8270,
 264,
 13355,
 2264,
 2818,
 356,
 272,
 3857,
 1178,
 548,
 345,
 9977,
 28705,
 28770,
 5529,
 15382,
 28730,
 7141,
 17769,
 298,
 2623,
 396,
 15382,
 14211,
 288,
 7103,
 477,
 23330,
 904,
 602,
 4120,
 298,
 272,
 16099,
 302,
 22830,
 4120,
 28725,
 9720,
 1059,
 475,
 742,
 7645,
 4120,
 28725,
 1312,
 14517,
 272,
 25754,
 7103,
 304,
 23329,
 2948,
 15014,
 8883,
 345,
 5553,
 28730,
 12333,
 1264,
 7367,
 14908,
 28730,
 9701,
 2045,
 28730,
 1613,
 548,
 345,
 406,
 9405,
 28730,
 7141,
 17769,
 8883,
 345,
 5553,
 28730,
 17052,
 1264,
 733,
 6799,
 1394,
 1264,
 345,
 14908,
 28730,
 9701,
 2045,
 28730,
 1613,
 548,
 345,
 3731,
 1264,
 345,
 406,
 9405,
 2873

In [None]:
# user_requests + EOS_user + labels + EOS

input_ids = requests.input_ids[0][:model.max_new_tokens] + eos_user_tokens.input_ids + label_input_ids

In [56]:
len(input_ids)

407

In [53]:
input_embeds = model.word_embedding(torch.tensor(input_ids).to(model.device))

In [55]:
input_embeds.shape

torch.Size([407, 4096])

In [58]:
bos_embeds.shape

torch.Size([4, 4096])

In [68]:
graph_embeds[0].shape

torch.Size([4096])

In [None]:
# BOS + graph_embed + user_requests + EOS_user + labels + EOS

input_embeds = torch.cat([bos_embeds, graph_embeds[0].unsqueeze(0), input_embeds], dim=0)

In [None]:
# BOS + graph_embed + user_requests + EOS_user + labels + EOS
# 4   +       1         + 256         + 4       + 146    + 1

input_embeds.shape

torch.Size([412, 4096])

In [74]:
batch_inputs_embeds = []
batch_attention_mask = []
batch_label_input_ids = []

In [76]:
batch_inputs_embeds.append(input_embeds)

In [79]:
batch_attention_mask.append([1] * input_embeds.shape[0])
print(len([1] * input_embeds.shape[0]))
print([1] * input_embeds.shape[0])

412
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [80]:
label_input_ids = [IGNORE_INDEX] * (input_embeds.shape[0] - len(label_input_ids)) + label_input_ids

In [82]:
print(len(label_input_ids))
print(label_input_ids)

412
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,

In [83]:
batch_label_input_ids.append(label_input_ids)

In [85]:
max_length = max([x.shape[0] for x in batch_inputs_embeds])

In [86]:
max_length

412

## plan dataset

In [46]:
PROMPT = """\n\n# GOAL #\nPlease understand the user's request and generate task steps and task invocation graph to solve it.""" \
       + """\n\n# REQUIREMENT #\n1. The format must in a strict JSON format as {"task_steps": [ concrete step descriptions ], "task_nodes": [ a list of tasks to be executed in sequence to fulfill user's request ], "task_links": [{"source": "task name i", "target": "task name j"}]}\n""" \
       + """2. The generated task steps and task nodes can resolve the given user request perfectly. Task name must be selected from TASK LIST.\n""" \
       + """3. Task steps should strictly aligned with task nodes, and the number of task steps should be same with the task nodes.\n""" \
       + """4. The task links should reflect the dependencies among task nodes, i.e. the order in which the APIs are invoked.\n""" 



tool_list = json.load(open(f"../data/{args.dataset}/tool_desc.json", "r"))["nodes"]

tool_string = "# TASK LIST #:\n" + ", ".join([task["id"] for task in tool_list]) 

args.prompt = tool_string + PROMPT + """\n\n# USER REQUEST #: {{user_request}}\nNow please generate your result in a strict JSON format:\n# RESULT #:"""

In [48]:
print(args.prompt)

# TASK LIST #:
order_tracking, search_repair_provider, send_sms, special_vehicle_rental_service, search_agenda, menu_select_api, manage_schedule, detailed_inquiry, schedule_planner, appliance_repair_cancellation, flight_status_tracker, read_agenda, hotel_booking_query, import_schedule, travel_plan_maker, del_transaction, search_restaurants, train_ticket_query, foreign_currency_exchange, daily_schedule_manager, product_catalog_search, online_appointment_booking, academic_paper_search, book_meeting_room, loan_info_entry, smart_home_control, car_rental_cancelling, check_meeting_room_availability, train_ticket_booking, agenda_sorting, travel_group_schedule, luggage_check_in, create_document, change_password, souvenir_purchase, get_menu, select_best_hotel, search_conference_rooms, business_trip_ticket_search, flight_info_query, checkout_api, traffic_card, website_design_tool, business_communication, cruise_ticket_search, cruise_ship_booking, souvenir_recommender, company_vehicle_service, cr

## LLM test

In [6]:
import torch 
import torch.nn as nn 
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer


# llm_name = "google/gemma-3-4B-it"
llm_name = "google/gemma-3-270m-it"
model = AutoModelForCausalLM.from_pretrained(llm_name)
tokenizer = AutoTokenizer.from_pretrained(llm_name)

In [4]:
model.config

Gemma3TextConfig {
  "_sliding_window_pattern": 6,
  "architectures": [
    "Gemma3ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "attn_logit_softcapping": null,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "final_logit_softcapping": null,
  "head_dim": 256,
  "hidden_activation": "gelu_pytorch_tanh",
  "hidden_size": 640,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "layer_types": [
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "sliding_attention",
    "full_attention"
  ],
  "max_position_embeddings": 32768,
  "model_type": "gemma3_text",
  "num_attention_heads": 4,
  "num_hidden_layers": 18,
  "n

In [7]:
# model
seed=0
model.eval()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
prompt = "I'm a language model,"
tokens = tokenizer(prompt, return_tensors="pt")
x = tokens
x = x['input_ids']
x = x.repeat(5,1)


while x.size(1) < 30:
    with torch.no_grad():
        logits = model(x)['logits']
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        ix = torch.multinomial(topk_probs, 1)
        xcol = torch.gather(topk_indices, -1, ix)
        x = torch.cat((x, xcol), dim=1)
        
for i in range(5):
    tokens = x[i, :30].tolist()
    decoded = tokenizer.decode(tokens, skip_special_tokens=True)
    
    print(">", decoded)


> I'm a language model, I can generate text in many languages. However, I cannot truly *understand* the meaning of a word or
> I'm a language model, and I don't have the ability to directly interact with the real world to experience the world in a way
> I'm a language model, I don't have the capacity to interact with the world in a real-time way. However, I
> I'm a language model, I cannot directly interact with the real world. I can only generate text based on the data I have been trained
> I'm a language model, and I can generate text. I can write stories, poems, articles, and code. I am also able
