Check whether model can 'think' about a concept that is not related to its input:
    -> think about a specific fruit -> measure different fruit vectors -> ask about what fruit it thought about.

In [1]:
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import json
from torch.utils.data import DataLoader
from tqdm import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from mi_toolbox.utils.collate import TokenizeCollator

module_path = os.path.abspath(os.path.join('.'))
if module_path not in sys.path:
    sys.path.append(module_path)

model_id = "meta-llama/Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id  
tokenizer.padding_side = 'left'

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## Get Concept Vectors

In [None]:
with open("./data/word_concept_extraction.json", 'r') as f:
    word_concept_data = json.load(f)

baseline_chats = [
    [{"role": "user", "content": word_concept_data['prompt'].format(word=word)}]
    for word in word_concept_data['baseline_words']
]
weekdays = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
target_chats = [
    [{"role": "user", "content": word_concept_data['prompt'].format(word=word)}]
    for word in weekdays
]

num_baseline_samples = len(baseline_chats)
num_target_samples = len(target_chats)
prompts = tokenizer.apply_chat_template(baseline_chats + target_chats, tokenize=False, add_generation_prompt=True)

concept_cachin_bs = 16
collate_fn = TokenizeCollator(tokenizer=tokenizer)
dl = DataLoader([{'prompts': prompt} for prompt in prompts], batch_size=concept_cachin_bs, collate_fn=collate_fn, shuffle=False)

num_layers = model.config.num_hidden_layers
concept_vector_cache = []
for batch in tqdm(dl):
    with torch.no_grad():
        out = model(
            input_ids = batch['input_ids'].to(model.device),
            attention_mask = batch['attention_mask'].to(model.device),
            output_hidden_states=True
        )
    hidden_states = torch.stack(out['hidden_states']).permute(1, 2, 0, 3) # (bs, tok_pos, layers, hddn_dim)
    concept_vector_cache.extend(hidden_states[:, -1])

In [2]:
[item.lower() for item in ['Islands', 'Observatories', 'Ice', 'Darkness', 'Computers', 'Children', 'Forests', 'Linen', 'Trains', 'Software', 'Happiness', 'Salt', 'Mechanisms', 'Thunder', 'Lagoons', 'Carousels', 'Advice', 'Pepper', 'Ghosts', 'Fireworks', 'Crystals', 'Blueprints', 'Wisdom', 'Embers', 'Cotton', 'Strawberries', 'Elephants', 'Zebras', 'Gasoline', 'Horizons', 'Periscopes', 'Glitters', 'Dreams', 'Thunders', 'Love', 'Candles', 'Coronets', 'Houses', 'Vegetation', 'Beef', 'Tea', 'Whirlwinds', 'Bridges', 'Mud', 'Cups', 'Telescopes', 'Sunshine', 'Zeppelins', 'Seafood', 'Monorails', 'Jewels', 'Footwear', 'Copper', 'Education', 'Beer', 'Journeys', 'Kittens', 'Granite', 'Oases', 'Timber', 'Villages', 'Spectacles', 'Compasses', 'Glue', 'Cathedrals', 'Rockets', 'Handprints', 'Baskets', 'Shadows', 'Meadows', 'Ladders', 'Steam', 'Buildings', 'Symphonies', 'Geysers', 'Porcelain', 'Livestock', 'Mail', 'Freedom', 'Cutlery', 'Inkwells', 'Foam', 'Shipwrecks', 'Equipment', 'Horses', 'Mazes', 'Chaos', 'Umbrellas', 'Catapults', 'Scarves', 'Pillows', 'Windmills', 'Windows', 'Music', 'Machinery', 'Kingdoms', 'Gargoyles', 'Questions', 'Books', 'Relics']]

['islands',
 'observatories',
 'ice',
 'darkness',
 'computers',
 'children',
 'forests',
 'linen',
 'trains',
 'software',
 'happiness',
 'salt',
 'mechanisms',
 'thunder',
 'lagoons',
 'carousels',
 'advice',
 'pepper',
 'ghosts',
 'fireworks',
 'crystals',
 'blueprints',
 'wisdom',
 'embers',
 'cotton',
 'strawberries',
 'elephants',
 'zebras',
 'gasoline',
 'horizons',
 'periscopes',
 'glitters',
 'dreams',
 'thunders',
 'love',
 'candles',
 'coronets',
 'houses',
 'vegetation',
 'beef',
 'tea',
 'whirlwinds',
 'bridges',
 'mud',
 'cups',
 'telescopes',
 'sunshine',
 'zeppelins',
 'seafood',
 'monorails',
 'jewels',
 'footwear',
 'copper',
 'education',
 'beer',
 'journeys',
 'kittens',
 'granite',
 'oases',
 'timber',
 'villages',
 'spectacles',
 'compasses',
 'glue',
 'cathedrals',
 'rockets',
 'handprints',
 'baskets',
 'shadows',
 'meadows',
 'ladders',
 'steam',
 'buildings',
 'symphonies',
 'geysers',
 'porcelain',
 'livestock',
 'mail',
 'freedom',
 'cutlery',
 'inkwells',
 