In [1]:
import numpy as np
import pandas as pd
import scanpy
import cell2sentence
import torch
from transformers import GPT2TokenizerFast, GPT2ForSequenceClassification, GPT2Config, TrainingArguments, Trainer
from datasets import load_dataset, Dataset
import evaluate

In [23]:
eso = scanpy.read_h5ad("data/eso.h5ad")

In [24]:
eso.obs

Unnamed: 0_level_0,donor_id,Time,donor_time,organ,sample,n_genes,percent_mito,n_counts,leiden,assay_ontology_term_id,...,author_cell_type,suspension_type,cell_type,assay,disease,organism,sex,tissue,self_reported_ethnicity,development_stage
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGAAGGTTT-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1976.0,0.043828,9948.0,2,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGAGAGTCTGG-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1960.0,0.057559,8096.0,4,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGAGCCCAATT-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1598.0,0.054264,5805.0,0,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGCATGCCCGA-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1805.0,0.045907,9345.0,2,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGTCGAACGGA-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,311.0,0.043103,580.0,9,EFO:0009899,...,T_CD4,cell,"CD4-positive, alpha-beta T cell",10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCACATTGGCGC-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,1205.0,0.002706,11455.0,5,EFO:0009899,...,Epi_upper,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTACCGTTA-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,2462.0,0.048540,11887.0,3,EFO:0009899,...,Epi_suprabasal,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTCTGCCAG-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,2314.0,0.062985,10050.0,0,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTTTGACAC-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,1108.0,0.053808,2639.0,3,EFO:0009899,...,Epi_suprabasal,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage


In [25]:
eso_c2s = cell2sentence.transforms.csdata_from_adata(eso)

100%|██████████| 87947/87947 [00:33<00:00, 2657.24it/s]


In [26]:
cell_sents = eso_c2s.create_sentence_lists()

In [27]:
labels = eso.obs['cell_type'].values
labels = [str(label) for label in labels]

In [7]:
label2id = dict(enumerate(np.unique(labels)))

In [29]:
del(eso, eso_c2s)

In [9]:
#cell_sents = [inner_list[:5] for inner_list in cell_sents]

In [28]:
cell_sents = [' '.join(inner_list) for inner_list in cell_sents]

In [11]:
# Create a pandas DataFrame with the sentences and labels
data = {"text": cell_sents, "label": labels}

In [12]:
del(cell_sents, labels)

In [13]:
# Create the dataset from the pandas DataFrame
dataset = Dataset.from_dict(data)

In [14]:
dataset = dataset.class_encode_column("label")

Casting to class labels:   0%|          | 0/87947 [00:00<?, ? examples/s]

In [15]:
dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 87947
})

In [16]:
num_classes = len(np.unique(dataset['label']))

In [17]:
dataset = dataset.train_test_split(test_size=0.2)

In [18]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 70357
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 17590
    })
})

In [19]:
gptconfig = GPT2Config.from_pretrained("distilgpt2", num_labels=num_classes)
gptconfig.pad_token_id = gptconfig.eos_token_id
model = GPT2ForSequenceClassification.from_pretrained("distilgpt2", config=gptconfig)
tokenizer = GPT2TokenizerFast.from_pretrained("distilgpt2")

#tokenizer.padding_side = "left" # Very Important
tokenizer.pad_token = tokenizer.eos_token


Some weights of the model checkpoint at distilgpt2 were not used when initializing GPT2ForSequenceClassification: ['lm_head.weight']
- This IS expected if you are initializing GPT2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding='max_length' , truncation=True)

# Tokenize the dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_dataset["train"]
test_dataset = tokenized_dataset["test"]

# Define the accuracy metric
#accuracy = load_metric("accuracy")
#clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])


accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
recall_metric = evaluate.load("recall")
precision_metric = evaluate.load("precision")


# Define the compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.from_numpy(logits), dim=-1)

    results = {}

    results.update(accuracy_metric.compute(predictions=predictions, references = labels))
    results.update(precision_metric.compute(predictions=predictions, references = labels, average="weighted"))
    results.update(recall_metric.compute(predictions=predictions, references = labels, average="weighted"))
    results.update(f1_metric.compute(predictions=predictions, references = labels, average="weighted"))

    return results



# Set up the training arguments
training_args = TrainingArguments(
    output_dir="output",
    num_train_epochs=6,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    evaluation_strategy="epoch",
    logging_dir="logs",
    logging_strategy="steps",
    logging_steps=250,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True, 
)


# Create the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)



# Train the model
trainer.train()

# Evaluate the model
trainer.evaluate()


Map:   0%|          | 0/70357 [00:00<?, ? examples/s]

Map:   0%|          | 0/17590 [00:00<?, ? examples/s]



  0%|          | 0/105540 [00:00<?, ?it/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 1.4021, 'learning_rate': 4.988487777146106e-05, 'epoch': 0.01}
{'loss': 1.0069, 'learning_rate': 4.9766439264733756e-05, 'epoch': 0.03}
{'loss': 0.8399, 'learning_rate': 4.964800075800645e-05, 'epoch': 0.04}
{'loss': 0.8341, 'learning_rate': 4.952956225127914e-05, 'epoch': 0.06}
{'loss': 0.7333, 'learning_rate': 4.941112374455183e-05, 'epoch': 0.07}
{'loss': 0.6863, 'learning_rate': 4.9292685237824526e-05, 'epoch': 0.09}
{'loss': 0.7255, 'learning_rate': 4.917424673109722e-05, 'epoch': 0.1}
{'loss': 0.6495, 'learning_rate': 4.905580822436991e-05, 'epoch': 0.11}
{'loss': 0.6172, 'learning_rate': 4.89373697176426e-05, 'epoch': 0.13}
{'loss': 0.5873, 'learning_rate': 4.8818931210915296e-05, 'epoch': 0.14}
{'loss': 0.6725, 'learning_rate': 4.870049270418799e-05, 'epoch': 0.16}
{'loss': 0.6061, 'learning_rate': 4.858205419746068e-05, 'epoch': 0.17}
{'loss': 0.6463, 'learning_rate': 4.846361569073337e-05, 'epoch': 0.18}
{'loss': 0.5934, 'learning_rate': 4.8345177184006066e-05, 'epoc

  0%|          | 0/4398 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.45902687311172485, 'eval_accuracy': 0.8764638999431495, 'eval_precision': 0.8755660911890922, 'eval_recall': 0.8764638999431495, 'eval_f1': 0.8748839975048095, 'eval_runtime': 314.7187, 'eval_samples_per_second': 55.891, 'eval_steps_per_second': 13.974, 'epoch': 1.0}
{'loss': 0.4164, 'learning_rate': 4.159749857873792e-05, 'epoch': 1.01}
{'loss': 0.4301, 'learning_rate': 4.147906007201061e-05, 'epoch': 1.02}
{'loss': 0.4561, 'learning_rate': 4.136062156528331e-05, 'epoch': 1.04}
{'loss': 0.4307, 'learning_rate': 4.1242183058556e-05, 'epoch': 1.05}
{'loss': 0.4602, 'learning_rate': 4.112374455182869e-05, 'epoch': 1.07}
{'loss': 0.4617, 'learning_rate': 4.1005779799128296e-05, 'epoch': 1.08}
{'loss': 0.434, 'learning_rate': 4.088734129240099e-05, 'epoch': 1.09}
{'loss': 0.4787, 'learning_rate': 4.076890278567368e-05, 'epoch': 1.11}
{'loss': 0.3888, 'learning_rate': 4.065046427894637e-05, 'epoch': 1.12}
{'loss': 0.4673, 'learning_rate': 4.0532025772219066e-05, 'epoch': 1.1

  0%|          | 0/4398 [00:00<?, ?it/s]

{'eval_loss': 0.36660194396972656, 'eval_accuracy': 0.9067652075042638, 'eval_precision': 0.9064059716168336, 'eval_recall': 0.9067652075042638, 'eval_f1': 0.9058255510293556, 'eval_runtime': 306.2607, 'eval_samples_per_second': 57.435, 'eval_steps_per_second': 14.36, 'epoch': 2.0}
{'loss': 0.3716, 'learning_rate': 3.331059314004169e-05, 'epoch': 2.0}
{'loss': 0.3156, 'learning_rate': 3.319215463331439e-05, 'epoch': 2.02}
{'loss': 0.412, 'learning_rate': 3.307371612658708e-05, 'epoch': 2.03}
{'loss': 0.3648, 'learning_rate': 3.295527761985977e-05, 'epoch': 2.05}
{'loss': 0.3997, 'learning_rate': 3.283683911313246e-05, 'epoch': 2.06}
{'loss': 0.3276, 'learning_rate': 3.271840060640516e-05, 'epoch': 2.08}
{'loss': 0.3326, 'learning_rate': 3.259996209967785e-05, 'epoch': 2.09}
{'loss': 0.3042, 'learning_rate': 3.248199734697745e-05, 'epoch': 2.1}
{'loss': 0.3834, 'learning_rate': 3.236355884025014e-05, 'epoch': 2.12}
{'loss': 0.3791, 'learning_rate': 3.224512033352284e-05, 'epoch': 2.13}


  0%|          | 0/4398 [00:00<?, ?it/s]

{'eval_loss': 0.36680781841278076, 'eval_accuracy': 0.9102899374644684, 'eval_precision': 0.9105499644260515, 'eval_recall': 0.9102899374644684, 'eval_f1': 0.9091558751596827, 'eval_runtime': 317.902, 'eval_samples_per_second': 55.332, 'eval_steps_per_second': 13.834, 'epoch': 3.0}
{'loss': 0.3032, 'learning_rate': 2.4904775440591246e-05, 'epoch': 3.01}
{'loss': 0.2737, 'learning_rate': 2.478681068789085e-05, 'epoch': 3.03}
{'loss': 0.3426, 'learning_rate': 2.466837218116354e-05, 'epoch': 3.04}
{'loss': 0.3847, 'learning_rate': 2.4549933674436236e-05, 'epoch': 3.06}
{'loss': 0.3076, 'learning_rate': 2.4431495167708926e-05, 'epoch': 3.07}
{'loss': 0.3371, 'learning_rate': 2.431305666098162e-05, 'epoch': 3.08}
{'loss': 0.3481, 'learning_rate': 2.419461815425431e-05, 'epoch': 3.1}
{'loss': 0.2845, 'learning_rate': 2.4076179647527006e-05, 'epoch': 3.11}
{'loss': 0.337, 'learning_rate': 2.3957741140799696e-05, 'epoch': 3.13}
{'loss': 0.3046, 'learning_rate': 2.38397763880993e-05, 'epoch': 3

  0%|          | 0/4398 [00:00<?, ?it/s]

{'eval_loss': 0.347149521112442, 'eval_accuracy': 0.9116543490619671, 'eval_precision': 0.9122106294927238, 'eval_recall': 0.9116543490619671, 'eval_f1': 0.9110278340785586, 'eval_runtime': 324.5989, 'eval_samples_per_second': 54.19, 'eval_steps_per_second': 13.549, 'epoch': 4.0}
{'loss': 0.2771, 'learning_rate': 1.6617870001895016e-05, 'epoch': 4.01}
{'loss': 0.2987, 'learning_rate': 1.649943149516771e-05, 'epoch': 4.02}
{'loss': 0.2486, 'learning_rate': 1.63809929884404e-05, 'epoch': 4.04}
{'loss': 0.3152, 'learning_rate': 1.6263028235740006e-05, 'epoch': 4.05}
{'loss': 0.2722, 'learning_rate': 1.6144589729012697e-05, 'epoch': 4.06}
{'loss': 0.28, 'learning_rate': 1.602615122228539e-05, 'epoch': 4.08}
{'loss': 0.3626, 'learning_rate': 1.5907712715558082e-05, 'epoch': 4.09}
{'loss': 0.3496, 'learning_rate': 1.5789274208830776e-05, 'epoch': 4.11}
{'loss': 0.2561, 'learning_rate': 1.5670835702103467e-05, 'epoch': 4.12}
{'loss': 0.3034, 'learning_rate': 1.555239719537616e-05, 'epoch': 4.

  0%|          | 0/4398 [00:00<?, ?it/s]

{'eval_loss': 0.3574374318122864, 'eval_accuracy': 0.9164866401364412, 'eval_precision': 0.9161942871428985, 'eval_recall': 0.9164866401364412, 'eval_f1': 0.9158628389106561, 'eval_runtime': 320.6787, 'eval_samples_per_second': 54.852, 'eval_steps_per_second': 13.715, 'epoch': 5.0}
{'loss': 0.3281, 'learning_rate': 8.33049080917188e-06, 'epoch': 5.0}
{'loss': 0.2737, 'learning_rate': 8.212526056471479e-06, 'epoch': 5.02}
{'loss': 0.2478, 'learning_rate': 8.094087549744173e-06, 'epoch': 5.03}
{'loss': 0.279, 'learning_rate': 7.975649043016866e-06, 'epoch': 5.05}
{'loss': 0.2555, 'learning_rate': 7.857210536289558e-06, 'epoch': 5.06}
{'loss': 0.2357, 'learning_rate': 7.738772029562252e-06, 'epoch': 5.07}
{'loss': 0.3322, 'learning_rate': 7.620333522834945e-06, 'epoch': 5.09}
{'loss': 0.3033, 'learning_rate': 7.501895016107638e-06, 'epoch': 5.1}
{'loss': 0.3084, 'learning_rate': 7.383456509380331e-06, 'epoch': 5.12}
{'loss': 0.2591, 'learning_rate': 7.265018002653023e-06, 'epoch': 5.13}
{

  0%|          | 0/4398 [00:00<?, ?it/s]

{'eval_loss': 0.3761449158191681, 'eval_accuracy': 0.917509948834565, 'eval_precision': 0.9171505741363544, 'eval_recall': 0.917509948834565, 'eval_f1': 0.9168567815653775, 'eval_runtime': 313.8903, 'eval_samples_per_second': 56.039, 'eval_steps_per_second': 14.011, 'epoch': 6.0}
{'train_runtime': 24531.2586, 'train_samples_per_second': 17.208, 'train_steps_per_second': 4.302, 'train_loss': 0.379336426080527, 'epoch': 6.0}


  0%|          | 0/4398 [00:00<?, ?it/s]

{'eval_loss': 0.3761449158191681,
 'eval_accuracy': 0.917509948834565,
 'eval_precision': 0.9171505741363544,
 'eval_recall': 0.917509948834565,
 'eval_f1': 0.9168567815653775,
 'eval_runtime': 313.8034,
 'eval_samples_per_second': 56.054,
 'eval_steps_per_second': 14.015,
 'epoch': 6.0}