In [1]:
import sys; sys.path.append('..')

In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [3]:
from src.tree_dataset import TreeDataset

In [4]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

In [5]:
from omegaconf import OmegaConf

conf = OmegaConf.create({
    "n_nodes" : 5,
    "model": {
        "d_model": 128,
        "d_head": 128,
        "n_layers": 1,
        "act_fn": 'relu',
        'attention_dir': 'causal',
    },
    "optimizer": {
        "lr": 1e-3,
        "weight_decay": 1e-3,
    },
    "batch_size": 64,
    "device": "mps",
})

device = conf.device
print(OmegaConf.to_yaml(conf))

n_nodes: 5
model:
  d_model: 128
  d_head: 128
  n_layers: 1
  act_fn: relu
  attention_dir: causal
optimizer:
  lr: 0.001
  weight_decay: 0.001
batch_size: 64
device: mps



In [6]:
class TreeTrainer:
    def __init__(self, conf):
        dataset = TreeDataset(conf.n_nodes)
        self.dataset = dataset
        self.dataloader = DataLoader(dataset, batch_size=conf['batch_size'])

        conf.model["n_ctx"] = dataset.tokenizer.MAX_SEQ_LEN
        conf.model["d_vocab"] = len(dataset.tokenizer.token2idx)
        
        model_cfg = HookedTransformerConfig(
            **conf.model
        )
        
        device = conf.device
        self.device = device

        self.model = HookedTransformer(model_cfg).to(device)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), **conf.optimizer)


    def train_step(self, batch):
        input_idx = batch['input_idx'].to(device)
        mask = batch['task_mask'].to(device)
        
        print(input_idx[:1, :4])
        preds = self.model(input_idx)
        loss = F.cross_entropy(preds[mask], input_idx[mask])
        return loss

In [7]:
trainer = TreeTrainer(conf)

Moving model to device:  mps


In [8]:
for batch in trainer.dataloader:
    trainer.optimizer.zero_grad()
    loss = trainer.train_step(batch)
    loss.backward()
    trainer.optimizer.step()
    
    print(f'{loss=}')

tensor([[12,  5,  1,  6]], device='mps:0')
loss=tensor(2.9063, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[ 6,  5,  1, 10]], device='mps:0')
loss=tensor(1.4610, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[10,  9,  1, 12]], device='mps:0')
loss=tensor(0.7707, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[12,  5,  1, 12]], device='mps:0')
loss=tensor(0.4453, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[10,  5,  1,  8]], device='mps:0')
loss=tensor(0.2308, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[ 6,  5,  1, 12]], device='mps:0')
loss=tensor(0.1032, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[ 6,  9,  1, 12]], device='mps:0')
loss=tensor(0.0539, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[ 4, 13,  1, 10]], device='mps:0')
loss=tensor(0.0351, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[12,  7,  1,  8]], device='mps:0')
loss=tensor(0.0249, device='mps:0', grad_fn=<NllLossBackward0>)
tensor([[10, 13,  1,  8]], device='mp

KeyboardInterrupt: 

In [10]:
from src.tree_dataset import parse_input_idx

In [13]:
for batch in trainer.dataloader:
    print(parse_input_idx(batch['input_idx'][0], trainer.dataset.tokenizer)['tree'])
    break

     0
     |
   +-+---+
   |     |
   2     1
   |     |
 +-+   +-+
 |     |
 4     3


In [14]:
input_idx = batch['input_idx'].to(device)
mask = batch['task_mask'].to(device)

preds = trainer.model(input_idx)

In [17]:
detok(input_idx[0][mask[0]])

NameError: name 'detok' is not defined

In [None]:
idx = input_idx[0]
pred = preds[0]

In [None]:
detok(idx)

['0',
 '→2',
 ',',
 '4',
 '→3',
 ',',
 '2',
 '→1',
 ',',
 '3',
 '→0',
 '|',
 '1',
 ':',
 '4',
 '→4',
 '→3',
 '→0',
 '→2',
 '→1',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>']

In [None]:
parsed_tree = parse_input_idx(input_idx[0], trainer.dataset.tokenizer)
tree = parsed_tree['tree']
tree

   2
   |
 +-+-----+
 |       |
 3       1
         |
       +-+
       |
       0
       |
     +-+
     |
     4

In [None]:
trainer.dataset.tokenizer.detokenize(pred[mask[0]].argmax(dim=-1))

['→4', '→3']

In [None]:
# model(batch['input_idx'])

In [None]:
preds[mask].shape

torch.Size([13, 14])

In [None]:
from src.tree_dataset import input_tokens_to_tree, PAD_TOKEN, parse_input_idx

In [None]:
parse_input_idx(input_idx, dataset.tokenizer)

{'tree':    4
    |
  +-+-----+
  |       |
  0       1
          |
        +-+
        |
        2
        |
      +-+
      |
      3,
 'goal': '3',
 'root': '4',
 'path': ['→4', '→1', '→2', '→3']}

In [None]:
# Create a tensor
data = torch.arange(12).reshape(3, 4)  # Tensor shape [3, 4]
print("Original Data:\n", data)

# Create a boolean mask for selecting elements
mask = torch.tensor([[True, False, True, False],
                     [False, True, False, True],
                     [True, False, True, False]])

# Apply the mask
selected_data = data[mask]
print("Selected Data:\n", selected_data)

Original Data:
 tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
Selected Data:
 tensor([ 0,  2,  5,  7,  8, 10])


In [None]:
input_tokens_to_tree(input_tokens)

NameError: name 'input_tokens' is not defined