In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
# from torch.amp import autocast, GradScaler
!pip install transformers
from transformers import AutoConfig, AutoModel, AutoTokenizer, AdamW

MODEL_NAME = "zhihan1996/DNA_bert_6"  # Or another DNABERT variant
KMER = 6  # The '6' in DNA_bert_6

import os

# Must be done *before* torch is imported or used
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

print(os.environ["CUDA_LAUNCH_BLOCKING"])

import multiprocessing

num_cores = multiprocessing.cpu_count()
print("Number of CPU cores:", num_cores)


print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("torch.cuda.is_available() =", torch.cuda.is_available())
print("torch.cuda.current_device() =", torch.cuda.current_device())



Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
1
Number of CPU cores: 176
PyTorch version: 2.2.2+cu121
CUDA version: 12.1
torch.cuda.is_available() = True
torch.cuda.current_device() = 0


In [2]:
df = pd.read_csv("CHM13_2995.csv")  # e.g. Code,Chromosome,Telomere,Sequence
#print(df.head())

# 4.1: Map Chromosome strings to integer IDs
df = df.dropna(subset=["Chromosome"])  # drop rows with Chromosome=NaN
unique_chrs = df["Chromosome"].unique().tolist()

def chr_sort_key(c):
    # interpret c as an int if possible, put X=23, Y=24, else bigger
    # adjust as you like:
    if str(c).isdigit():
        return int(c)
    elif c == "X":
        return 23
    elif c == "Y":
        return 24
    else:
        return 99999

sorted_chrs = sorted(unique_chrs, key=chr_sort_key)
chr2id = {ch: i for i, ch in enumerate(sorted_chrs)}
# After creating chr2id = { "1":0, "2":1, ..., "23":22 } (or with X, Y, etc.)
inv_chr2id = {v: k for k, v in chr2id.items()}

#print(chr2id)
#print("num_chr_labels =", num_chr_labels)    # Should be 24 if your labels are 0..23



df["chr_label"] = df["Chromosome"].map(chr2id)

# 4.2: Telomere labels are presumably 0/1/2 already
df["tel_label"] = df["Telomere"]

#print(df)
unique_labels = df["chr_label"].unique()
print("Unique chromosome labels:", unique_labels)
print("Max label:", unique_labels.max(), "Min label:", unique_labels.min())
print("tel unique:", df["tel_label"].unique())
print(df["Chromosome"].isnull().sum())  # or df['chr_label'].isnull().sum()
print(len(chr2id))
print("Empty sequences:", (df["Sequence"].isnull() | (df["Sequence"] == "")).sum())




Unique chromosome labels: [ 4 19  0  2  5 22  7  6 15 14  9 21 13  8  3 12 18 11 23 20  1 17 10 16]
Max label: 23 Min label: 0
tel unique: [1 2 0]
0
24
Empty sequences: 0


In [3]:
df_v = pd.read_csv("CN1_2995.csv")  # e.g. Code,Chromosome,Telomere,Sequence
#print(df.head())

# 4.1: Map Chromosome strings to integer IDs
df_v = df_v.dropna(subset=["Chromosome"])  # drop rows with Chromosome=NaN
unique_chrs = df_v["Chromosome"].unique().tolist()

def chr_sort_key(c):
    # interpret c as an int if possible, put X=23, Y=24, else bigger
    # adjust as you like:
    if str(c).isdigit():
        return int(c)
    elif c == "X":
        return 23
    elif c == "Y":
        return 24
    else:
        return 99999

sorted_chrs = sorted(unique_chrs, key=chr_sort_key)
chr2id = {ch: i for i, ch in enumerate(sorted_chrs)}
# After creating chr2id = { "1":0, "2":1, ..., "23":22 } (or with X, Y, etc.)
inv_chr2id = {v: k for k, v in chr2id.items()}

#print(chr2id)
#print("num_chr_labels =", num_chr_labels)    # Should be 24 if your labels are 0..23



df_v["chr_label"] = df_v["Chromosome"].map(chr2id)

# 4.2: Telomere labels are presumably 0/1/2 already
df_v["tel_label"] = df_v["Telomere"]

#print(df_v)
unique_labels = df_v["chr_label"].unique()
print("Unique chromosome labels:", unique_labels)
print("Max label:", unique_labels.max(), "Min label:", unique_labels.min())
print("tel unique:", df_v["tel_label"].unique())
print(df["Chromosome"].isnull().sum())  # or df['chr_label'].isnull().sum()
print(len(chr2id))
print("Empty sequences:", (df_v["Sequence"].isnull() | (df_v["Sequence"] == "")).sum())

Unique chromosome labels: [ 3  8  7  4 12 23 21 19  6 18  0 10  9  1 11  2 16 20  5 17 22 13 15 14]
Max label: 23 Min label: 0
tel unique: [1 2 0]
0
24
Empty sequences: 0


In [4]:
#The creators of DNABERT released separate checkpoints for k=3, k=4, k=5, k=6, (and sometimes k=7)
#  because different k‐mer sizes can capture different types of patterns in the DNA. 

def seq_to_kmers(seq, k=6):
    """
    Convert a DNA sequence into overlapping k-mers,
    then join them with spaces for DNABERT's tokenizer.
    """
    seq = seq.upper()
    kmers = []
    for i in range(len(seq) - k + 1):
        kmers.append(seq[i:i+k])
    return " ".join(kmers)

######################################
#  B. Chunking + K-mer Helpers
######################################
def chunk_sequence(seq, chunk_size=512, overlap=50):
    """
    Return a list of overlapping substrings from `seq`.
    Example: first chunk covers [0:512], next chunk covers [462:974], etc.
    """
    chunks = []
    start = 0
    length = len(seq)
    while start < length:
        end = start + chunk_size
        chunk = seq[start:end]
        chunks.append(chunk)
        if end >= length:
            break
        # move to next chunk with overlap
        start += (chunk_size - overlap)
    return chunks




In [5]:
class DNABertChunkedDataset(Dataset):
    """
    - Each row in the CSV can produce multiple chunk-examples (if the sequence is long).
    - We'll store them as separate items in the dataset.
    """
    def __init__(self, df, tokenizer, k=6, chunk_size=512, overlap=50, max_length=512):
        self.samples = []  # will hold dicts of form: { "kmers_str":..., "chr_label":..., "tel_label":... }
        self.tokenizer = tokenizer
        self.k = k
        self.max_length = max_length

        for idx in range(len(df)):
            row = df.iloc[idx]
            sequence = row["Sequence"]
            chr_label = row["chr_label"]
            tel_label = row["tel_label"]

            # Split the full sequence into overlapping chunks
            seq_chunks = chunk_sequence(sequence, chunk_size=chunk_size, overlap=overlap)
            for ch in seq_chunks:
                if len(ch) < self.k:
                    continue
                # Convert that chunk to k-mer text
                kmers_str = seq_to_kmers(ch, k=self.k)
                self.samples.append({
                    "kmers_str": kmers_str,
                    "chr_label": chr_label,
                    "tel_label": tel_label
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        kmers_str = item["kmers_str"]
        chr_label = item["chr_label"]
        tel_label = item["tel_label"]

        # Tokenize the k-mer string
        encoding = self.tokenizer(
            kmers_str,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels_chr': torch.tensor(chr_label, dtype=torch.long),
            'labels_tel': torch.tensor(tel_label, dtype=torch.long)
        }


In [6]:
######################################
#  D. Multi-Task Model
######################################
class MultiTaskDNABERT(nn.Module):
    def __init__(self, model_name, num_chr_labels, num_tel_labels, dropout=0.1):
        super().__init__()
        self.config = AutoConfig.from_pretrained(model_name)
        self.bert = AutoModel.from_pretrained(model_name, config=self.config)

        hidden_size = self.config.hidden_size
        self.dropout = nn.Dropout(dropout)

        # Two classification heads
        self.classifier_chr = nn.Linear(hidden_size, num_chr_labels)
        self.classifier_tel = nn.Linear(hidden_size, num_tel_labels)

    def forward(self, input_ids, attention_mask, labels_chr=None, labels_tel=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        # if there's a pooler_output, use it, else fallback to [CLS]
        if outputs.pooler_output is not None:
            pooled = outputs.pooler_output
        else:
            pooled = outputs.last_hidden_state[:, 0]

        x = self.dropout(pooled)
        logits_chr = self.classifier_chr(x)
        logits_tel = self.classifier_tel(x)

        loss = None
        if labels_chr is not None and labels_tel is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss_chr = loss_fct(logits_chr, labels_chr)
            loss_tel = loss_fct(logits_tel, labels_tel)
            loss = loss_chr + loss_tel

        return {
            'loss': loss,
            'logits_chr': logits_chr,
            'logits_tel': logits_tel
        }

In [8]:
######################################
#  E. Prepare Data
######################################
MODEL_NAME = "zhihan1996/DNA_bert_6"  # Or a local checkpoint
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

num_chr_labels = len(chr2id)  # e.g. ~24 if "1..22,X,Y"
num_tel_labels = 3            # 0,1,2

train_ds = DNABertChunkedDataset(
    df=df,
    tokenizer=tokenizer,
    k=6,
    chunk_size=512,   # chunk char length
    overlap=50,       # overlap in chars
    max_length=512    # DNABERT max token input
)

valid_ds = DNABertChunkedDataset(
    df=df_v,
    tokenizer=tokenizer,
    k=6,
    chunk_size=512,   # chunk char length
    overlap=50,       # overlap in chars
    max_length=512    # DNABERT max token input
)

val_size = int(0.5 * len(valid_ds))
leftover = len(valid_ds) - val_size
half_val_ds, leftover_ds = random_split(valid_ds, [val_size, leftover])

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=num_cores-1)
val_loader = DataLoader(half_val_ds, batch_size=16, shuffle=False, num_workers=num_cores-1)

In [9]:
######################################
#  F. Initialize Model & Train
######################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MultiTaskDNABERT(
    model_name=MODEL_NAME,
    num_chr_labels=num_chr_labels,
    num_tel_labels=num_tel_labels
)
model.to(device)

#cosine decay with warmup
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# optimizer = AdamW(model.parameters(), lr=2e-5)

batches_per_epoch = len(train_loader)  # e.g. 100

scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=batches_per_epoch,  # after T_0 steps, it restarts
    T_mult=1,               # no extension of the cycle length each time
    eta_min=0               # the minimum LR at the cosine nadir
)


best_acc_chrom = 0
EPOCHS = 10
try:
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for batch_idx, batch in enumerate(train_loader):
            if batch_idx % 200 == 0:
                print(f"Batch {batch_idx} out of {len(train_loader)}")
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_chr = batch['labels_chr'].to(device)
            labels_tel = batch['labels_tel'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels_chr=labels_chr,
                labels_tel=labels_tel
            )
            loss = outputs['loss']
            loss.backward()
            optimizer.step()
            
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} - Train Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        correct_chr, correct_tel = 0, 0
        total_samples = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels_chr = batch['labels_chr'].to(device)
                labels_tel = batch['labels_tel'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels_chr=labels_chr,
                    labels_tel=labels_tel
                )
                loss = outputs['loss']
                val_loss += loss.item()

                logits_chr = outputs['logits_chr']
                logits_tel = outputs['logits_tel']
                preds_chr = torch.argmax(logits_chr, dim=1)
                preds_tel = torch.argmax(logits_tel, dim=1)

                correct_chr += (preds_chr == labels_chr).sum().item()
                correct_tel += (preds_tel == labels_tel).sum().item()
                total_samples += len(labels_chr)

        avg_val_loss = val_loss / len(val_loader)
        acc_chr = correct_chr / total_samples
        acc_tel = correct_tel / total_samples
        print(f"Val Loss: {avg_val_loss:.4f}, Chr Acc: {acc_chr:.3f}, Tel Acc: {acc_tel:.3f}")
        if acc_chr > best_acc_chrom:
            best_acc_chrom = acc_chr
            torch.save(model.state_dict(), "model_best.pt")
except KeyboardInterrupt:
   # print("Training interrupted; saving partial model.")
    #torch.save(model.state_dict(), "my_partial_model.pt")
    print('interrupted')

print("Training complete.")

2025-03-13 00:54:53.192564: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-13 00:54:53.220707: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Batch 0 out of 2958
Batch 200 out of 2958
Batch 400 out of 2958
Batch 600 out of 2958
Batch 800 out of 2958
Batch 1000 out of 2958
Batch 1200 out of 2958
Batch 1400 out of 2958
Batch 1600 out of 2958
Batch 1800 out of 2958
Batch 2000 out of 2958
Batch 2200 out of 2958
Batch 2400 out of 2958
Batch 2600 out of 2958
Batch 2800 out of 2958
Epoch 1 - Train Loss: 2.9420
Val Loss: 3.4307, Chr Acc: 0.181, Tel Acc: 0.795
Batch 0 out of 2958
Batch 200 out of 2958
Batch 400 out of 2958
Batch 600 out of 2958
Batch 800 out of 2958
Batch 1000 out of 2958
Batch 1200 out of 2958
Batch 1400 out of 2958
Batch 1600 out of 2958
Batch 1800 out of 2958
Batch 2000 out of 2958
Batch 2200 out of 2958
Batch 2400 out of 2958
Batch 2600 out of 2958
Batch 2800 out of 2958
Epoch 2 - Train Loss: 2.0188
Val Loss: 3.3874, Chr Acc: 0.290, Tel Acc: 0.799
Batch 0 out of 2958
Batch 200 out of 2958
Batch 400 out of 2958
Batch 600 out of 2958
Batch 800 out of 2958
Batch 1000 out of 2958
Batch 1200 out of 2958
Batch 1400 out

In [None]:
######################################
#  G. (Optional) Chunk-Level -> Sequence-Level Aggregation
######################################
# If each sequence was chunked into N pieces, you might want an overall label for the entire sequence.
# One approach: gather chunk predictions for the same "Code" or row, then do majority vote or average logits.
# This is a quick example of how you might do it for "val_ds".

def predict_sequence(model, tokenizer, seq, k=6, chunk_size=512, overlap=50):
    """
    Return average logits across all chunks for (chr, tel).
    """
    seq_chunks = chunk_sequence(seq, chunk_size=chunk_size, overlap=overlap)
    model.eval()

    sum_logits_chr = None
    sum_logits_tel = None
    total_chunks = 0

    with torch.no_grad():
        for ch in seq_chunks:
            if len(ch) < k:
                    continue
            kmers_str = seq_to_kmers(ch, k=k)
            encoding = tokenizer(
                kmers_str,
                return_tensors='pt',
                truncation=True,
                padding='max_length',
                max_length=512
            )
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)
            outputs = model(input_ids, attention_mask)
            logits_chr = outputs['logits_chr']
            logits_tel = outputs['logits_tel']

            if sum_logits_chr is None:
                sum_logits_chr = logits_chr
                sum_logits_tel = logits_tel
            else:
                sum_logits_chr += logits_chr
                sum_logits_tel += logits_tel
            total_chunks += 1

    avg_chr = sum_logits_chr / total_chunks
    avg_tel = sum_logits_tel / total_chunks
    pred_chr_id = torch.argmax(avg_chr, dim=1).item()
    chr_str = inv_chr2id[pred_chr_id]
    pred_tel_id = torch.argmax(avg_tel, dim=1).item()
    return chr_str, pred_tel_id

In [None]:
# sequence = df.loc[1, "Sequence"]
# chrom = df.loc[1, "Chromosome"]
# telo = df.loc[1, "Telomere"]

print('ground truth', chrom, telo)

#accuracy 
acc = []
for index, row in df_v.iterrows():
    # index is the row index (int)
    # row is a Series with columns as keys
    sequence = row["Sequence"]
    chrom = row["Chromosome"]
    telo  = row["Telomere"]
    # do something with these values
    chr_str, tel_label = predict_sequence(
        model, 
        tokenizer, 
        seq=sequence,  # your test DNA sequence
        k=6,                  # must match your training k-mer
        chunk_size=512,       # same chunk size as training
        overlap=50            # same overlap as training
    )
    if chr_str == chrom and tel_label == telo:
        acc.append(1)
    else:
        acc.append(0)


print('acc',sum(acc)/len(acc))




# print('pred', chr_str, tel_label)