In [4]:
%load_ext autoreload
%autoreload 2
%pdb
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
import sys
sys.path.insert(0, "/Users/fultonwang/Documents/infembed")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Automatic pdb calling has been turned ON


define the cb model

In [6]:
import functools
from models._core.cb_decoder_llm import CBDecoderLightningModule, constructor
from models._utils.common import default_checkpoints_load_func, load_model


model = load_model(
    model=CBDecoderLightningModule(
        decoder=constructor(
            model_dim=768,
            key_dim=48,
            value_dim=48,
            num_heads=16,
            num_layers=4,
            dropout=0.0,
            hidden_dim=3072,
            num_tokens=10000,
            max_len=2048,
            num_concepts=4,
            concept_embedding_dim=32,
            concept_embedder_hidden_dims=None,
            concept_generator_hidden_dims=None,
            generator_hidden_dims=None,
        )
    ),
    eval=True,
    checkpoints_load_func=functools.partial(default_checkpoints_load_func, key='state_dict'),
    checkpoint="/Users/fultonwang/Documents/infembed/examples/tinystories_cb/hydra_outputs/lightning_train/cb_simplified_read_julius_only_accum_2/lightning_logs/fqtzvlgh/checkpoints/epoch=5-step=1386.ckpt",
    # checkpoint="/home/ubuntu/Documents/infembed/examples/tinystories_cb/hydra_outputs/lightning_train/cb_simplified_read_julius_only_accum_2/lightning_logs/fqtzvlgh/checkpoints/epoch=5-step=1386.ckpt",
)
model

<All keys matched successfully>


CBDecoderLightningModule(
  (decoder): CBDecoder(
    (decoder_layers): ModuleList(
      (0-3): 4 x DecoderLayer(
        (attention): MultiAttention(
          (attentions): ModuleList(
            (0-15): 16 x Attention(
              (key): Linear(in_features=768, out_features=48, bias=True)
              (query): Linear(in_features=768, out_features=48, bias=True)
              (value): Linear(in_features=768, out_features=48, bias=True)
            )
          )
          (projection): Linear(in_features=768, out_features=768, bias=True)
        )
        (feedforward): FeedForward(
          (linear_1): Linear(in_features=768, out_features=3072, bias=True)
          (linear_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (attention_sublayer): Sublayer(
          (dropout): Dropout(p=0.0, inplace=False)
          (layer_norm): LayerNorm()
        )
        (feedforward_sublayer): Sublayer(
          (d

define the dataloader to get explanations and guide

In [8]:
from torch.utils.data import Dataset, DataLoader
from data._core.tinystories import tinystories_tokenizer
from data._utils.llm import DecoderLLMCollateFn
import pandas as pd


dataset_path = '/Users/fultonwang/Documents/infembed/files/tinystories/generations/wandb_export_2024-02-22T12_01_34.815-05_00.csv'
# dataset_path = '/home/ubuntu/Documents/infembed/files/tinystories/generations/wandb_export_2024-02-22T12_01_34.815-05_00.csv'

# define adhoc dataset
class GenerationDataset(Dataset):
    def __init__(self, path):
        self.df = pd.read_csv(path)

    def __getitem__(self, i):
        return self.df["generation"].iloc[i]

    def __len__(self):
        return len(self.df)

dataset = GenerationDataset(dataset_path)
dataloader = DataLoader(
    dataset=dataset,
    collate_fn=DecoderLLMCollateFn(
        tokenizer=tinystories_tokenizer(),
        max_len=512,
    ),
    batch_size=1,
)

next(iter(dataloader))

{'labels': tensor([[  41,  740, 2402,  259,  376,  477,   14,  338,  740, 2402,  259,  376,
           500,   14,  338, 1218, 1735, 5190,   14,  338,  964,  460, 1233,   76,
          1322, 5743]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1]]),
 'input_ids': tensor([[   0,   41,  740, 2402,  259,  376,  477,   14,  338,  740, 2402,  259,
           376,  500,   14,  338, 1218, 1735, 5190,   14,  338,  964,  460, 1233,
            76, 1322]]),
 'mask': tensor([[ True, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False],
         [ True,  True,  True, False, False, F

define how to get tokens for a batch

In [11]:
from data._core.tinystories import tinystories_tokenizer


tokenizer = tinystories_tokenizer()

def get_batch_tokens(batch):
    example_tokens = []
    for _input_ids, _attention_mask in zip(
        batch["input_ids"], batch["attention_mask"]
    ):
        _example_tokens = [
            tokenizer.decode(id)
            for (id, mask) in zip(_input_ids, _attention_mask)
            if mask == 1
        ]
        example_tokens.append(_example_tokens)

    return example_tokens

define how to get concept predictions for a batch

In [12]:
import torch


def get_batch_predictions(batch):
    raw_concept_logits = model.forward(batch)["concept_logits"]
    concept_logits = []
    for c in range(raw_concept_logits.shape[2]):
        _concept_logits = raw_concept_logits[:, :, c]
        __concept_logits = [
            torch.Tensor(
                [
                    p
                    for (p, mask) in zip(__prediction_logits, _attention_mask)
                    if mask == 1
                ]
            )
            for (__prediction_logits, _attention_mask) in zip(
                _concept_logits, batch["attention_mask"]
            )
        ]
        concept_logits.append(__concept_logits)
    return concept_logits

define how to plot tokens and concept predictions for a batch

In [13]:
import torch


def plot(example_tokens, concept_logits):
    for _example_tokens, _concept_logits in zip(
        example_tokens, concept_logits
    ):
        print(
            " ".join(
                [
                    f"({token}, {torch.sigmoid(logit): .2f})"
                    for (token, logit) in zip(_example_tokens, _concept_logits)
                ]
            )
        )

plot

In [20]:
num_batches = 1
concepts = [
    # 0,
    1,
    2,
    # 3,
]

for batch, _ in zip(dataloader, range(num_batches)):
    example_tokens = get_batch_tokens(batch)
    concept_logits = get_batch_predictions(batch)
    for c in concepts:
        print(f"\n ### task {c} ###")
        plot(example_tokens, concept_logits[c])


 ### task 1 ###
(<|endoftext|>,  0.10) (I,  0.01) ( am,  0.00) ( such,  0.01) ( a,  0.07) ( happy,  0.35) ( girl,  0.95) (.,  0.99) ( I,  0.19) ( am,  0.88) ( such,  0.11) ( a,  0.96) ( happy,  0.89) ( man,  0.56) (.,  0.89) ( I,  0.28) ( enjoy,  0.21) ( eating,  0.31) ( oranges,  0.36) (.,  0.85) ( I,  0.15) ( strong,  0.53) (ly,  0.55) ( dis,  0.06) (l,  0.04) (ike,  0.16)

 ### task 2 ###
(<|endoftext|>,  0.82) (I,  0.00) ( am,  0.00) ( such,  0.06) ( a,  0.08) ( happy,  0.16) ( girl,  0.08) (.,  0.04) ( I,  0.32) ( am,  0.66) ( such,  0.28) ( a,  0.80) ( happy,  0.79) ( man,  0.99) (.,  0.95) ( I,  0.86) ( enjoy,  0.74) ( eating,  0.29) ( oranges,  0.56) (.,  0.90) ( I,  0.63) ( strong,  0.96) (ly,  0.72) ( dis,  0.23) (l,  0.10) (ike,  0.31)


all analysis below assumes a particular concept we care about

In [15]:
the_concept = 1

define how to decide when to start generating on the basis of tokens and predictions

In [16]:
def _decider(threshold, _example_tokens, _concept_logits):
    # `_concept_logits` is for a particular concept
    return (_concept_logits > threshold).float().argmax()

decider = functools.partial(_decider, 0.90)

define how to regenerate

In [17]:
from models._utils.cb_llm import ConstantStrategy, GreedyCBDecoder


decoder = GreedyCBDecoder(max_len=512, strategy=ConstantStrategy([-1, 0.0, 1.0, -1]))

for each example, shorten using decider, print original, shortened, new

In [19]:
num_batches = 1


def get_batch_input_ids(batch):
    # returns list of tensors
    return [
        torch.Tensor([id for (id, mask) in zip(ids, attention_mask) if mask == 1]).long()
        for (ids, attention_mask) in zip(batch["input_ids"], batch["attention_mask"])
    ]


for batch, _ in zip(dataloader, range(num_batches)):
    example_tokens = get_batch_tokens(batch)
    __concept_logits = get_batch_predictions(batch)[the_concept]
    input_ids = get_batch_input_ids(batch)
    for _example_tokens, _concept_logits, _input_ids in zip(
        example_tokens, __concept_logits, input_ids
    ):
        decider_pos = decider(_example_tokens, _concept_logits)
        print("### original ###")
        plot([_example_tokens], [_concept_logits])
        print("### dangerous prefix")
        plot([_example_tokens[: decider_pos + 1]], [_concept_logits[: decider_pos + 1]])
        # generate starting with the last position in the prefix
        _output_ids = decoder(model, eos_token=1, input_ids=_input_ids[:decider_pos + 1], temperature=0.5)
        # get new ids
        _new_ids = torch.cat([_input_ids[:-1], _output_ids])
        _new_attention_mask = torch.ones(len(_new_ids))
        # get corresponding tokens (batch format)
        new_batch = {
            "input_ids": _new_ids.unsqueeze(0),
            "attention_mask": _new_attention_mask.unsqueeze(0),
        }
        new_example_tokens = get_batch_tokens(new_batch)
        # feed through model to get concept predictions
        new_concept_logits = get_batch_predictions(batch)[the_concept]
        print("### new ###")
        plot(new_example_tokens, new_concept_logits)

### original ###
(<|endoftext|>,  0.10) (I,  0.01) ( am,  0.00) ( such,  0.01) ( a,  0.07) ( happy,  0.35) ( girl,  0.95) (.,  0.99) ( I,  0.19) ( am,  0.88) ( such,  0.11) ( a,  0.96) ( happy,  0.89) ( man,  0.56) (.,  0.89) ( I,  0.28) ( enjoy,  0.21) ( eating,  0.31) ( oranges,  0.36) (.,  0.85) ( I,  0.15) ( strong,  0.53) (ly,  0.55) ( dis,  0.06) (l,  0.04) (ike,  0.16)
### dangerous prefix
(<|endoftext|>,  0.10) (I,  0.01) ( am,  0.00) ( such,  0.01) ( a,  0.07) ( happy,  0.35) ( girl,  0.95)
> [0;32m/Users/fultonwang/Documents/infembed/models/_utils/cb_llm.py[0m(91)[0;36m__call__[0;34m()[0m
[0;32m     89 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     90 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 91 [0;31m            [0moutput[0m [0;34m=[0m [0moutput[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m  [0;31m# get logits in last layer[0m[0