In [1]:
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import pandas as pd
from pathlib import Path
from pyfaidx import Fasta

In [2]:
def exists(val):
    return val is not None

string_complement_map = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'a': 't', 'c': 'g', 'g': 'c', 't': 'a'}

def string_reverse_complement(seq):
    rev_comp = ''
    for base in seq[::-1]:
        if base in string_complement_map:
            rev_comp += string_complement_map[base]
        # if bp not complement map, use the same bp
        else:
            rev_comp += base
    return rev_comp
    
class FastaInterval():
    def __init__(
        self,
        *,
        fasta_file,
        # max_length = None,
        return_seq_indices = False,
        shift_augs = None,
        pad_interval = False,
    ):
        fasta_file = Path(fasta_file)
        assert fasta_file.exists(), 'path to fasta file must exist'

        self.seqs = Fasta(str(fasta_file))
        self.return_seq_indices = return_seq_indices
        # self.max_length = max_length # -1 for adding sos or eos token
        self.shift_augs = shift_augs
        self.pad_interval = pad_interval        

        # calc len of each chromosome in fasta file, store in dict
        self.chr_lens = {}

        for chr_name in self.seqs.keys():
            # remove tail end, might be gibberish code
            # truncate_len = int(len(self.seqs[chr_name]) * 0.9)
            # self.chr_lens[chr_name] = truncate_len
            self.chr_lens[chr_name] = len(self.seqs[chr_name])


    def __call__(self, chr_name, start, end, max_length, rc_aug=False):
        """
        max_length passed from dataset, not from init
        """
        interval_length = end - start
        chromosome = self.seqs[chr_name]
        # chromosome_length = len(chromosome)
        chromosome_length = self.chr_lens[chr_name]

        if exists(self.shift_augs):
            min_shift, max_shift = self.shift_augs
            max_shift += 1

            min_shift = max(start + min_shift, 0) - start
            max_shift = min(end + max_shift, chromosome_length) - end

            rand_shift = randrange(min_shift, max_shift)
            start += rand_shift
            end += rand_shift

        left_padding = right_padding = 0

        # checks if not enough sequence to fill up the start to end
        if interval_length < max_length:
            extra_seq = max_length - interval_length

            extra_left_seq = extra_seq // 2
            extra_right_seq = extra_seq - extra_left_seq

            start -= extra_left_seq
            end += extra_right_seq

        if start < 0:
            left_padding = -start
            start = 0

        if end > chromosome_length:
            right_padding = end - chromosome_length
            end = chromosome_length

        # Added support!  need to allow shorter seqs
        if interval_length > max_length:
            end = start + max_length

        seq = str(chromosome[start:end])

        if rc_aug:
            seq = string_reverse_complement(seq)

        if self.pad_interval:
            seq = ('.' * left_padding) + seq + ('.' * right_padding)

        return seq

## Load pilot datasets


In [None]:
!wget -N -P supplementary_data https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10028905/bin/media-1.xlsx
!wget -N -P supplementary_data https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10028905/bin/media-2.xlsx
!wget -N -P supplementary_data https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10028905/bin/media-3.xlsx
!wget -N -P supplementary_data https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10028905/bin/media-4.xlsx
!wget -N -P supplementary_data https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10028905/bin/media-5.xlsx
!wget -N -P supplementary_data https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10028905/bin/media-6.xlsx
!wget -N -P supplementary_data https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10028905/bin/media-7.xlsx
!wget -N -P supplementary_data https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10028905/bin/media-8.xlsx

In [12]:
def int_cast_index(df, keys):
    for key in keys:
        df[key] = df[key].astype('int')

def get_sequences(df, fasta_interval):
    # Get sequence from hg38 and add adapters
    #strand_to_bool = {"+": True, "-": False}
    #df['seq'] = df.apply(lambda row: fasta_interval(row['chr.hg38'], row['start.hg38'], row['stop.hg38'], max_length=200, rc_aug=strand_to_bool[row['str.hg38']]), axis=1)
    #df['seq'] = 'AGGACCGGATCAACT' + df['seq'] + 'CATTGCGTGAACCGA'
    df['seq'] = df["230nt sequence (15nt 5' adaptor - 200nt element - 15nt 3' adaptor)"]

    return df

def filter_data(df):
    df = df.dropna(how='any')
    df = df[~df['chr.hg38'].str.contains('random')]
    df = df[~df['chr.hg38'].str.contains('alt')]

    return df

fasta_interval = FastaInterval(fasta_file='/data/code/hyena-dna/data/hg38/hg38.ml.fa')

long_seq_name = "230nt sequence (15nt 5' adaptor - 200nt element - 15nt 3' adaptor)"
data_keys = ['category','chr.hg38','start.hg38','stop.hg38','str.hg38','mean',long_seq_name]

xls_sites = pd.ExcelFile('supplementary_data/media-3.xlsx')
xls_data = pd.ExcelFile('supplementary_data/media-4.xlsx')

sites_hepg2 = pd.read_excel(xls_sites, 'HepG2 large-scale', index_col='name', skiprows=3)
sites_k562 = pd.read_excel(xls_sites, 'K562 large-scale', index_col='name', skiprows=1)

data_hepg2 = pd.read_excel(xls_data, 'HepG2_summary_data', index_col='name')
data_k562 = pd.read_excel(xls_data, 'K562_summary_data', index_col='name')

# Build and clean final dataframes
data_hepg2 = pd.merge(data_hepg2, sites_hepg2, left_index=True, right_index=True)[data_keys]
data_hepg2 = data_hepg2.rename(columns={long_seq_name: 'seq'})
#data_hepg2 = filter_data(data_hepg2)
#int_cast_index(data_hepg2, ['start.hg38','stop.hg38'])
#data_hepg2 = get_sequences(data_hepg2, fasta_interval)


data_k562 = pd.merge(data_k562, sites_k562, left_index=True, right_index=True)[data_keys]
data_k562 = data_k562.rename(columns={long_seq_name: 'seq'})
data_k562 = filter_data(data_k562)
#int_cast_index(data_k562, ['start.hg38','stop.hg38'])
#data_k562 = get_sequences(data_k562, fasta_interval)


In [4]:
data_hepg2

Unnamed: 0_level_0,category,chr.hg38,start.hg38,stop.hg38,str.hg38,mean
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
DNasePeakNoPromoter37009,putative enhancer,chr20,22523762,22523962,+,-0.638
DNasePeakNoPromoter48063,putative enhancer,chr5,35470968,35471168,+,-0.537
DNasePeakNoPromoter48864_Reversed:,putative enhancer,chr5,78533467,78533667,-,-1.022
DNasePeakNoPromoter47281_Reversed:,putative enhancer,chr4,167025759,167025959,-,-0.732
DNasePeakNoPromoter51319,putative enhancer,chr6,7813807,7814007,+,-0.972
...,...,...,...,...,...,...
DNasePeakNoPromoter11251,putative enhancer,chr11,88434412,88434612,+,-1.226
DNasePeakNoPromoter65003_Reversed:,putative enhancer,chrX,47338041,47338241,-,-0.561
DNasePeakNoPromoter50457,putative enhancer,chr5,155004430,155004630,+,-1.649
ENSG00000165457,promoter,chr11,72216687,72216887,+,-0.497


## Train/test split

In [4]:
# Clean data directories (copied from https://stackoverflow.com/a/57892171)
def rm_tree(pth: Path):
    for child in pth.iterdir():
        if child.is_file():
            child.unlink()
        else:
            rm_tree(child)
    pth.rmdir()

folds = 10
data = data_hepg2
dataset_name = 'hepg2'

# Clean data directories
for k in range(0,folds):
    for BASE_FILE_PATH in [Path(f"data/k562/{k}"), Path(f"data/hepg2/{k}")]:
        if BASE_FILE_PATH.exists():
            rm_tree(BASE_FILE_PATH)
        BASE_FILE_PATH.mkdir()

kf = KFold(n_splits=folds, shuffle=True)
k = 0
for train_index, test_index in kf.split(data):
    train, test = data.iloc[train_index], data.iloc[test_index]

    # Save data files
    train.to_csv(Path('data') / dataset_name / str(k) / 'train.csv')
    test.to_csv(Path('data') / dataset_name / str(k) / 'val.csv')

    # Increment fold index
    k += 1

# Save total dataset for training
data.to_csv(Path('data') / dataset_name / '-1' / 'train.csv')

In [5]:
!cp -r data/hepg2 /data/code/hyena-dna/data/mpra_agarwal_seq/