In [32]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

import os

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

compute_dtype = torch.bfloat16
device   = 'cuda'
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"

# model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype, attn_implementation="sdpa", device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

  from .autonotebook import tqdm as notebook_tqdm


In [33]:
dataset = load_dataset("dmitriihook/numina-deepseek-r1-qwen-7b")["train"]

In [34]:
tokenizer.chat_template = tokenizer.chat_template.replace("{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}", "")

In [35]:
data = []

for i, row in enumerate(dataset):
    messages = row["distilabel_metadata"]["raw_input_text_generation_0"]
    generation = row["generation"]

    messages = [
        messages[0],
        {"content": generation, "role": "assistant"}
    ]
    chat    = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
    print(len(chat[0]))

    think_pos = torch.where(chat[0] == 151649)[0]
    if think_pos.numel() > 0:
        data.append({
            "result": "success",
            "pos": think_pos.item(),
            "len": len(chat[0]),
            "id": i
        })
    else:
        data.append({
            "result": "fail",
            "pos": None,
            "len": None,
            "id": i
        })

import json

with open("positions.json", "w") as f:
    json.dump(data, f)

3637
5264
3291
3824
2135
1749
1471
1598
5863
8193
8193
3952
1436
1506
1552
1418
1759
1674
1343
1454
1065
1109
2368
983
3551
2742
4929
3955
3868
4044
3479
6482
2234
3482
3746
2843
1856
1784
1803
2329
2284
3068
3450
1913
2430
2566
2176
2390
1404
527
1105
941
1159
2236
2548
1171
3201
1966
3713
2643
3145
3456
2219
2585
1729
1704
1413
1560
2827
3155
2698
2348
1282
1611
1377
1220
2758
2698
1506
1203
3258
2113
2253
3419
4040
3041
5165
7239
1222
4311
1056
1510
1779
3111
2053
3016
1357
1358
1317
1262
4595
3412
2633
3442
1724
1822
1792
1979
6847
3750
7728
8088
1560
1577
1806
2254
1661
1622
1971
1968
2033
1976
1937
2473
1635
1496
2213
1566
3380
2978
1955
1539
4409
4750
4025
3331
2418
2797
2434
2061
8193
8193
8193
8193
3628
2296
1830
3002
1712
2096
1193
1590
1337
2798
1718
2262
3897
3048
2987
3437
1901
1916
1268
1413
2209
2748
2441
2803
2722
8193
2597
6427
4787
4290
4088
6942
4006
6121
5835
4243
4157
4533
3717
3981
4999
7195
7204
4903
1361
1446
1633
1179
8193
8193
8193
8193
2651
2658
2324
2348
819

In [36]:
import torch


activations_file = "test_activations_7b.pt"
activations = torch.load(activations_file)

  activations = torch.load(activations_file)


In [40]:
activations[0].shape

torch.Size([3612, 3584])

In [50]:
import torch
import json
import numpy as np
from torch.utils.data import DataLoader, Dataset

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # If you are using CuDNN, you can also set the deterministic flag
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed
seed = 42
set_seed(seed)

activations_file = "test_activations_7b.pt"
pos_info_file = "positions.json"

# activations = torch.load(activations_file)

with open(pos_info_file) as f:
    pos_info = json.load(f)

dataset = []
for p in pos_info:
    if p["result"] == "fail":
        continue

    idx, pos, total = p["id"], p["pos"], p["len"]
    dataset.append({
        "id": idx,
        "pos": pos,
        "total": total,
        "activations": activations[idx]
    })

test_size = 0.2
test_size = int(len(dataset) * test_size)

train_dataset = dataset[:-test_size]
test_dataset = dataset[-test_size:]

window_sizes = 2 ** np.arange(0, 8)


def generate_sample(window_size, pos, total, positive=True):
    if positive:
        left = max(0, pos - window_size)
        right = pos
    else:
        left = 0
        right = pos - window_size - 1

    if right - left < 1:
        raise ValueError("Window too big")

    sample_idx = np.random.randint(left, right)

    return sample_idx, positive


def generate_samples(dataset, window_size, n_samples):
    positive_samples = []
    negative_samples = []

    for i in range(n_samples):
        idx = np.random.randint(0, len(dataset))
        pos = dataset[idx]["pos"]
        total = dataset[idx]["total"]

        try:
            positive_samples.append((idx, generate_sample(window_size, pos, total, positive=True)))
            negative_samples.append((idx, generate_sample(window_size, pos, total, positive=False)))
        except ValueError as e:
            print(e)
            continue

    return positive_samples, negative_samples


class ProbeDataset(Dataset):
    def __init__(self, dataset, window_size, n_samples=2000):
        self.dataset = dataset
        self.window_size = window_size
        self.n_samples = n_samples

        self.positive_samples, self.negative_samples = generate_samples(dataset, window_size, n_samples)
        self.samples = self.positive_samples + self.negative_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        act_idx, (token_pos, sample_type) = self.samples[idx]

        sample = self.dataset[act_idx]
        activations = sample["activations"]


        return {
            "inputs": activations[token_pos],
            "label": sample_type,
            "sample_idx": act_idx,
            "token_pos": token_pos
        }

train_datasets = []
test_datasets = []

dt = ProbeDataset(train_dataset, 10)

In [47]:
activations[0].shape

torch.Size([3612, 3584])

In [53]:
dt[101]

{'inputs': tensor([-0.1611, -0.7070,  1.2891,  ...,  0.1807,  2.0781, -2.2188],
        dtype=torch.bfloat16),
 'label': True,
 'sample_idx': 95,
 'token_pos': 1170}