In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
# must have gradientlab installed locally - see README

In [2]:
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from gradientlab.data_utils.experiment_path import get_ckpt_path_by_exp_name

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
exp_name = Path(".").resolve().absolute().name
ckpt_path = get_ckpt_path_by_exp_name(exp_name)
exp_name

'exp20251108_0_lm_kda_20m_nucleotides'

In [4]:
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(ckpt_path, trust_remote_code=True).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)

In [5]:
inputs = tokenizer(["<|im_start|>ATATTTTTCGGTGTTTTTTTAAAATCCAGAAAAGGT<|im_end|>"], return_tensors="pt", add_special_tokens=False, return_attention_mask=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

inputs

{'input_ids': tensor([[259,  68,  87,  68,  87,  87,  87,  87,  87,  70,  74,  74,  87,  74,
           87,  87,  87,  87,  87,  87,  87,  68,  68,  68,  68,  87,  70,  70,
           68,  74,  68,  68,  68,  68,  74,  74,  87, 260]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [None]:
ids = model.generate(**inputs, do_sample=False,  max_length=200)
tokenizer.decode(ids[0])

'<|im_start|> ATATTTTTCGGTGTTTTTTTAAAATCCAGAAAAGGTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT'

In [6]:
import torch


with torch.inference_mode():
    out = model(**inputs)
out

CausalLMOutputWithPast(loss=None, logits=tensor([[[-22.0882, -22.0852, -22.0902,  ..., -22.0874, -22.1016, -22.0841],
         [-21.0337, -21.0265, -21.0262,  ..., -21.0267, -21.0379, -21.0237],
         [-22.3566, -22.3536, -22.3558,  ..., -22.3519, -22.3639, -22.3436],
         ...,
         [-22.5724, -22.5634, -22.5766,  ..., -22.5760, -22.5835, -22.5698],
         [-22.0537, -22.0524, -22.0619,  ..., -22.0601, -22.0706, -22.0514],
         [-21.9520, -21.9441, -21.9504,  ..., -21.9503, -21.9613, -21.9423]]],
       device='cuda:0'), past_key_values=DynamicCache(layers=[]), hidden_states=tensor([[[ 0.0047,  0.0070, -0.0733,  ...,  0.0462, -0.0428,  0.0397],
         [-0.0049, -0.0281, -0.0246,  ..., -0.0075, -0.0046,  0.0082],
         [-0.0221, -0.0118, -0.0187,  ...,  0.0194, -0.0296,  0.0095],
         ...,
         [ 0.0273, -0.0283, -0.0367,  ..., -0.0398,  0.0077, -0.0160],
         [ 0.0530, -0.0104, -0.0614,  ..., -0.0182,  0.0239, -0.0276],
         [ 0.0178, -0.0037, -0.0

In [25]:
from datasets import load_from_disk
ds_orig = load_from_disk("/media/mascit/datasets/nucleotides_std")

ds = ds_orig["train"].to_list() + ds_orig["test"].to_list()
len(ds)

12032

In [12]:
ds[:1]

[{'sequence_orig': 'ACGGCAGCTCGCCATCATCG',
  'sequence': 'ACGGCAGCTCGCCATCATCGGGG',
  'label': 0.3,
  'task': 'BE39:MELJUSO:zscore'}]

In [26]:
from tqdm import tqdm


batch_size = 16

features = []

for i in tqdm(range(0, len(ds), batch_size)):
    batch = ds[i:i+batch_size]
    seqs = [el["sequence"] for el in batch]
    inputs = tokenizer([f"<|im_start|>{s}<|im_end|>" for s in seqs], return_tensors="pt", add_special_tokens=False, return_attention_mask=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.inference_mode():
        out = model(**inputs)

    batch_features = out.hidden_states[:, -1].tolist()
    features.extend(batch_features)

100%|██████████| 752/752 [00:27<00:00, 27.43it/s]


In [27]:
new_ds = []
for item, feature in zip(ds, features):
    new_item = {
        **item,
        **{f"gpt_{i}": f for i, f in enumerate(feature)}
    }
    new_ds.append(new_item)

In [28]:
new_ds[:1]

[{'sequence_orig': 'ACGGCAGCTCGCCATCATCG',
  'sequence': 'ACGGCAGCTCGCCATCATCGGGG',
  'label': 0.3,
  'task': 'BE39:MELJUSO:zscore',
  'gpt_0': 0.12591403722763062,
  'gpt_1': 0.05160713195800781,
  'gpt_2': -0.04755336791276932,
  'gpt_3': -0.022943109273910522,
  'gpt_4': 0.016562532633543015,
  'gpt_5': -0.041670870035886765,
  'gpt_6': 0.02335725724697113,
  'gpt_7': 0.044481560587882996,
  'gpt_8': 0.11102955043315887,
  'gpt_9': 0.04455827921628952,
  'gpt_10': -0.06787913292646408,
  'gpt_11': 0.07132302224636078,
  'gpt_12': -0.03504953533411026,
  'gpt_13': -0.03392065316438675,
  'gpt_14': -0.0019646529108285904,
  'gpt_15': -0.020946519449353218,
  'gpt_16': 0.05419322848320007,
  'gpt_17': -0.049541175365448,
  'gpt_18': -0.01410716027021408,
  'gpt_19': -0.02453221008181572,
  'gpt_20': 0.062198225408792496,
  'gpt_21': -0.07528197020292282,
  'gpt_22': 0.059594471007585526,
  'gpt_23': 0.05839304253458977,
  'gpt_24': 0.0792023241519928,
  'gpt_25': -0.022146280854940414,

In [29]:
import pandas as pd


df_new = pd.DataFrame(new_ds)

In [30]:
df_new.head()

Unnamed: 0,sequence_orig,sequence,label,task,gpt_0,gpt_1,gpt_2,gpt_3,gpt_4,gpt_5,gpt_6,gpt_7,gpt_8,gpt_9,gpt_10,gpt_11,gpt_12,gpt_13,gpt_14,gpt_15,gpt_16,gpt_17,gpt_18,gpt_19,gpt_20,gpt_21,gpt_22,gpt_23,gpt_24,gpt_25,gpt_26,gpt_27,gpt_28,gpt_29,gpt_30,gpt_31,gpt_32,gpt_33,gpt_34,gpt_35,...,gpt_216,gpt_217,gpt_218,gpt_219,gpt_220,gpt_221,gpt_222,gpt_223,gpt_224,gpt_225,gpt_226,gpt_227,gpt_228,gpt_229,gpt_230,gpt_231,gpt_232,gpt_233,gpt_234,gpt_235,gpt_236,gpt_237,gpt_238,gpt_239,gpt_240,gpt_241,gpt_242,gpt_243,gpt_244,gpt_245,gpt_246,gpt_247,gpt_248,gpt_249,gpt_250,gpt_251,gpt_252,gpt_253,gpt_254,gpt_255
0,ACGGCAGCTCGCCATCATCG,ACGGCAGCTCGCCATCATCGGGG,0.3,BE39:MELJUSO:zscore,0.125914,0.051607,-0.047553,-0.022943,0.016563,-0.041671,0.023357,0.044482,0.11103,0.044558,-0.067879,0.071323,-0.03505,-0.033921,-0.001965,-0.020947,0.054193,-0.049541,-0.014107,-0.024532,0.062198,-0.075282,0.059594,0.058393,0.079202,-0.022146,0.040812,-0.100016,-0.09709,0.029079,0.008716,0.029332,0.06103,0.089095,0.020949,-0.132523,...,0.003617,0.014011,-0.088886,0.029078,-0.068608,-0.001504,0.022356,-0.028506,0.022791,0.07074,0.047942,0.083514,0.01225,-0.025119,0.078993,-0.003862,-0.015514,0.053628,0.087778,-0.00618,-0.666833,0.012253,-0.111187,0.002487,0.098091,0.064733,-0.161249,-4.56522,0.249375,0.057947,0.054672,0.106872,0.021743,1.013824,0.196617,0.051407,-0.101851,0.00199,0.058244,0.038512
1,TTCTCAGATATGGTCTTAAA,TTCTCAGATATGGTCTTAAAAGG,-3.0,BE39:MELJUSO:zscore,0.046595,-0.02445,-0.058685,0.006605,-0.011588,0.008578,0.044201,0.000304,0.022437,-0.033208,0.030673,-0.002213,0.055722,-0.005216,0.015321,0.012713,0.017152,0.000728,0.044287,0.036159,-0.03707,0.042819,0.022769,-0.051338,-0.019025,0.019725,-0.008076,0.0519,-0.003903,0.028443,0.013148,-0.032114,0.043635,0.04411,-0.033816,0.001119,...,0.005065,0.031149,0.001169,0.049138,0.001954,0.051153,0.009007,-0.002005,0.022871,0.068247,0.017211,0.082055,0.010086,0.009417,0.018061,0.007221,0.004519,0.044025,0.025114,0.007125,0.51708,0.015849,-0.048028,0.015808,0.011982,0.05147,0.008152,-3.878134,-0.027961,0.00997,0.018673,0.030737,-0.002636,-1.303357,-0.023476,0.03248,0.009351,0.00935,0.019897,0.029615
2,GCTGCAGTTGACACACTGGG,GCTGCAGTTGACACACTGGGTGG,0.0,BE39:MELJUSO:zscore,0.019764,-0.011866,-0.008757,0.029762,0.019669,0.05178,0.079187,0.030046,-0.025103,0.003659,-0.062993,0.053278,0.01982,0.023191,0.006679,-0.016876,0.001065,0.052163,-0.021933,0.007782,0.004648,-0.003641,0.03566,-0.008423,-0.058473,0.03008,0.027637,0.040358,0.014709,0.034148,-0.032597,-0.025167,0.00042,0.025336,-0.030684,0.007105,...,-0.007354,0.034133,0.000939,0.039005,-0.007096,0.047441,0.000857,0.081519,-0.001905,0.094714,0.000467,0.062526,0.029322,0.049515,0.013537,0.056761,-0.007342,0.050315,-0.034251,0.037931,0.765546,0.050311,0.009072,0.066612,-0.002333,0.040317,0.050887,-4.250634,-0.145295,0.06488,0.003328,0.058887,0.029008,0.67688,-0.113978,0.053165,0.019509,0.052915,0.011747,0.028595
3,GTGGTGTTCCGGCTTCAGGT,GTGGTGTTCCGGCTTCAGGTGGG,0.6,BE39:MELJUSO:zscore,0.059406,-0.026775,-0.026016,0.047232,0.03603,0.012072,0.013162,0.000736,0.006807,-0.02135,-0.028008,0.045886,0.032879,-0.048041,0.021854,0.004847,-0.058433,-0.061785,-0.013992,-0.029875,-0.026898,-0.01504,0.006215,-0.066658,0.066843,-0.030151,0.079054,0.038404,-0.016863,0.043335,0.053637,-0.00013,-0.001084,0.081036,-0.013629,-0.013296,...,0.009974,0.046436,-0.009647,0.061094,0.006363,0.091702,-0.016922,0.050879,0.019253,0.22662,-0.014389,0.08064,0.028546,0.045448,0.016225,0.069756,-0.004355,0.061293,-0.020023,0.036907,0.351479,0.068779,-0.010102,0.049406,0.030639,0.103413,0.002584,-4.408957,0.001425,0.066362,0.039539,0.085634,0.015784,0.225784,-0.060546,0.058458,-0.023439,0.036487,0.044296,0.071402
4,GGTGTCCCTTTGAAGGTGCT,GGTGTCCCTTTGAAGGTGCTGGG,0.6,BE39:MELJUSO:zscore,0.04262,-0.107365,0.065568,0.052436,0.030147,0.089224,0.047095,-0.027907,-0.00351,-0.059568,-0.055953,0.116043,0.007154,0.007261,0.015405,0.049635,0.008412,-0.027478,0.037917,-0.046747,-0.136207,0.05695,-0.007439,-0.046437,-0.015417,-0.006238,0.013901,0.013517,0.0142,0.027176,0.034579,-0.027821,0.008996,0.051305,-0.030875,-0.064363,...,-0.009396,0.041522,-0.00122,0.060058,0.006847,0.091724,0.007382,0.031856,0.004569,0.170114,0.013285,0.04269,-0.023456,0.01697,0.00914,0.035752,-0.030229,0.064989,0.000127,0.017434,1.218952,0.043436,-0.034053,0.0161,-0.001985,0.067431,0.051375,-4.528712,-0.055255,0.027935,0.022609,0.034118,-0.052611,-0.157527,-0.047898,0.037487,0.040009,0.021785,-0.018878,0.062287


In [31]:
df_new.to_parquet("data.tmp/ds_zscore_gpt12M_pretrain_only_features.parquet")

In [None]:
# https://huggingface.co/collections/sapienzanlp/ita-bench-italian-benchmarks-for-llms-66337ca59e6df7d7d4933896