In [1]:
import numpy as np
import pandas as pd
import scanpy
import cell2sentence
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, Trainer, TrainingArguments
from datasets import load_dataset, load_metric, Dataset

In [2]:
eso = scanpy.read_h5ad("data/eso.h5ad")

In [3]:
eso.obs

Unnamed: 0_level_0,donor_id,Time,donor_time,organ,sample,n_genes,percent_mito,n_counts,leiden,assay_ontology_term_id,...,author_cell_type,suspension_type,cell_type,assay,disease,organism,sex,tissue,self_reported_ethnicity,development_stage
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGAAGGTTT-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1976.0,0.043828,9948.0,2,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGAGAGTCTGG-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1960.0,0.057559,8096.0,4,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGAGCCCAATT-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1598.0,0.054264,5805.0,0,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGCATGCCCGA-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,1805.0,0.045907,9345.0,2,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
AAACCTGTCGAACGGA-1-HCATisStab7413619,328C,0h,328C_0h,Oesophagus,HCATisStab7413619,311.0,0.043103,580.0,9,EFO:0009899,...,T_CD4,cell,"CD4-positive, alpha-beta T cell",10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,sixth decade human stage
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCACATTGGCGC-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,1205.0,0.002706,11455.0,5,EFO:0009899,...,Epi_upper,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTACCGTTA-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,2462.0,0.048540,11887.0,3,EFO:0009899,...,Epi_suprabasal,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTCTGCCAG-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,2314.0,0.062985,10050.0,0,EFO:0009899,...,Epi_stratified,cell,stratified epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage
TTTGTCAGTTTGACAC-1-HCATisStabAug177376568,325C,72h,325C_72h,Oesophagus,HCATisStabAug177376568,1108.0,0.053808,2639.0,3,EFO:0009899,...,Epi_suprabasal,cell,epithelial cell,10x 3' v2,normal,Homo sapiens,female,epithelium of esophagus,unknown,fifth decade human stage


In [4]:
eso_c2s = cell2sentence.transforms.csdata_from_adata(eso)

100%|██████████| 87947/87947 [00:32<00:00, 2700.13it/s]


In [5]:
cell_sents = eso_c2s.create_sentence_lists()

In [6]:
labels = eso.obs['cell_type'].values
labels = [str(label) for label in labels]

In [7]:
label2id = dict(enumerate(np.unique(labels)))

In [8]:
del(eso, eso_c2s)

In [9]:
cell_sents = [inner_list[:10] for inner_list in cell_sents]

In [10]:
cell_sents = [' '.join(inner_list) for inner_list in cell_sents]

In [11]:
# Create a pandas DataFrame with the sentences and labels
data = {"text": cell_sents, "label": labels}

In [12]:
del(cell_sents, labels)

In [13]:
# Create the dataset from the pandas DataFrame
dataset = Dataset.from_dict(data)

In [14]:
dataset = dataset.class_encode_column("label")

Casting to class labels:   0%|          | 0/87947 [00:00<?, ? examples/s]

In [15]:
dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 87947
})

In [16]:
num_classes = len(np.unique(dataset['label']))

In [17]:
dataset = dataset.train_test_split(test_size=0.2)

In [18]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 70357
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 17590
    })
})

In [19]:
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_classes)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier

In [20]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

# Tokenize the dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_dataset["train"]
test_dataset = tokenized_dataset["test"]

# Define the accuracy metric
accuracy = load_metric("accuracy")

# Define the compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.from_numpy(logits), dim=-1)
    return accuracy.compute(predictions=predictions, references=labels)

# Set up the training arguments
training_args = TrainingArguments(
    output_dir="output",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    logging_dir="logs",
    logging_strategy="steps",
    logging_steps=500,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True, 
)


# Create the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

# Evaluate the model
trainer.evaluate()

Map:   0%|          | 0/70357 [00:00<?, ? examples/s]

Map:   0%|          | 0/17590 [00:00<?, ? examples/s]

  accuracy = load_metric("accuracy")


  0%|          | 0/26385 [00:00<?, ?it/s]

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 0.8816, 'learning_rate': 4.9056281978396815e-05, 'epoch': 0.06}
{'loss': 0.6644, 'learning_rate': 4.810877392457836e-05, 'epoch': 0.11}
