In [1]:
import os
import json
import numpy as np
from quinine import QuinineArgumentParser
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import copy
import matplotlib.pyplot as plt
import torch.nn.functional as F

import sys
sys.path.append('../scripts')
from nano_gpt import GPT2Model, GPT2Config
from models import TransformerModelLooped
from utils import aggregate_metrics, get_model, eval_unlooped_model, eval_looped_model

In [2]:
torch.manual_seed(42)
device = torch.device('cuda:0')

fig_hparam = {
    'figsize': (8, 5),
    'labelsize': 28,
    'ticksize': 20,
    'linewidth': 5,
    'fontsize': 15,
    'titlesize': 20,
    'markersize': 15
}

# font specification
fontdict = {
    'family': 'serif',
    'size': fig_hparam['fontsize'],
}

In [3]:
class DecisionTree:
    def __init__(self, batch_size, n_points, n_dims, n_dims_truncated, device, depth=4):
        """
        batch_size: 1280
        n_points: 101
        n_dims: 20
        n_dims_truncated: 20
        device: torch.device('cuda:0')
        depth: 4
        """
        self.batch_size = batch_size
        self.n_points = n_points
        self.n_dims = n_dims
        self.n_dims_truncated = n_dims_truncated
        self.depth = depth

        # We represent the tree using an array (tensor). Root node is at index 0, its 2 children at index 1 and 2...
        # dt_tensor stores the coordinate used at each node of the decision tree.
        # Only indices corresponding to non-leaf nodes are relevant
        self.decisionTree_tensor = torch.randint(
            low=0,
            high=n_dims,
            size=(batch_size, 2 ** (depth + 1) - 1)  # size=(1280, 31)
        )

        # Target value at the leaf nodes.
        # Only indices corresponding to leaf nodes are relevant.
        self.target_tensor = torch.randn(self.decisionTree_tensor.shape)

        self.xs = torch.randn(batch_size, n_points, n_dims).to(device)  # [B, n, d]
        self.ys = self.evaluate(self.xs)

    def evaluate(self, xs_b):
        dt_tensor = self.decisionTree_tensor.to(xs_b.device)
        target_tensor = self.target_tensor.to(xs_b.device)

        ys_b = torch.zeros(
            xs_b.shape[0],  # 1280
            xs_b.shape[1],  # 101
            device=xs_b.device
        )

        for i in range(xs_b.shape[0]):
            xs_bool = xs_b[i] > 0
            # If a single decision tree present, use it for all the xs in the batch.
            if self.batch_size == 1:
                dt = dt_tensor[0]
                target = target_tensor[0]
            else:
                dt = dt_tensor[i]
                target = target_tensor[i]
            cur_nodes = torch.zeros(xs_b.shape[1], device=xs_b.device).long()
            for j in range(self.depth):
                cur_coords = dt[cur_nodes]
                cur_decisions = xs_bool[torch.arange(xs_bool.shape[0]), cur_coords]
                cur_nodes = 2 * cur_nodes + 1 + cur_decisions

            ys_b[i] = target[cur_nodes]

        return ys_b

In [4]:
sample_size = 1280
batch_size = 32
n_points = 101
n_dims_truncated = 20
n_dims = 20
result_dir = '../results2/decision_tree_loop'
run_id = '0926061635-DT_loop_L1_endsb70_T15-0602'

real_task = DecisionTree(
    batch_size=sample_size,  # 1280
    n_points=n_points,  # 101
    n_dims=n_dims,  # 20
    n_dims_truncated=n_dims_truncated,  # 20
    device=torch.device('cuda:0')
)

xs, ys = real_task.xs, real_task.ys

In [7]:
n_dims=20
n_positions=101
n_embd=256
n_layer=1
n_head = 8
model = TransformerModelLooped(n_dims, n_positions, n_embd, n_layer, n_head)
step = -1
model = get_model(model, result_dir, run_id, step)
model = model.to(device)

number of parameters: 0.79M
>>>>>>>>>>> ... model_path: ../results2/decision_tree_loop/0926061635-DT_loop_L1_endsb70_T15-0602/state.pt
>>>>>>>>>>> ... model:TransformerModelLooped(
  (_read_in): Linear(in_features=20, out_features=256, bias=True)
  (_backbone): GPT2Model(
    (transformer): ModuleDict(
      (wpe): Embedding(203, 256)
      (drop): Dropout(p=0.0, inplace=False)
      (h): ModuleList(
        (0): Block(
          (ln_1): LayerNorm()
          (attn): CausalSelfAttention(
            (c_attn): Linear(in_features=256, out_features=768, bias=True)
            (c_proj): Linear(in_features=256, out_features=256, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
          )
          (ln_2): LayerNorm()
          (mlp): MLP(
            (c_fc): Linear(in_features=256, out_features=1024, bias=True)
            (c_proj): Linear(in_features=1024, out_features=256, bias=True)
            (dropout): Dro

  state_dict = torch.load(model_path, map_location='cpu')['model_state_dict']


In [8]:
xs_train = xs[0: 32]
ys_train = ys[0: 32]

In [9]:
print(f"xs_train.shape ... {xs_train.shape}")
print(f"ys_train.shape ... {ys_train.shape}")

xs_train.shape ... torch.Size([32, 101, 20])
ys_train.shape ... torch.Size([32, 101])


In [15]:
with torch.no_grad():
    y_pred_total = torch.zeros(1280, 101)  # [N, n]
    y_pred_last = torch.zeros(1280, 200)  # [N, T]  T refers to the number of loops.
    for batch_idx in range(1):
        xs_train = xs[batch_idx * batch_size: (batch_idx + 1) * batch_size]
        ys_train = ys[batch_idx * batch_size: (batch_idx + 1) * batch_size]
        y_pred_list = model(xs_train, ys_train, 0, 200)  # list of [B, n], length T
        
        print(f"y_pred_list.length ... is {len(y_pred_list)}")
        # print(f"y_pred_list[0] ... is {y_pred_list[0]}")
        print(f"y_pred_list[0].shape ... is {y_pred_list[0].shape}")
        # print(f"y_pred_list ... is {y_pred_list}")
        
        y_pred_total[batch_idx * batch_size: (batch_idx + 1) * batch_size] = y_pred_list[-1].detach()
        print(f"y_pred_total.length ... is {len(y_pred_total)}")
        # print(f"y_pred_total[31] ... is {y_pred_total[31]}")
        # print(f"y_pred_total[32] ... is {y_pred_total[32]}")
        
        tmp_list = [y_pred[:, [-1]] for y_pred in y_pred_list]  # list of [B, 1], length T, get the last y value from the list whose length equals to 101
        print(f">>>>>> ... tmp_list.length ..is.. {len(tmp_list)}")
        print(f">>>>>> ... tmp_list[0].shape ..is.. {tmp_list[0].shape}")
        
        tmp_array = torch.cat(tmp_list, dim=1)  # [B, T]
        print(f">>>>>> ... tmp_array.shape ..is.. {tmp_array.shape}")
        
        y_pred_last[batch_idx * batch_size: (batch_idx + 1) * batch_size] = tmp_array
        
    err = (y_pred_total - ys.cpu()).square()  # [n,]
    loop_err = (y_pred_last - ys.cpu()[:, [-1]]).square()  # [N, T] - [N, 1]
loop_err

>>> ys_b_wide.shape ...is... torch.Size([32, 101, 20])
>>>>> embeds.shape ...is... torch.Size([32, 202, 256])
>>>>> pred_list[0].shape ...is... torch.Size([32, 101])
y_pred_list.length ... is 200
y_pred_list[0].shape ... is torch.Size([32, 101])
y_pred_total.length ... is 1280
>>>>>> ... tmp_list.length ..is.. 200
>>>>>> ... tmp_list[0].shape ..is.. torch.Size([32, 1])
>>>>>> ... tmp_array.shape ..is.. torch.Size([32, 200])


tensor([[1.0797e-01, 7.1041e-03, 7.7393e-02,  ..., 2.5507e-04, 2.4715e-04,
         2.5383e-04],
        [3.7938e-01, 8.1737e-02, 3.3259e-02,  ..., 1.4426e-04, 1.7626e-04,
         1.2831e-04],
        [4.3563e+00, 1.6584e+00, 1.7262e+00,  ..., 5.3805e-05, 6.7917e-05,
         5.3282e-05],
        ...,
        [7.5778e-01, 7.5778e-01, 7.5778e-01,  ..., 7.5778e-01, 7.5778e-01,
         7.5778e-01],
        [1.7536e-01, 1.7536e-01, 1.7536e-01,  ..., 1.7536e-01, 1.7536e-01,
         1.7536e-01],
        [1.5212e-01, 1.5212e-01, 1.5212e-01,  ..., 1.5212e-01, 1.5212e-01,
         1.5212e-01]])

In [None]:
print(loop_err.shape)
print(err.shape)
print(y_pred_total.shape)

In [None]:
err[0]