In [1]:
from agent.QCNN.QCNN import QCNN
from agent.Qtention.Qtention import Qtention
import torch
import math
import pickle
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
buffer = pickle.load(open("C:\\Users\\User\\Documents\\code\\rl-training-gold-miner\\buffers\\warmup_buffer_qtention.pkl", "rb"))

In [3]:
model = QCNN(
    d_model = 48,
    n_actions = 50, 
    d_hidden = 64   
)

In [4]:
def qtention_to_qcnn(
    type_ids: torch.Tensor,          # [31]
    item_feats: torch.Tensor,     # [31, 10] (env at 0)
    mov_idx: torch.Tensor,        # [n_mov] indices in token space (0..30), include env offset
    mov_feats: torch.Tensor,      # [n_mov, 3]
    length: int,                      # includes env
    max_items: int = 30,     # QCNN max_items; default = padded_len-1 (=30)
    sort_by_phi: bool = True,
):
    """
    Convert padded Qtention inputs -> QCNN-style inputs.

    Returns:
      env_feats:  [10]
      items_feats:[max_items, 23]  (10 base + 10 onehot(type) + 3 mov)
      mask:       [max_items]      (1 real, 0 pad)

    If return_permutation=True:
      order:      [n_items] (new = old[order]) after phi sort, else identity
      inv_order:  [n_items] (index -> new_index)
    """
    assert type_ids.ndim == 1, "type_ids must be [L]"
    assert item_feats.ndim == 2 and item_feats.size(-1) == 10, "item_feats must be [L,10]"
    Lpad = item_feats.size(0)
    if max_items is None:
        max_items = Lpad - 1  # 31 -> 30

    # sanitize length
    if isinstance(length, torch.Tensor):
        length = int(length.item())
    length = max(1, min(length, Lpad, int(type_ids.size(0))))  # at least env

    device = item_feats.device
    dtype = item_feats.dtype

    # env
    env_feats = item_feats[0]  # [10]

    # real items slice (exclude env)
    n_items = length - 1
    if n_items <= 0:
        # no items
        items_padded = torch.zeros((max_items, 23), device=device, dtype=torch.float32)
        mask = torch.zeros((max_items,), device=device, dtype=torch.float32)
        return env_feats, items_padded, mask

    x10 = item_feats[1:length].to(torch.float32)   # [n_items,10]
    t  = type_ids[1:length].to(torch.long)             # [n_items]

    # Qtention type ids: ENV=0, PAD=10, real types ~ 1..9
    # QCNN type ids: real 0..8, PAD=9  (onehot size 10)
    qcnn_ids = (t - 1).clamp(0, 9)  # PAD(10)->9, real(1..9)->0..8

    onehot = torch.zeros((n_items, 10), device=device, dtype=torch.float32)
    onehot.scatter_(1, qcnn_ids.view(-1, 1), 1.0)

    # dense movement [n_items,3] from sparse (mov_idx uses token indices incl env)
    mov_all = torch.zeros((n_items, 3), device=device, dtype=torch.float32)
    if mov_idx is not None and mov_feats is not None and mov_idx.numel() > 0:
        idx = mov_idx.to(torch.long)  # token indices 0..30
        feats = mov_feats.to(torch.float32)

        # keep only indices that fall inside real token range and are not env
        valid = (idx > 0) & (idx < length)
        if valid.any():
            ii = (idx[valid] - 1)  # align to [0..n_items-1]
            mov_all[ii] = feats[valid]  # overwrite if duplicates

    items_feats = torch.cat([x10, onehot, mov_all], dim=1)  # [n_items, 23]

    # optional sort by phi to match QCNN preprocessing
    if sort_by_phi:
        sin_phi = x10[:, 3]
        cos_phi = x10[:, 4]
        phi = torch.atan2(sin_phi, cos_phi)
        order = torch.argsort(phi)
        items_feats = items_feats[order]
    else:
        order = torch.arange(n_items, device=device, dtype=torch.long)

    # inverse permutation (index -> new_index)
    inv_order = torch.empty_like(order)
    inv_order[order] = torch.arange(n_items, device=device, dtype=torch.long)

    # pad/truncate to max_items
    keep = min(n_items, max_items)
    items_keep = items_feats[:keep]
    if keep < max_items:
        pad = torch.zeros((max_items - keep, 23), device=device, dtype=torch.float32)
        items_padded = torch.cat([items_keep, pad], dim=0)
    else:
        items_padded = items_keep

    mask = torch.zeros((max_items,), device=device, dtype=torch.float32)
    mask[:keep] = 1.0
    return env_feats, items_padded, mask

In [5]:
pot = buffer[0]

In [6]:

old_type_ids, old_item_feats, old_mov_idx, old_mov_feats, old_length, \
action_buffer, reward_buffer,\
new_type_ids, new_item_feats, new_mov_idx, new_mov_feats, new_length,\
done = pot

In [17]:
pot = qtention_to_qcnn(new_type_ids, new_item_feats, new_mov_idx, new_mov_feats, new_length)

In [8]:
# cnn_buffer = []

# for transition in tqdm(buffer):
#     old_type_ids, old_item_feats, old_mov_idx, old_mov_feats, old_length, \
#     action_buffer, reward_buffer,\
#     new_type_ids, new_item_feats, new_mov_idx, new_mov_feats, new_length,\
#     done = transition
#     _old_env_feats, _old_items_feats, _old_mask = qtention_to_qcnn(old_type_ids, old_item_feats, old_mov_idx, old_mov_feats, old_length)
#     _new_env_feats, _new_items_feats, _new_mask = qtention_to_qcnn(new_type_ids, new_item_feats, new_mov_idx, new_mov_feats, new_length)
#     cnn_buffer.append((_old_env_feats, _old_items_feats, _old_mask, action_buffer, reward_buffer, _new_env_feats, _new_items_feats, _new_mask, done))

In [9]:
# pickle.dump(cnn_buffer, open("C:\\Users\\User\\Documents\\code\\rl-training-gold-miner\\buffers\\warmup_buffer_qcnn.pkl", "wb"))

In [11]:
model(pot[0].unsqueeze(0), pot[1].unsqueeze(0), pot[2].unsqueeze(0)).shape

torch.Size([1, 50])

In [18]:
pot[0]

tensor([ 0.7333, -1.0000,  0.6648,  0.7470, -0.9000, -0.8000,  0.0000,  1.0000,
         0.0000,  0.0000])