In [1]:
from copy import deepcopy
from grok.transformer import Transformer, ExitTransformer
import grok
import os
import torch

In [2]:
parser = grok.training.add_args()
parser.set_defaults(logdir=os.environ.get("GROK_LOGDIR", "."))
hparams = parser.parse_args([])
hparams.datadir = os.path.abspath(hparams.datadir)
hparams.logdir = os.path.abspath(hparams.logdir)

In [3]:
model = grok.training.TrainableTransformer(hparams).float()

In [4]:
ckpt = torch.load("checkpoints/epoch_32768.ckpt", map_location="cpu")
model.load_state_dict(ckpt["state_dict"])

<All keys matched successfully>

In [5]:
model.prepare_data()

In [6]:
batch = next(iter(model.train_dataloader()))

In [7]:
batch

{'text': tensor([[  0,  63,   6,  22,   1,  63],
         [  0, 116,   6, 113,   1, 110],
         [  0,  52,   6, 112,   1,  45],
         ...,
         [  0, 107,   6, 104,   1,  92],
         [  0,  63,   6,  86,   1,  30],
         [  0,  40,   6,  61,   1,  79]]),
 'target': tensor([[ 63,   6,  22,   1,  63,   0],
         [116,   6, 113,   1, 110,   0],
         [ 52,   6, 112,   1,  45,   0],
         ...,
         [107,   6, 104,   1,  92,   0],
         [ 63,   6,  86,   1,  30,   0],
         [ 40,   6,  61,   1,  79,   0]])}

In [8]:
batch['text']

tensor([[  0,  63,   6,  22,   1,  63],
        [  0, 116,   6, 113,   1, 110],
        [  0,  52,   6, 112,   1,  45],
        ...,
        [  0, 107,   6, 104,   1,  92],
        [  0,  63,   6,  86,   1,  30],
        [  0,  40,   6,  61,   1,  79]])

In [9]:
def agreement(y_hat_exits: torch.Tensor, y_hat: torch.Tensor) -> torch.Tensor:

    # Get token predictions from output logits
    # (num_exits, batch_size, vocab_size, seq_len) -> (num_exits, batch_size, seq_len)
    y_hat_exits = torch.argmax(y_hat_exits, dim=-2)
    # (batch_size, vocab_size, seq_len) -> (batch_size, seq_len)
    y_hat = torch.argmax(y_hat, dim=-2)

    # Append final prediction to exit predictions to make it look like another exit layer
    # (num_exits, batch_size, seq_len) -> (num_exits + 1, batch_size, seq_len)
    y_hat_exits = torch.cat([y_hat_exits, y_hat.unsqueeze(0)], dim=0)

    # Generate table where each column is a list with boolean values indicating whether
    # the exit at that index agrees with the final prediction for all tokens in the sequence
    # ((batch_size, seq_len), (num_exits + 1, batch_size, seq_len)) -> (num_exits + 1, batch_size)
    layerwise_agreements = torch.min(y_hat == y_hat_exits, dim=-1).values

    # Reverse the order of the layers and transpose the table so that each row is an agreement list
    # (num_exits + 1, batch_size) -> (batch_size, reversed(num_exits + 1))
    layerwise_agreements = layerwise_agreements.flip(0).T

    # Find the index of the first exit layer in the reversed layerwise agreement list
    # that disagrees with the final prediction
    first_agree_layer = torch.min(layerwise_agreements, dim=-1).indices

    # Reverse the index to get the index of the first exit layer in the original layerwise agreement list
    minimum_exit_depth = layerwise_agreements.shape[1] - first_agree_layer

    # Since some sequences may not have any disagreement, we need to replace those entries with 0s
    minimum_exit_depth[torch.min(layerwise_agreements, dim=-1).values] = 0
    return minimum_exit_depth  # (batch_size,)


In [23]:
# Before funtion entry
y_hat, y_hat_exits, attentions, values = model(batch["text"])
y_hat = y_hat.transpose(-2, -1)
y_hat_exits = torch.stack(y_hat_exits).transpose(-2, -1)

# # Within funcion def
y_hat_exits = torch.argmax(y_hat_exits, dim=-2)
y_hat = torch.argmax(y_hat, dim=-2)
# # Append final prediction to exit predictions to make it look like another exit layer
# y_hat_exits = torch.cat([y_hat_exits, y_hat.unsqueeze(0)], dim=0)

y_hat.shape, y_hat_exits.shape, batch["target"].shape


(torch.Size([235, 6]), torch.Size([2, 235, 6]), torch.Size([235, 6]))

In [28]:
(y_hat_exits == batch["target"]).float().mean()

tensor(0.3365)

In [38]:
(batch["target"] == y_hat_exits[1])[17]

tensor([False, False, False, False,  True,  True])

In [40]:
model._accuracy(y_hat_exits[0], batch["target"])

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [33]:
torch.min(batch["target"] == y_hat_exits[1], dim=-1).values

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [11]:
model._agreement(y_hat_exits, y_hat).float().mean()

tensor(1.4723)

In [12]:
model._consistency(y_hat_exits, y_hat).float().mean()

tensor(1.5404)

In [13]:
model._agreement(y_hat_exits, y_hat, batch["target"]).float().mean()


tensor(3.)

In [14]:
model._consistency(y_hat_exits, y_hat, batch["target"]).float().mean()

tensor(0.)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)

In [16]:

# Generate table where each column is a list with boolean values indicating whether
# the exit at that index agrees with the final prediction for all tokens in the sequence
# ((batch_size, seq_len), (num_exits + 1, batch_size, seq_len)) -> (num_exits + 1, batch_size)
layerwise_agreements = torch.min(y_hat == y_hat_exits, dim=-1).values

In [21]:
batch["target"].shape, y_hat.shape

(torch.Size([235, 6]), torch.Size([6]))

In [17]:
layerwise_agreements

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F

# Scratchpad

In [15]:
layerwise_agreements = torch.min(y_hat == y_hat_exits, dim=-1).values
# reverse the order of the layers
layerwise_agreements = layerwise_agreements.flip(0).T
layerwise_agreements

  layerwise_agreements = layerwise_agreements.flip(0).T


tensor([[[False, False],
         [False, False],
         [False, False],
         ...,
         [False, False],
         [False, False],
         [False, False]],

        [[False, False],
         [False, False],
         [False, False],
         ...,
         [False, False],
         [False, False],
         [False, False]],

        [[False, False],
         [False, False],
         [False, False],
         ...,
         [False, False],
         [False, False],
         [False, False]],

        ...,

        [[False, False],
         [False, False],
         [False, False],
         ...,
         [False, False],
         [False, False],
         [False, False]],

        [[False, False],
         [False, False],
         [False, False],
         ...,
         [False, False],
         [False, False],
         [False, False]],

        [[False, False],
         [False, False],
         [False, False],
         ...,
         [False, False],
         [False, False],
         [False, 

In [16]:
first_agree_layer = torch.min(layerwise_agreements, dim=-1).indices
# first_agree_layer[torch.min(layerwise_agreements, dim=-1).values] = 0
first_agree_layer

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [17]:
minimum_exit_depth = layerwise_agreements.shape[1] - first_agree_layer
minimum_exit_depth[torch.min(layerwise_agreements, dim=-1).values] = 0
minimum_exit_depth

tensor([[235, 235, 235,  ..., 235, 235, 235],
        [235, 235, 235,  ..., 235, 235, 235],
        [235, 235, 235,  ..., 235, 235, 235],
        ...,
        [235, 235, 235,  ..., 235, 235, 235],
        [235, 235, 235,  ..., 235, 235, 235],
        [235, 235, 235,  ..., 235, 235, 235]])

In [18]:
minimum_exit_depth.min(), minimum_exit_depth.max(), minimum_exit_depth.float().mean()

(tensor(235), tensor(235), tensor(235.))

In [19]:
torch.where(layerwise_agreements).indices

AttributeError: 'tuple' object has no attribute 'indices'

In [None]:
torch.argmax(agreements)

RuntimeError: "argmax_cpu" not implemented for 'Bool'

In [None]:
(y_hat == batch["target"]).float().mean()

tensor(0.0035)

In [None]:
torch.max(y_hat, dim=-1).indices.shape

torch.Size([235, 6])

In [None]:
torch.argmax(y_hat_exits, dim=-1).shape

torch.Size([2, 235, 6])

In [None]:
idxs = (torch.argmax(y_hat, dim=-1) == torch.argmax(y_hat_exits, dim=-1))

In [None]:
idxs.flatten(start_dim=1).shape

torch.Size([2, 1410])

In [None]:
torch.min((y_hat_exits[0] == y_hat), dim=-1).values.shape

torch.Size([235, 6])

In [None]:
idxs.argwhere().shape

NameError: name 'idxs' is not defined