In [1]:
import numpy as np
import pandas as pd
import scanpy
import cell2sentence
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, Trainer, TrainingArguments
from datasets import load_dataset, load_metric, Dataset

In [2]:
eso = scanpy.read_h5ad("data/eso.h5ad")

In [3]:
eso.obs

Unnamed: 0_level_0,donor_id,Time,donor_time,organ,sample,n_genes,percent_mito,n_counts,leiden,assay_ontology_term_id,...,author_cell_type,suspension_type,cell_type,assay,disease,organism,sex,tissue,self_reported_ethnicity,development_stage
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGAAGGTTT-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1976.0,0.043828,9948.0,2,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGAGAGTCTGG-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1960.0,0.057559,8096.0,4,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGAGCCCAATT-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1598.0,0.054264,5805.0,0,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGCATGCCCGA-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1805.0,0.045907,9345.0,2,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGTCGAACGGA-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,311.0,0.043103,580.0,9,EFO:0009899,...,T_CD4,cell,"CD4-positive, alpha-beta T cell",10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCACATTGGCGC-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,1205.0,0.002706,11455.0,5,EFO:0009899,...,Epi_upper,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTACCGTTA-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,2462.0,0.048540,11887.0,3,EFO:0009899,...,Epi_suprabasal,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTCTGCCAG-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,2314.0,0.062985,10050.0,0,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTTTGACAC-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,1108.0,0.053808,2639.0,3,EFO:0009899,...,Epi_suprabasal,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage


In [4]:
eso_c2s = cell2sentence.transforms.csdata_from_adata(eso)

100%|██████████| 87947/87947 [00:32<00:00, 2715.93it/s]


In [5]:
cell_sents = eso_c2s.create_sentence_lists()

In [6]:
labels = eso.obs['cell_type'].values
labels = [str(label) for label in labels]

In [7]:
label2id = dict(enumerate(np.unique(labels)))

In [8]:
del(eso, eso_c2s)

In [9]:
cell_sents = [inner_list[:100] for inner_list in cell_sents]

In [10]:
cell_sents = [' '.join(inner_list) for inner_list in cell_sents]

In [11]:
# Create a pandas DataFrame with the sentences and labels
data = {"text": cell_sents, "label": labels}

In [12]:
del(cell_sents, labels)

In [13]:
# Create the dataset from the pandas DataFrame
dataset = Dataset.from_dict(data)

In [14]:
dataset = dataset.class_encode_column("label")

Casting to class labels:   0%|          | 0/87947 [00:00<?, ? examples/s]

In [15]:
dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 87947
})

In [16]:
num_classes = len(np.unique(dataset['label']))

In [17]:
dataset = dataset.train_test_split(test_size=0.2)

In [18]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 70357
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 17590
    })
})

In [19]:
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_classes)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classi

In [20]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

# Tokenize the dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_dataset["train"]
test_dataset = tokenized_dataset["test"]

# Define the accuracy metric
accuracy = load_metric("accuracy")

# Define the compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.from_numpy(logits), dim=-1)
    return accuracy.compute(predictions=predictions, references=labels)

# Set up the training arguments
training_args = TrainingArguments(
    output_dir="output",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy="epoch",
    logging_dir="logs",
    logging_strategy="steps",
    logging_steps=250,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True, 
)


# Create the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

# Evaluate the model
trainer.evaluate()

Map:   0%|          | 0/70357 [00:00<?, ? examples/s]

Map:   0%|          | 0/17590 [00:00<?, ? examples/s]

  accuracy = load_metric("accuracy")


  0%|          | 0/13194 [00:00<?, ?it/s]

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 1.1501, 'learning_rate': 4.9056389267849026e-05, 'epoch': 0.06}
{'loss': 1.0584, 'learning_rate': 4.810898893436411e-05, 'epoch': 0.11}
{'loss': 0.9484, 'learning_rate': 4.716537820221313e-05, 'epoch': 0.17}
{'loss': 0.9636, 'learning_rate': 4.621797786872821e-05, 'epoch': 0.23}
{'loss': 0.8441, 'learning_rate': 4.5270577535243294e-05, 'epoch': 0.28}
{'loss': 0.744, 'learning_rate': 4.432696680309232e-05, 'epoch': 0.34}
{'loss': 0.6765, 'learning_rate': 4.33795664696074e-05, 'epoch': 0.4}
{'loss': 0.6339, 'learning_rate': 4.243216613612248e-05, 'epoch': 0.45}
{'loss': 0.5774, 'learning_rate': 4.1488555403971505e-05, 'epoch': 0.51}
{'loss': 0.5642, 'learning_rate': 4.0541155070486586e-05, 'epoch': 0.57}
{'loss': 0.5388, 'learning_rate': 3.959375473700167e-05, 'epoch': 0.63}
{'loss': 0.5354, 'learning_rate': 3.8646354403516756e-05, 'epoch': 0.68}
{'loss': 0.5379, 'learning_rate': 3.769895407003183e-05, 'epoch': 0.74}
{'loss': 0.5041, 'learning_rate': 3.675155373654692e-05, 'epoc

  0%|          | 0/1100 [00:00<?, ?it/s]

{'eval_loss': 0.4340948462486267, 'eval_accuracy': 0.8392268334280841, 'eval_runtime': 78.3012, 'eval_samples_per_second': 224.645, 'eval_steps_per_second': 14.048, 'epoch': 1.0}
{'loss': 0.449, 'learning_rate': 3.296195240260725e-05, 'epoch': 1.02}
{'loss': 0.4492, 'learning_rate': 3.2014552069122326e-05, 'epoch': 1.08}
{'loss': 0.4267, 'learning_rate': 3.1067151735637414e-05, 'epoch': 1.14}
{'loss': 0.436, 'learning_rate': 3.0119751402152495e-05, 'epoch': 1.19}
{'loss': 0.4208, 'learning_rate': 2.9172351068667573e-05, 'epoch': 1.25}
{'loss': 0.4062, 'learning_rate': 2.8224950735182658e-05, 'epoch': 1.31}
{'loss': 0.413, 'learning_rate': 2.7277550401697743e-05, 'epoch': 1.36}
{'loss': 0.3974, 'learning_rate': 2.6330150068212828e-05, 'epoch': 1.42}
{'loss': 0.3823, 'learning_rate': 2.5382749734727906e-05, 'epoch': 1.48}
{'loss': 0.353, 'learning_rate': 2.443534940124299e-05, 'epoch': 1.53}
{'loss': 0.3607, 'learning_rate': 2.3487949067758072e-05, 'epoch': 1.59}
{'loss': 0.3868, 'learni

  0%|          | 0/1100 [00:00<?, ?it/s]

{'eval_loss': 0.33921748399734497, 'eval_accuracy': 0.8783968163729392, 'eval_runtime': 78.5893, 'eval_samples_per_second': 223.822, 'eval_steps_per_second': 13.997, 'epoch': 2.0}
{'loss': 0.3264, 'learning_rate': 1.5916325602546615e-05, 'epoch': 2.05}
{'loss': 0.3195, 'learning_rate': 1.4968925269061696e-05, 'epoch': 2.1}
{'loss': 0.3208, 'learning_rate': 1.4021524935576778e-05, 'epoch': 2.16}
{'loss': 0.3018, 'learning_rate': 1.307412460209186e-05, 'epoch': 2.22}
{'loss': 0.299, 'learning_rate': 1.2126724268606942e-05, 'epoch': 2.27}
{'loss': 0.299, 'learning_rate': 1.1179323935122025e-05, 'epoch': 2.33}
{'loss': 0.3181, 'learning_rate': 1.0231923601637108e-05, 'epoch': 2.39}
{'loss': 0.3093, 'learning_rate': 9.284523268152191e-06, 'epoch': 2.44}
{'loss': 0.3142, 'learning_rate': 8.337122934667273e-06, 'epoch': 2.5}
{'loss': 0.3066, 'learning_rate': 7.389722601182356e-06, 'epoch': 2.56}
{'loss': 0.2745, 'learning_rate': 6.442322267697438e-06, 'epoch': 2.61}
{'loss': 0.3225, 'learning

  0%|          | 0/1100 [00:00<?, ?it/s]

{'eval_loss': 0.30558696389198303, 'eval_accuracy': 0.8888004548038658, 'eval_runtime': 78.393, 'eval_samples_per_second': 224.382, 'eval_steps_per_second': 14.032, 'epoch': 3.0}
{'train_runtime': 3198.745, 'train_samples_per_second': 65.986, 'train_steps_per_second': 4.125, 'train_loss': 0.4586493691766336, 'epoch': 3.0}


  0%|          | 0/1100 [00:00<?, ?it/s]

{'eval_loss': 0.30558696389198303,
 'eval_accuracy': 0.8888004548038658,
 'eval_runtime': 78.2943,
 'eval_samples_per_second': 224.665,
 'eval_steps_per_second': 14.05,
 'epoch': 3.0}