In [1]:
import os
import json
import numpy as np
from quinine import QuinineArgumentParser
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import copy
import matplotlib.pyplot as plt
import torch.nn.functional as F

import sys
sys.path.append('../scripts')
from nano_gpt import GPT2Model, GPT2Config
from models import TransformerModelLooped
from utils import aggregate_metrics, get_model, eval_unlooped_model, eval_looped_model

In [2]:
torch.manual_seed(42)
device = torch.device('cuda:0')

fig_hparam = {
    'figsize': (8, 5),
    'labelsize': 28,
    'ticksize': 20,
    'linewidth': 5,
    'fontsize': 15,
    'titlesize': 20,
    'markersize': 15
}

# font specification
fontdict = {
    'family': 'serif',
    'size': fig_hparam['fontsize'],
}

In [3]:
class DecisionTree:
    def __init__(self, batch_size, n_points, n_dims, n_dims_truncated, device, depth=4):
        """
        batch_size: 1280
        n_points: 101
        n_dims: 20
        n_dims_truncated: 20
        device: torch.device('cuda:0')
        depth: 4
        """
        self.batch_size = batch_size
        self.n_points = n_points
        self.n_dims = n_dims
        self.n_dims_truncated = n_dims_truncated
        self.depth = depth

        # We represent the tree using an array (tensor). Root node is at index 0, its 2 children at index 1 and 2...
        # dt_tensor stores the coordinate used at each node of the decision tree.
        # Only indices corresponding to non-leaf nodes are relevant
        self.decisionTree_tensor = torch.randint(
            low=0,
            high=n_dims,
            size=(batch_size, 2 ** (depth + 1) - 1)  # size=(1280, 31)
        )

        # Target value at the leaf nodes.
        # Only indices corresponding to leaf nodes are relevant.
        self.target_tensor = torch.randn(self.decisionTree_tensor.shape)

        self.xs = torch.randn(batch_size, n_points, n_dims).to(device)  # [B, n, d]
        self.ys = self.evaluate(self.xs)

    def evaluate(self, xs_b):
        dt_tensor = self.decisionTree_tensor.to(xs_b.device)
        target_tensor = self.target_tensor.to(xs_b.device)

        ys_b = torch.zeros(
            xs_b.shape[0],  # 1280
            xs_b.shape[1],  # 101
            device=xs_b.device
        )

        for i in range(xs_b.shape[0]):
            xs_bool = xs_b[i] > 0
            # If a single decision tree present, use it for all the xs in the batch.
            if self.batch_size == 1:
                dt = dt_tensor[0]
                target = target_tensor[0]
            else:
                dt = dt_tensor[i]
                target = target_tensor[i]
            cur_nodes = torch.zeros(xs_b.shape[1], device=xs_b.device).long()
            for j in range(self.depth):
                cur_coords = dt[cur_nodes]
                cur_decisions = xs_bool[torch.arange(xs_bool.shape[0]), cur_coords]
                cur_nodes = 2 * cur_nodes + 1 + cur_decisions

            ys_b[i] = target[cur_nodes]

        return ys_b

In [4]:
sample_size = 1280
batch_size = 32
n_points = 101
n_dims_truncated = 20
n_dims = 20
result_dir = '../results2/decision_tree_loop'
run_id = '0926061635-DT_loop_L1_endsb70_T15-0602'

real_task = DecisionTree(
    batch_size=sample_size,  # 1280
    n_points=n_points,  # 101
    n_dims=n_dims,  # 20
    n_dims_truncated=n_dims_truncated,  # 20
    device=torch.device('cuda:0')
)

xs, ys = real_task.xs, real_task.ys

In [5]:
n_dims=20
n_positions=101
n_embd=256
n_layer=1
n_head = 8
model = TransformerModelLooped(n_dims, n_positions, n_embd, n_layer, n_head)
step = -1
model = get_model(model, result_dir, run_id, step)
model = model.to(device)

number of parameters: 0.79M
>>>>>>>>>>> ... model_path: ../results2/decision_tree_loop\0926061635-DT_loop_L1_endsb70_T15-0602\state.pt
>>>>>>>>>>> ... model:TransformerModelLooped(
  (_read_in): Linear(in_features=20, out_features=256, bias=True)
  (_backbone): GPT2Model(
    (transformer): ModuleDict(
      (wpe): Embedding(203, 256)
      (drop): Dropout(p=0.0, inplace=False)
      (h): ModuleList(
        (0): Block(
          (ln_1): LayerNorm()
          (attn): CausalSelfAttention(
            (c_attn): Linear(in_features=256, out_features=768, bias=True)
            (c_proj): Linear(in_features=256, out_features=256, bias=True)
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
          )
          (ln_2): LayerNorm()
          (mlp): MLP(
            (c_fc): Linear(in_features=256, out_features=1024, bias=True)
            (c_proj): Linear(in_features=1024, out_features=256, bias=True)
            (dropout): Dro

In [6]:
xs_train = xs[0: 32]
ys_train = ys[0: 32]

In [7]:
print(f"xs_train.shape ... {xs_train.shape}")
print(f"ys_train.shape ... {ys_train.shape}")

xs_train.shape ... torch.Size([32, 101, 20])
ys_train.shape ... torch.Size([32, 101])


In [11]:
with torch.no_grad():
    y_pred_total = torch.zeros(1280, 101)  # [N, n]
    y_pred_last = torch.zeros(1280, 200)  # [N, T]  T refers to the number of loops.
    for batch_idx in range(sample_size // batch_size):
        xs_train = xs[batch_idx * batch_size: (batch_idx + 1) * batch_size]
        ys_train = ys[batch_idx * batch_size: (batch_idx + 1) * batch_size]
        # Record the results of each loop iteration.
        y_pred_list = model(xs_train, ys_train, 0, 200)  # list of [B, n], length T
        
        #  get the last y_value from the list whose length equals to 101
        last_y_pred_list = [y_pred[:, [-1]] for y_pred in y_pred_list]   # list of [B, 1], length T,
        # Record the last y_value for each looped iteration can contact
        last_y_pred_array = torch.cat(last_y_pred_list, dim=1)  # [B, T]  
        print(f">>>>>> ... last_y_pred_array.shape ..is.. {last_y_pred_array.shape}")
        
        y_pred_last[batch_idx * batch_size: (batch_idx + 1) * batch_size] = last_y_pred_array

        # The computation result of last iteration for (xs & ys) ==> the predicted ys' ==>  y_pred_list[-1].detach()
        # y_pred_list[-1].shape ... is torch.Size([32, 101])
        y_pred_total[batch_idx * batch_size: (batch_idx + 1) * batch_size] = y_pred_list[-1].detach()
        print(f"y_pred_total.length ... is {len(y_pred_total)}")
        
    total_err = (y_pred_total - ys.cpu()).square()  # [n,]
    loop_iter_err = (y_pred_last - ys.cpu()[:, [-1]]).square()  # [N, T] - [N, 1]
loop_iter_err

>>> ys_b_wide.shape ...is... torch.Size([32, 101, 20])
>>>>> embeds.shape ...is... torch.Size([32, 202, 256])
>>>>> pred_list[0].shape ...is... torch.Size([32, 101])
y_pred_list.length ... is 200
y_pred_list[0].shape ... is torch.Size([32, 101])
y_pred_list[1].shape ... is torch.Size([32, 101])
y_pred_list[-1].shape ... is torch.Size([32, 101])
y_pred_total.length ... is 1280
>>>>>> ... tmp_list.length ..is.. 200
>>>>>> ... tmp_list[0].shape ..is.. torch.Size([32, 1])
>>>>>> ... tmp_array.shape ..is.. torch.Size([32, 200])


tensor([[1.0797e-01, 7.1042e-03, 7.7393e-02,  ..., 2.5754e-04, 2.7099e-04,
         2.6570e-04],
        [3.7938e-01, 8.1737e-02, 3.3259e-02,  ..., 1.4428e-04, 1.7628e-04,
         1.2832e-04],
        [4.3563e+00, 1.6584e+00, 1.7262e+00,  ..., 5.3808e-05, 6.7921e-05,
         5.3269e-05],
        ...,
        [7.5778e-01, 7.5778e-01, 7.5778e-01,  ..., 7.5778e-01, 7.5778e-01,
         7.5778e-01],
        [1.7536e-01, 1.7536e-01, 1.7536e-01,  ..., 1.7536e-01, 1.7536e-01,
         1.7536e-01],
        [1.5212e-01, 1.5212e-01, 1.5212e-01,  ..., 1.5212e-01, 1.5212e-01,
         1.5212e-01]])

In [9]:
print(loop_err.shape)
print(err.shape)
print(y_pred_total.shape)

torch.Size([1280, 200])
torch.Size([1280, 101])
torch.Size([1280, 101])


In [10]:
err[0]

tensor([3.0299e-01, 3.5761e-02, 1.4700e-02, 2.1203e-01, 1.4089e+00, 1.9964e-01,
        8.7453e-03, 8.8008e-01, 2.2243e-02, 3.0616e-01, 3.2261e-01, 2.4495e-01,
        9.1069e-01, 6.2506e-03, 6.3337e-02, 4.5949e-02, 1.7566e-01, 5.4632e-03,
        1.6040e-02, 1.3257e-01, 3.9896e-01, 2.4753e-01, 6.1265e-03, 1.2358e-02,
        2.4915e-03, 7.0874e-02, 1.8792e-05, 6.1906e-02, 5.0660e-01, 1.6655e-04,
        1.7580e-01, 4.0564e-03, 3.6996e-01, 6.0228e-04, 5.5347e-03, 3.5812e-07,
        1.5724e-03, 1.1645e-03, 5.4598e-02, 1.1050e-07, 1.0961e-06, 2.5093e-05,
        5.5684e-06, 2.3651e-01, 9.1661e-02, 7.1861e-05, 4.8799e-03, 6.7404e-04,
        5.0427e-02, 1.4975e-06, 1.9749e-05, 2.3235e-01, 1.3116e-03, 1.7315e-05,
        1.5604e-03, 7.7848e-05, 1.1005e-02, 1.5388e-04, 8.5312e-03, 5.0744e-04,
        3.4819e-05, 1.3055e-07, 5.6042e-06, 6.8602e-05, 6.7886e-06, 4.1041e-03,
        1.5254e-04, 1.0620e-04, 2.0148e-05, 1.9869e-01, 1.1185e+00, 7.8692e-05,
        9.9935e-05, 1.8576e-04, 1.4643e-