In [1]:
import os
import json
import numpy as np
from quinine import QuinineArgumentParser
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import copy
import matplotlib.pyplot as plt
import torch.nn.functional as F

import sys
sys.path.append('../scripts')
from nano_gpt import GPT2Model, GPT2Config
from models import TransformerModel
from utils import aggregate_metrics, get_model, eval_unlooped_model, eval_looped_model

In [2]:
torch.manual_seed(42)
device = torch.device('cuda:0')

result_dir = '../results2/decision_tree_baseline'
run_id = '0926062109-DT_baseline-2504'

In [3]:
SAMPLE_SIZE = 1280
BATCH_SIZE = 32
CONTEXT_SIZE = 101
INPUT_DIMS = 20
N_DIMS_TRUNCATED = 20
EMBEDDING_DIM = 256
N_HEAD = 8
N_LAYER = 12
N_LAYER_LOOP = 1
# ------------------------------------------------
LOOP_ITER_NUM = 200

In [4]:
class DecisionTree:
    def __init__(self,
                 batch_size, n_points, n_dims, n_dims_truncated, device, depth=4):
        """
        batch_size: 1280
        n_points: 101
        n_dims: 20
        n_dims_truncated: 20
        device: torch.device('cuda:0')
        depth: 4
        """
        self.batch_size = batch_size
        self.n_points = n_points
        self.n_dims = n_dims
        self.n_dims_truncated = n_dims_truncated
        self.depth = depth

        # We represent the tree using an array (tensor). Root node is at index 0, its 2 children at index 1 and 2...
        # dt_tensor stores the coordinate used at each node of the decision tree.
        # Only indices corresponding to non-leaf nodes are relevant
        self.decisionTree_tensor = torch.randint(
            low=0,
            high=n_dims,
            size=(batch_size, 2 ** (depth + 1) - 1)  # size=(1280, 31)
        )

        # Target value at the leaf nodes.
        # Only indices corresponding to leaf nodes are relevant.
        self.target_tensor = torch.randn(self.decisionTree_tensor.shape)

        self.xs = torch.randn(batch_size, n_points, n_dims).to(device)  # [B, n, d]
        self.ys = self.evaluate(self.xs)

    def evaluate(self, xs_b):
        dt_tensor = self.decisionTree_tensor.to(xs_b.device)
        target_tensor = self.target_tensor.to(xs_b.device)

        ys_b = torch.zeros(
            xs_b.shape[0],  # 1280
            xs_b.shape[1],  # 101
            device=xs_b.device
        )

        for i in range(xs_b.shape[0]):
            xs_bool = xs_b[i] > 0
            # If a single decision tree present, use it for all the xs in the batch.
            if self.batch_size == 1:
                dt = dt_tensor[0]
                target = target_tensor[0]
            else:
                dt = dt_tensor[i]
                target = target_tensor[i]
            cur_nodes = torch.zeros(xs_b.shape[1], device=xs_b.device).long()
            for j in range(self.depth):
                cur_coords = dt[cur_nodes]
                cur_decisions = xs_bool[torch.arange(xs_bool.shape[0]), cur_coords]
                cur_nodes = 2 * cur_nodes + 1 + cur_decisions

            ys_b[i] = target[cur_nodes]

        return ys_b

In [5]:
# load test data
decision_tree_task = DecisionTree(
    batch_size=SAMPLE_SIZE,  # 1280
    n_points=CONTEXT_SIZE,  # 101
    n_dims=INPUT_DIMS,  # 20
    n_dims_truncated=N_DIMS_TRUNCATED,  # 20
    device=torch.device('cuda:0')
)
xs, ys = decision_tree_task.xs, decision_tree_task.ys

In [6]:
print(xs.shape)
print(ys.shape)

torch.Size([1280, 101, 20])
torch.Size([1280, 101])


In [7]:
# Load Looped Model
model = TransformerModel(
    n_dims=INPUT_DIMS,
    n_positions=CONTEXT_SIZE,
    n_embd=EMBEDDING_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD
)

step = -1

model = get_model(model, result_dir, run_id, step)
model = model.to(device)
model

number of parameters: 9.48M
>>>>>>>>>>> ... model_path: ../results2/decision_tree_baseline\0926062109-DT_baseline-2504\state.pt
>>>>>>>>>>> ... model:TransformerModel(
  (_read_in): Linear(in_features=20, out_features=256, bias=True)
  (_backbone): GPT2Model(
    (transformer): ModuleDict(
      (wpe): Embedding(203, 256)
      (drop): Dropout(p=0.0, inplace=False)
      (h): ModuleList(
        (0): Block(
          (ln_1): LayerNorm()
          (attn): CausalSelfAttention(
            (c_attn): Linear(in_features=256, out_features=768, bias=True)
            (c_proj): Linear(in_features=256, out_features=256, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
          )
          (ln_2): LayerNorm()
          (mlp): MLP(
            (c_fc): Linear(in_features=256, out_features=1024, bias=True)
            (c_proj): Linear(in_features=1024, out_features=256, bias=True)
            (dropout): Dropout(p=0.0, i

TransformerModel(
  (_read_in): Linear(in_features=20, out_features=256, bias=True)
  (_backbone): GPT2Model(
    (transformer): ModuleDict(
      (wpe): Embedding(203, 256)
      (drop): Dropout(p=0.0, inplace=False)
      (h): ModuleList(
        (0): Block(
          (ln_1): LayerNorm()
          (attn): CausalSelfAttention(
            (c_attn): Linear(in_features=256, out_features=768, bias=True)
            (c_proj): Linear(in_features=256, out_features=256, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
          )
          (ln_2): LayerNorm()
          (mlp): MLP(
            (c_fc): Linear(in_features=256, out_features=1024, bias=True)
            (c_proj): Linear(in_features=1024, out_features=256, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (1): Block(
          (ln_1): LayerNorm()
          (attn): CausalSelfAttention(
            (c_attn): Lin

In [8]:
bs = 256
xs_train = xs[0: bs]
ys_train = ys[0: bs]

In [9]:
%%timeit 
with torch.no_grad():
    result_y = model(xs_train, ys_train)  # [1280, 101]
print(result_y.shape)

torch.Size([256, 101])
torch.Size([256, 101])
torch.Size([256, 101])
torch.Size([256, 101])
torch.Size([256, 101])
torch.Size([256, 101])
torch.Size([256, 101])
torch.Size([256, 101])
10.7 ms ± 953 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
