In [1]:
import h5py
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

from utils import parse_gtf

# Configuration
WINDOW_SIZE = 101
CHROMOSOME_PREFIX = ""  # "chr" if needed to match FASTA
RANDOM_SEED = 42
LEARNING_RATE = 0.0005
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

# 1. Preprocess All Chromosomes First =========================================
chromosome_names = [str(i) for i in range(1, 23)] + ["X", "Y"]

# Global gene split (all chromosomes)
gene_groups = parse_gtf("data/Homo_sapiens.GRCh38.113.chr.gtf.gz") 
gene_ids = list(gene_groups.groups.keys())
train_genes, test_genes = train_test_split(gene_ids, test_size=0.2, random_state=RANDOM_SEED)

2025-05-03 10:32:39.740274: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-03 10:32:39.741169: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-03 10:32:39.744034: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-03 10:32:39.806679: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746261159.866213 3386696 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746261159.88

In [2]:
def build_model(input_shape):
    inputs = tf.keras.Input(shape=input_shape)

    x = layers.Conv1D(64, kernel_size=10, activation='relu')(inputs)
    x = layers.MaxPooling1D(pool_size=2)(x)
    x = layers.Conv1D(128, kernel_size=5, activation='relu')(x)

    # Attention layer
    x = layers.LayerNormalization()(x)
    x = layers.MultiHeadAttention(num_heads=4, key_dim=32)(x, x)
    x = layers.GlobalAveragePooling1D()(x)

    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)

    # lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    #     initial_learning_rate=0.0003,
    #     decay_steps=10000
    # )
    
    model = models.Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
    )
    return model

In [10]:
import pandas as pd
from Bio import SeqIO
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, models


class ChromosomeDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, chr_files, batch_size=256):
        self.chr_files = chr_files
        self.batch_size = batch_size
        self.on_epoch_end()
        
    def __len__(self):
        return int(np.ceil(sum(
            h5py.File(f, 'r')['X_train'].shape[0] for f in self.chr_files
        ) / self.batch_size))
    
    def __getitem__(self, index):
        # Get random chromosome
        chr_file = np.random.choice(self.chr_files)
        
        with h5py.File(chr_file, 'r') as f:
            X = f['X_train']
            y = f['y_train']
            
            # Get random batch from chromosome
            start = np.random.randint(0, len(X) - self.batch_size)
            return X[start:start+self.batch_size], y[start:start+self.batch_size]
    
    def on_epoch_end(self):
        np.random.shuffle(self.chr_files)

class ValidationDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, chr_files, batch_size=512):
        self.chr_files = chr_files
        self.batch_size = batch_size
        
    def __len__(self):
        total_samples = sum(
            h5py.File(f, 'r')['X_test'].shape[0] for f in self.chr_files
        )
        return int(np.ceil(total_samples / self.batch_size))
    
    def __getitem__(self, index):
        # Load one chromosome at a time
        chr_file = self.chr_files[index % len(self.chr_files)]
        
        with h5py.File(chr_file, 'r') as f:
            X = f['X_test'][:]
            y = f['y_test'][:]
            
            # Return full chromosome test set (batched automatically)
            return X, y

# Initialize
val_generator = ValidationDataGenerator(
    [f'data/bin/train_test_data{n}.h5' for n in chromosome_names]
)

# Initialize
all_chr_files = [f'data/bin/train_test_data{n}.h5' for n in chromosome_names]
model = build_model((WINDOW_SIZE, 4))

# Training with chromosome rotation
history = model.fit(
    ChromosomeDataGenerator(all_chr_files),
    epochs=100,
    steps_per_epoch=1000,  # Adjust based on total data size
    validation_data=val_generator,
    validation_steps=20, #len(val_generator),
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=5,
            restore_best_weights=True
        ),
        tf.keras.callbacks.ModelCheckpoint(
            'best_model.h5',
            save_best_only=True
        )
    ]
)

Epoch 1/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 100ms/step - accuracy: 0.6014 - loss: 0.6524 - precision_4: 0.6053 - recall_4: 0.5710



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m131s[0m 130ms/step - accuracy: 0.6014 - loss: 0.6524 - precision_4: 0.6054 - recall_4: 0.5711 - val_accuracy: 0.7051 - val_loss: 0.5725 - val_precision_4: 0.7320 - val_recall_4: 0.6470
Epoch 2/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 102ms/step - accuracy: 0.7086 - loss: 0.5727 - precision_4: 0.7196 - recall_4: 0.6818



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m130s[0m 130ms/step - accuracy: 0.7086 - loss: 0.5727 - precision_4: 0.7196 - recall_4: 0.6818 - val_accuracy: 0.7184 - val_loss: 0.5556 - val_precision_4: 0.7572 - val_recall_4: 0.6429
Epoch 3/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 104ms/step - accuracy: 0.7202 - loss: 0.5597 - precision_4: 0.7295 - recall_4: 0.6993



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m133s[0m 133ms/step - accuracy: 0.7202 - loss: 0.5597 - precision_4: 0.7295 - recall_4: 0.6993 - val_accuracy: 0.7274 - val_loss: 0.5458 - val_precision_4: 0.7482 - val_recall_4: 0.6855
Epoch 4/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 99ms/step - accuracy: 0.7311 - loss: 0.5469 - precision_4: 0.7416 - recall_4: 0.7074



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m124s[0m 124ms/step - accuracy: 0.7311 - loss: 0.5469 - precision_4: 0.7416 - recall_4: 0.7074 - val_accuracy: 0.7269 - val_loss: 0.5456 - val_precision_4: 0.7610 - val_recall_4: 0.6615
Epoch 5/100
[1m 351/1000[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m1:11[0m 111ms/step - accuracy: 0.7302 - loss: 0.5428 - precision_4: 0.7439 - recall_4: 0.6962



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 73ms/step - accuracy: 0.7300 - loss: 0.5444 - precision_4: 0.7426 - recall_4: 0.6964 - val_accuracy: 0.7237 - val_loss: 0.5491 - val_precision_4: 0.7699 - val_recall_4: 0.6381
Epoch 6/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 150ms/step - accuracy: 0.7343 - loss: 0.5408 - precision_4: 0.7461 - recall_4: 0.7043 - val_accuracy: 0.7226 - val_loss: 0.5522 - val_precision_4: 0.6911 - val_recall_4: 0.8053
Epoch 7/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 120ms/step - accuracy: 0.7366 - loss: 0.5380 - precision_4: 0.7482 - recall_4: 0.7177



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m151s[0m 151ms/step - accuracy: 0.7366 - loss: 0.5380 - precision_4: 0.7482 - recall_4: 0.7177 - val_accuracy: 0.7316 - val_loss: 0.5409 - val_precision_4: 0.7752 - val_recall_4: 0.6525
Epoch 8/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 122ms/step - accuracy: 0.7408 - loss: 0.5336 - precision_4: 0.7538 - recall_4: 0.7169



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m156s[0m 156ms/step - accuracy: 0.7408 - loss: 0.5336 - precision_4: 0.7538 - recall_4: 0.7169 - val_accuracy: 0.7338 - val_loss: 0.5374 - val_precision_4: 0.7469 - val_recall_4: 0.7075
Epoch 9/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m149s[0m 149ms/step - accuracy: 0.7436 - loss: 0.5308 - precision_4: 0.7552 - recall_4: 0.7183 - val_accuracy: 0.7333 - val_loss: 0.5384 - val_precision_4: 0.7556 - val_recall_4: 0.6897
Epoch 10/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 71ms/step - accuracy: 0.7487 - loss: 0.5238 - precision_4: 0.7628 - recall_4: 0.7237 - val_accuracy: 0.7315 - val_loss: 0.5451 - val_precision_4: 0.7143 - val_recall_4: 0.7716
Epoch 11/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 112ms/step - accuracy: 0.7459 - loss: 0.5267 - precision_4: 0.7604 - recall_4: 0.7190



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 143ms/step - accuracy: 0.7459 - loss: 0.5267 - precision_4: 0.7604 - recall_4: 0.7190 - val_accuracy: 0.7346 - val_loss: 0.5372 - val_precision_4: 0.7292 - val_recall_4: 0.7464
Epoch 12/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 113ms/step - accuracy: 0.7426 - loss: 0.5288 - precision_4: 0.7561 - recall_4: 0.7174



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 144ms/step - accuracy: 0.7426 - loss: 0.5288 - precision_4: 0.7561 - recall_4: 0.7174 - val_accuracy: 0.7359 - val_loss: 0.5343 - val_precision_4: 0.7622 - val_recall_4: 0.6858
Epoch 13/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m146s[0m 147ms/step - accuracy: 0.7461 - loss: 0.5258 - precision_4: 0.7572 - recall_4: 0.7210 - val_accuracy: 0.7364 - val_loss: 0.5350 - val_precision_4: 0.7573 - val_recall_4: 0.6958
Epoch 14/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 114ms/step - accuracy: 0.7469 - loss: 0.5228 - precision_4: 0.7612 - recall_4: 0.7182



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m145s[0m 145ms/step - accuracy: 0.7469 - loss: 0.5228 - precision_4: 0.7612 - recall_4: 0.7182 - val_accuracy: 0.7371 - val_loss: 0.5334 - val_precision_4: 0.7347 - val_recall_4: 0.7421
Epoch 15/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 71ms/step - accuracy: 0.7527 - loss: 0.5174 - precision_4: 0.7641 - recall_4: 0.7336 - val_accuracy: 0.7373 - val_loss: 0.5337 - val_precision_4: 0.7633 - val_recall_4: 0.6878
Epoch 16/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m119s[0m 119ms/step - accuracy: 0.7517 - loss: 0.5184 - precision_4: 0.7650 - recall_4: 0.7231 - val_accuracy: 0.7260 - val_loss: 0.5518 - val_precision_4: 0.6973 - val_recall_4: 0.7989
Epoch 17/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 114ms/step - accuracy: 0.7559 - loss: 0.5130 - precision_4: 0.7687 - recall_4: 0.7321 - val_accuracy: 0.7368 - val_loss: 0.5339 - val_precision_4: 0.775



[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m57s[0m 57ms/step - accuracy: 0.7545 - loss: 0.5133 - precision_4: 0.7665 - recall_4: 0.7318 - val_accuracy: 0.7390 - val_loss: 0.5311 - val_precision_4: 0.7569 - val_recall_4: 0.7042
Epoch 21/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 113ms/step - accuracy: 0.7532 - loss: 0.5139 - precision_4: 0.7673 - recall_4: 0.7294 - val_accuracy: 0.7387 - val_loss: 0.5314 - val_precision_4: 0.7363 - val_recall_4: 0.7437
Epoch 22/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 113ms/step - accuracy: 0.7569 - loss: 0.5096 - precision_4: 0.7704 - recall_4: 0.7330 - val_accuracy: 0.7362 - val_loss: 0.5359 - val_precision_4: 0.7594 - val_recall_4: 0.6915
Epoch 23/100
[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 113ms/step - accuracy: 0.7614 - loss: 0.5040 - precision_4: 0.7761 - recall_4: 0.7363 - val_accuracy: 0.7351 - val_loss: 0.5370 - val_precision_4: 0.729

In [8]:
len(val_generator),

(541,)

using the batches from all chromosomes and bigger batch size improved the accuracy 0.745->0.7605

In [11]:
print("loading gtf...") 
gene_groups = parse_gtf("data/Homo_sapiens.GRCh38.113.chr.gtf.gz") # must be outside loop

# Split genes
print("splitting_genes...")
gene_ids = list(gene_groups.groups.keys())
train_genes, test_genes = train_test_split(
    gene_ids, test_size=0.2, random_state=RANDOM_SEED
)

loading gtf...


  gtf = pd.read_csv(


splitting_genes...


In [12]:
with h5py.File(f'genes.h5', 'w') as f:
    # Save the tuple as a dataset
    # f.create_dataset('gene_groups', data=gene_groups)
    f.create_dataset('train_genes', data=train_genes)
    f.create_dataset('test_genes', data=test_genes)
    gene_groups.obj.to_hdf('genes.h5', key='gene_df', mode='a')

HDF5ExtError: HDF5 error back trace

  File "H5F.c", line 836, in H5Fopen
    unable to synchronously open file
  File "H5F.c", line 796, in H5F__open_api_common
    unable to open file
  File "H5VLcallback.c", line 3863, in H5VL_file_open
    open failed
  File "H5VLcallback.c", line 3675, in H5VL__file_open
    open failed
  File "H5VLnative_file.c", line 128, in H5VL__native_file_open
    unable to open file
  File "H5Fint.c", line 1910, in H5F_open
    unable to lock the file
  File "H5FD.c", line 2412, in H5FD_lock
    driver lock request failed
  File "H5FDsec2.c", line 941, in H5FD__sec2_lock
    unable to lock file, errno = 11, error message = 'Resource temporarily unavailable'

End of HDF5 error back trace

Unable to open/create file 'genes.h5'