In [2]:
import random
import numpy as np
import torch
from datasets import DatasetDict, load_from_disk, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
ds = load_from_disk('hf_dataset_2024-12-06')
ds

DatasetDict({
    train: Dataset({
        features: ['Context', 'cell_sentences_data', 'Question', 'Answer', 'Keyword', 'full_QA_pair', 'Publication_Title', 'Dataset_URL', 'Dataset_Index', 'Used_Rows', 'openai_batch_response_id'],
        num_rows: 566
    })
    validation: Dataset({
        features: ['Context', 'cell_sentences_data', 'Question', 'Answer', 'Keyword', 'full_QA_pair', 'Publication_Title', 'Dataset_URL', 'Dataset_Index', 'Used_Rows', 'openai_batch_response_id'],
        num_rows: 105
    })
    test: Dataset({
        features: ['Context', 'cell_sentences_data', 'Question', 'Answer', 'Keyword', 'full_QA_pair', 'Publication_Title', 'Dataset_URL', 'Dataset_Index', 'Used_Rows', 'openai_batch_response_id'],
        num_rows: 99
    })
})

In [4]:
train_set= ds['train']
val_set = ds['validation']

In [5]:
train_set[0].keys()

dict_keys(['Context', 'cell_sentences_data', 'Question', 'Answer', 'Keyword', 'full_QA_pair', 'Publication_Title', 'Dataset_URL', 'Dataset_Index', 'Used_Rows', 'openai_batch_response_id'])

In [6]:
def combine_fields_v1(example):
    text = f"###context:\n{example['Context']}\n\n###cell_sentences_data:\n{example['cell_sentences_data']}\n\n###Question:\n{example['Question']}\n\n###Answer:\n{example['Answer']}"
    example["text"] = text
    return example

train_set = train_set.map(combine_fields_v1)
val_set = val_set.map(combine_fields_v1)


In [16]:
train_set[100]

{'Context': 'High levels of GAPDH are detected in the gene expression profile.',
 'cell_sentences_data': 'Cell Type: capillary endothelial cell, Tissue: apex of heart, Gene Expression: MALAT1 TMSB4X TMSB10 B2M MT-CO1 MT-CO3 RPL10 MT-CO2 ACTB RPLP1 MT-ATP6 EEF1A1 RPL41 MT-ND4 RPL13 MT-CYB PTMA RPS18 MT-ND3 IFITM3 TPT1 ACTG1 RPS8 RPL32 RPL11 VIM RPS19 RPS12 RPL28 MYL6 RPS14 RPS2 RPS27A RPS28 RPL34 RPS15 RPS24 RPS3A RPS15A RPS23 CAV1 FTH1 GAPDH RPL21 RPS27 RPL26 RPL15 RPS6 RPS3 FABP5 RPS7 RPL19 RPL8 RPLP2 RPL3 RPS4X RPL39 RPL13A RPL18A RPL29 RPL35A RPL37 RPL30 RPL37A LGALS1 FTL CALM1 HLA-B ID1 EGFL7 NEAT1 RPL7A GNG11 RPL36 RPL18 RPL12 SERF2 RPS9 SPARCL1 RPL23A RPS13 RPL14 RPL5 RPL9 RPS29 MT-ND1 RPL6 COL4A1 EIF1 RPL24 HBG2 RPL35 TCF4 RPL7 ITM2B RPS25 HLA-A RPL27A FAU ID3\n',
 'Question': 'Why is the expression of GAPDH significant in this endothelial cell from the heart tissue?',
 'Answer': 'GAPDH is a key glycolytic enzyme, and its expression indicates active glycolysis, which is importan

In [7]:
train_set['text']

['###context:\nEpicardial cell activity is linked to heart regeneration. Researchers compared gene expression of epicardial cells at different stages to find molecular differences.\n\n###cell_sentences_data:\nCell Type: endocardial cell, Tissue: apex of heart, Disease: normal, Gene Expression: MALAT1 RPL10 EEF1A1 RPLP1 MT-CO1 RPL41 RPS18 RPL13 RPS8 VIM RPS12 RPS2 RPL28 RPL32 MT-CO2 TMSB10 RPS3A MT-CO3 TPT1 RPS23 RPS15A RPL34 RPS14 RPL11 RPS27A RPS4X RPL26 RPL39 RPS3 RPS24 RPL15 RPL3 RPS27 TMSB4X RPS7 COL3A1 RPS19 MT-ATP6 RPL12 RPL13A RPL7A RPL30 RPS6 RPL18A RPL19 RPS28 RPL8 RPL21 MT-ND4 RPL37 RPL29 RPL18 RPS15 RPL17 PTMA RPS9 RPL37A RPL6 RPL35A RPL5 FTH1 ACTB RPLP2 RPS13 RPL9 MT-CYB RPL23A RPLP0 ACTG1 RPL36 RPL10A RPL14 RPS5 RPL24 RPL7 RACK1 IFITM3 RPS26 GAPDH RPS29 MT-ND3 RPL23 MYL6 RPS25 RPSA B2M FTL RPL35 SERF2 RPS16 RPL27A FN1 NACA POSTN RPL22 FAU SAT1 RPL36A UBA52 HSP90AB1\nCell Type: fibroblast, Tissue: apex of heart, Disease: normal, Gene Expression: MALAT1 RPL10 EEF1A1 MT-CO1 V

In [8]:
model_name = "vandijklab/C2S-Pythia-410m-diverse-single-and-multi-cell-tasks"
model = AutoModelForCausalLM.from_pretrained(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
# tokenizer.add_tokens(["<|Question|>", "<|Answer|>"]) 
model.resize_token_embeddings(len(tokenizer))

Embedding(50277, 1024)

In [9]:
tokenizer.pad_token, tokenizer.padding_side

('<|endoftext|>', 'right')

In [10]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
if tokenizer.padding_side == 'left':
    tokenizer.padding_side = 'right'

In [11]:
# Set training parameters
training_args = SFTConfig(
    dataset_text_field='text',
    output_dir="./sft_output",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    num_train_epochs=40,
    logging_steps=10,
    save_steps=50,
    save_total_limit=3,
    report_to="wandb",
    run_name="C2S-Pythia-410m-diverse-single-and-multi-cell-tasks-sft",
    evaluation_strategy="steps",
    eval_steps=50
)


# Initialize SFTTrainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_set,
    eval_dataset=val_set,
    args=training_args,
    peft_config=LoraConfig(
        r=16,
        task_type='CAUSAL_LM'
    ),
)
 
# Start training
trainer.train()

Map: 100%|██████████| 566/566 [00:00<00:00, 800.29 examples/s]
Map: 100%|██████████| 105/105 [00:00<00:00, 437.93 examples/s]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: [33mharryzhang957[0m ([33mresearch_harry[0m). Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avo

Step,Training Loss,Validation Loss
50,2.0885,1.916966
100,1.9013,1.74547
150,1.7733,1.61794
200,1.643,1.529825
250,1.5675,1.464653
300,1.4592,1.422663
350,1.5065,1.393125
400,1.3935,1.372238
450,1.4517,1.360513
500,1.3849,1.345942




TrainOutput(global_step=2800, training_loss=1.2995481548990522, metrics={'train_runtime': 1647.0043, 'train_samples_per_second': 13.746, 'train_steps_per_second': 1.7, 'total_flos': 3.552101170996224e+16, 'train_loss': 1.2995481548990522, 'epoch': 39.57597173144876})

In [None]:
x = A B C eos P P
y = A B C eos -100 -100

logit = A B C eos P P
y =     B C eos -100 -100



P      P   A B C eos
-100 -100  A B C eos


P      P   A B C
-100   A   B C eos
      -100



A B C P P 
