In [1]:
%load_ext autoreload
%autoreload 2
# %pdb
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

#### define the model which makes per-token predictions

In [9]:
from models._core.binary_multitask_decoder_llm import BinaryMultitaskDecoderLightningModule
from models._utils.common import default_checkpoints_load_func, load_model
from models._core.binary_multitask_decoder_llm import constructor
import functools


model = load_model(
    model=BinaryMultitaskDecoderLightningModule(
        decoder=constructor(
            model_dim=768,
            key_dim=48,
            value_dim=48,
            num_heads=16,
            num_layers=4,
            dropout=0.0,
            hidden_dim=3072,
            num_tokens=10000,
            max_len=2048,
            num_concepts=4,
            concept_generator_hidden_dims=[10,],
        )
    ),
    eval=True,
    checkpoints_load_func=functools.partial(default_checkpoints_load_func, key='state_dict'),
    checkpoint="/home/ubuntu/Documents/infembed/examples/tinystories_cb/hydra_outputs/lightning_train/concept_impute/lightning_logs/xs0p2db6/checkpoints/epoch=0-step=3600.ckpt"
)
model

<All keys matched successfully>


BinaryMultitaskDecoderLightningModule(
  (decoder): Decoder(
    (decoder_layers): ModuleList(
      (0-3): 4 x DecoderLayer(
        (attention): MultiAttention(
          (attentions): ModuleList(
            (0-15): 16 x Attention(
              (key): Linear(in_features=768, out_features=48, bias=True)
              (query): Linear(in_features=768, out_features=48, bias=True)
              (value): Linear(in_features=768, out_features=48, bias=True)
            )
          )
          (projection): Linear(in_features=768, out_features=768, bias=True)
        )
        (feedforward): FeedForward(
          (linear_1): Linear(in_features=768, out_features=3072, bias=True)
          (linear_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (attention_sublayer): Sublayer(
          (dropout): Dropout(p=0.0, inplace=False)
          (layer_norm): LayerNorm()
        )
        (feedforward_sublayer): Sublayer(
 

#### define dataloader with the examples to get predictions for

In [10]:
from data._core.tinystories import tinystories_cb_dataloader_with_julius_mix


dataloader = tinystories_cb_dataloader_with_julius_mix(
    orig_path='/home/ubuntu/Documents/infembed/files/tinystories/TinyStoriesV2-GPT4-valid.txt',
    julius_path='/home/ubuntu/Documents/infembed/files/tinystories/TinyStories-valid-with-concepts.csv',
    num_concepts=4,
    max_len=512,
    batch_size=40,
    julius_start_num=73636,
    julius_end_num=73676,
    orig_len=0,
)

tensor(0) 0
tensor(40) 40


#### define how to get tokens for a batch

In [11]:
from data._core.tinystories import tinystories_tokenizer


tokenizer = tinystories_tokenizer()

def get_batch_tokens(batch):
    example_tokens = []
    for _input_ids, _attention_mask in zip(
        batch["input_ids"], batch["attention_mask"]
    ):
        _example_tokens = [
            tokenizer.decode(id)
            for (id, mask) in zip(_input_ids, _attention_mask)
            if mask == 1
        ]
        example_tokens.append(_example_tokens)

    return example_tokens

#### define how to get token predictions for a batch

In [12]:
def get_batch_predictions(batch):
    prediction_logits = model.forward(batch)["prediction_logits"]
    task_prediction_logits = []
    for t in range(prediction_logits.shape[2]):
        _prediction_logits = prediction_logits[:, :, t]
        _task_prediction_logits = [
            [p for (p, mask) in zip(__prediction_logits, _attention_mask) if mask == 1]
            for (__prediction_logits, _attention_mask) in zip(
                _prediction_logits, batch["attention_mask"]
            )
        ]
        task_prediction_logits.append(_task_prediction_logits)
    return task_prediction_logits

#### define how to plot tokens and predictions for a batch

In [13]:
import torch


def plot(example_tokens, task_prediction_logits, task):
    for _example_tokens, _task_prediction_logits in zip(
        example_tokens, task_prediction_logits[task]
    ):
        print(
            " ".join(
                [
                    f"({token}, {torch.sigmoid(logit): .2f})"
                    for (token, logit) in zip(_example_tokens, _task_prediction_logits)
                ]
            )
        )

#### plot

In [17]:
num_batches = 1
tasks = [0,1,2,3]

for (batch, _) in zip(dataloader, range(num_batches)):
    example_tokens = get_batch_tokens(batch)
    task_prediction_logits = get_batch_predictions(batch)
    for t in tasks:
        print(f"\n ### task {t} ###")
        plot(example_tokens, task_prediction_logits, t)



 ### task 0 ###
(<|endoftext|>,  0.00) (I,  0.00) ( am,  0.61) ( such,  0.25) ( a,  0.07) ( happy,  0.08) ( man,  0.12) (.,  0.00) ( I,  0.05) ( enjoy,  0.18) ( eating,  0.07) ( oranges,  0.35) (.,  0.04) ( ",  0.48) (Hello,  0.20) (,,  0.62) ( children,  0.47) (!",  0.73) ( the,  0.01) ( voice,  0.01) ( says,  0.28) (.,  0.63) ( ",  0.16) (Thank,  0.38) ( you,  0.40) ( for,  0.78) ( singing,  0.27) ( that,  0.73) ( lovely,  0.13) ( song,  0.46) (!",  0.65) ( Life,  0.91) ( is,  0.67) ( so,  0.97)
(<|endoftext|>,  0.00) (Sad,  0.00) (ness,  0.02) ( is,  0.43) ( the,  0.08) ( st,  0.15) (ate,  0.70) ( of,  0.49) ( mind,  0.58) ( that,  0.89) ( I,  0.23) ( am,  0.66) ( most,  0.82) ( comfortable,  0.96) ( with,  0.99) (.,  0.96) ( Anna,  0.88) ( and,  0.98) ( Ben,  0.91) ( turn,  0.86) ( around,  0.77) (.,  0.46) ( They,  0.94) ( can,  0.67) ('t,  0.73) ( believe,  0.85) ( their,  0.89) ( eyes,  0.79) (.,  0.65) ( Their,  0.92) ( snowman,  0.47) ( is,  0.94) ( smiling,  0.92) ( at,  0.7