<a href="https://colab.research.google.com/github/charleszhang418/SpaceX/blob/main/code/dnabert_finetune_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

+ Data: https://osdr.nasa.gov/bio/repo/data/studies/OSD-466
+ Sample:
  1. RR10_FCS_FLT_KO_F19: p21-null, Space Flight
  2. RR10_FCS_FLT_WT_F16: Wild Type, Space Flight
  3. RR10_FCS_GC_WT_G3: Wild Type, Ground Control
  4. RR10_FCS_GC_KO_G4: p21-null, Ground Control
  5. RR10_FCS_VIV_WT_V1: Wild Type, Vivarium Control
  6. RR10_FCS_VIV_KO_V13: p21-null,	Vivarium Control

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
# %cd gdrive/MyDrive/NASA
%cd gdrive/MyDrive/Project/data_science/nasa-space-app-2023/
!ls

/content/gdrive/MyDrive/Project/data_science/nasa-space-app-2023
code  data  model


In [3]:
!pip install transformers
!pip install torch
!pip install einops
!pip install transformers[torch]
!pip install evaluate

Collecting transformers
  Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m54.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.17.3-py3-none-any.whl (295 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m101.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m80.7 MB/s[0m eta [36m0:00:00[0m
Inst

In [4]:
import pandas as pd
import torch
from transformers import AutoTokenizer, BertModel
from transformers import BertForSequenceClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
import evaluate
import numpy as np

In [5]:
dna_data = pd.read_csv('data/dna_data.csv')
print(dna_data.shape)
dna_data = dna_data.drop_duplicates(subset='DNA')
print(dna_data.shape)
dna_data.head()

(36000, 5)
(35867, 5)


Unnamed: 0,notation,DNA,mass,filename,label
0,A00654:48:HN52TDRXY:1:2101:5990:1000 2:N:0:AGT...,ATATTTATGGCTGGACTTGAACTTACTAAGTAGACCATGCTGGCCT...,FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF...,GLDS-466_metagenomics_RR10_FCS_VIV_WT_V1_R2_HR...,Vivarium Control
1,A00654:48:HN52TDRXY:1:2101:6840:1000 2:N:0:AGT...,CCTACGCTCAGCGAGGCGACTTTGAGAGATGCGCCGAAGAATCTTT...,"FFFF:F,F:FFFFFFFFFFF,FFF::FFFFFFFFF,,F,,FFF,F:...",GLDS-466_metagenomics_RR10_FCS_VIV_WT_V1_R2_HR...,Vivarium Control
2,A00654:48:HN52TDRXY:1:2101:9498:1000 2:N:0:AGT...,GCCTTGACCCATGCCTGATAAGGGAGGGCCCGGTCGACGCCCAGGA...,":FFFFFFFFFFFFFFFF:FFFFFFFFFF,FF:FFFFF:FFFFFFFF...",GLDS-466_metagenomics_RR10_FCS_VIV_WT_V1_R2_HR...,Vivarium Control
3,A00654:48:HN52TDRXY:1:2101:15067:1000 2:N:0:AG...,GGACAGGGCCGCAGCATATTCTCATTAAACGGCTGGCCGTCATGGT...,FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF...,GLDS-466_metagenomics_RR10_FCS_VIV_WT_V1_R2_HR...,Vivarium Control
4,A00654:48:HN52TDRXY:1:2101:1298:1016 2:N:0:AGT...,TCCTCCATCCATGCTTTGAATGTCTGGAACCCTGCCTCTGTGTTAC...,FFFFFFFFFF:FFFFFFFFFFFFFFFFFFFFFFFFFFFF:FFFF:F...,GLDS-466_metagenomics_RR10_FCS_VIV_WT_V1_R2_HR...,Vivarium Control


In [6]:
# Label proportion
from collections import Counter
my_list = list(dna_data['label'])
element_count = Counter(my_list)
total_elements = len(my_list)
element_percentages = {key: (count / total_elements) * 100 for key, count in element_count.items()}
for key, percentage in element_percentages.items():
    print(f"{key}: {percentage:.2f}%")

Vivarium Control: 33.36%
Ground Control: 33.33%
Space Flight: 33.31%


In [7]:
tokenizer = AutoTokenizer.from_pretrained('zhihan1996/DNABERT-2-117M', trust_remote_code=True)
# model = BertModel.from_pretrained('zhihan1996/DNABERT-2-117M', trust_remote_code=True)
model = BertForSequenceClassification.from_pretrained('zhihan1996/DNABERT-2-117M', num_labels=3)

Downloading (…)okenizer_config.json:   0%|          | 0.00/158 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/168k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/862 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/468M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.1.intermediate.dense.weight', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.2.output.LayerNorm.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.

In [8]:
# # Create random label for testing
# import random
# N = len(dna_data)
# random_label_list = [random.randint(0, 2) for _ in range(N)]
# dna_data['label'] = random_label_list

In [9]:
# Label encoding

from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(list(dna_data['label']))
dna_data['label'] = encoded_labels

# original_labels = label_encoder.inverse_transform(encoded_labels)

In [10]:
train_data = dna_data.sample(frac=0.8, random_state=524)
test_eval_data = dna_data.drop(train_data.index)
eval_data = test_eval_data.sample(frac=0.5, random_state=524)
test_data = test_eval_data.drop(eval_data.index)

print(train_data.shape, eval_data.shape, test_data.shape)

train_dna = list(train_data['DNA'])
train_labels = list(train_data['label'])

val_dna = list(eval_data['DNA'])
val_labels = list(eval_data['label'])

test_dna = list(test_data['DNA'])
test_labels = list(test_data['label'])

train_encodings = tokenizer(train_dna, truncation=True, padding=True, return_tensors='pt')
val_encodings = tokenizer(val_dna, truncation=True, padding=True, return_tensors='pt')
test_encodings = tokenizer(test_dna, truncation=True, padding=True, return_tensors='pt')

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


(28694, 5) (3586, 5) (3587, 5)


In [11]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = CustomDataset(train_encodings, train_labels)
val_dataset = CustomDataset(val_encodings, val_labels)
test_dataset = CustomDataset(test_encodings, test_labels)

In [12]:
metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
  predictions, labels = eval_pred
  predictions = np.argmax(predictions, axis=1)
  acc = metric.compute(predictions=predictions, references=labels)['accuracy']
  return {
      'accuracy': acc
  }

training_args = TrainingArguments(
    output_dir='model/results',
    evaluation_strategy='epoch',
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    learning_rate=1e-5,
    logging_dir='model/logs'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

trainer.train()
results = trainer.evaluate(test_dataset)

model.save_pretrained('model/dnabert_finetuned_model')
tokenizer.save_pretrained('model/dnabert_tokenzier')

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Accuracy
1,1.0496,1.002535,0.483268
2,1.0056,0.982399,0.477133
3,0.9822,0.988015,0.472114
4,0.9715,0.979707,0.50976
5,0.9568,0.968686,0.491634
6,0.9416,0.96372,0.513664
7,0.9332,0.984247,0.511712
8,0.927,0.97234,0.516174
9,0.9194,0.971876,0.505577
10,0.9053,0.967732,0.512549


('model/dnabert_tokenzier/tokenizer_config.json',
 'model/dnabert_tokenzier/special_tokens_map.json',
 'model/dnabert_tokenzier/tokenizer.json')

In [None]:
model.evaluate(test_dataset)

In [13]:
from transformers import AutoModelForSequenceClassification, BertTokenizer
tokenizer = AutoTokenizer.from_pretrained('model/dnabert_tokenzier')
model = AutoModelForSequenceClassification.from_pretrained('model/dnabert_finetuned_model')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
from scipy.special import softmax

def get_output(dna, show_prob=False):

  # Get output
  tokens = tokenizer(dna, padding=True, truncation=True, return_tensors='pt')
  outputs = model(**tokens)

  # Softmax
  pred_prob = softmax(outputs.logits.detach().cpu().numpy())
  pred_prob = np.squeeze(pred_prob)

  # Get first n
  best_pos = np.argsort(pred_prob)[-1:]
  best_pos = np.flip(best_pos)

  # Return output
  best_out = label_encoder.inverse_transform(best_pos)
  return best_out

In [17]:
get_output('TGGGGGGA')

array(['Space Flight'], dtype='<U16')