In [119]:
import pandas as pd
import numpy as np
from Bio import SeqIO
import argparse
import torch
import pickle
import os

In [6]:
os.listdir()

['.ipynb_checkpoints', 'encode_snps.ipynb']

In [7]:
ref_file = "../data/reference/chr22.fa"
methy_file = "../data/methylation/253.bed"
snps_file = "../data/variant-calls/253.pass.vcf"

In [8]:
def read_and_filter_vcf(ifile, chrom):
    '''
    Reads and extracts relevant information from VCF file
    '''
    # ifile = "../../data/variant-calls/253.snps.vcf"
    df = pd.read_csv(ifile, 
                     sep='\t', 
                     comment='#', 
                     header=None, 
                     usecols=[0, 1, 3, 4, 9],
                     names=['chrom', 'pos', 'reference', 'alternate', 'extra_info']
                     )
    
    # Filter to shrink the number of comparisons that need to be made
    df = df.loc[df['chrom'] == chrom]
    
    # Pull out the SNP call
    df['variant_call'] = df['extra_info'].str[:3]

    # Get rid of Ns, indicate that ref homozygous
    df = df[df['reference'].isin(['A', 'C', 'G', 'T'])]

    # Makes iteration work
    df = df.drop(columns=['extra_info'])
    df = df.set_index("pos")

    return df

In [9]:
snps_df = read_and_filter_vcf(snps_file, "chr20")

In [10]:
snps_df.head()

Unnamed: 0_level_0,chrom,reference,alternate,variant_call
pos,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
61083,chr20,C,T,0/1
61406,chr20,T,C,0/1
61433,chr20,C,T,0/1
61540,chr20,A,G,0/1
61558,chr20,T,A,0/1


In [11]:

def read_fa(ifile, chrom):
    dict = SeqIO.to_dict(SeqIO.parse(open(ifile), 'fasta'))
    return(dict[chrom])

ref_seq = read_fa(ref_file, 'chr22')

In [12]:
chrom = "chr22"

In [44]:
def read_methy_data(ifile, chrom, min_coverage = 10):
    full_df = pd.read_table(ifile)

    sub_df = full_df.loc[(full_df['chrom'] == chrom) & (full_df['coverage'] >= min_coverage)].copy(deep=True)
    
    sub_df['pos'] = sub_df.chromStart + 1
    
    sub_df.set_index('pos', inplace = True)
    
    return(sub_df[['methylated', 'coverage', 'strand']])

    

In [45]:
methy_df = read_methy_data(methy_file, chrom)

In [46]:
methy_df.head()

Unnamed: 0_level_0,methylated,coverage,strand
pos,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
10572576,11,12,+/-
10572578,11,11,+/-
10572593,6,10,+/-
10598651,9,10,+/-
10615942,7,24,+/-


In [49]:
def check_methy_against_ref(ref_seq, methy_df):
    '''
    Checks whether the (pos)th values in ref_seq are cytosines. 
    '''

    count = 0
    num_cytosines = 0

    while (count < 100):
        # Loop through reference sequence
        for i, nt in enumerate(ref_seq):
            # pos will be one-based index of ref sequences
            pos = i + 1

            # If this position is in 
            if pos in methy_df.index:
                # Look at whether this is a C in the reference
                if nt == "C":
                    num_cytosines += 1
                count += 1

    return(num_cytosines / count)

check_methy_against_ref(ref_seq, methy_df)

0.9970046456250872

In [93]:


def encode_one_letter(x):
    '''Given A, C, G, T return one-hot encoding'''
    out = np.zeros((4,),dtype = 'float32')

    if x == "A":
        out[0] = 1
    elif x == "C":
        out[1] = 1
    elif x == "G":
        out[2] = 1
    elif x == "T":
        out[3] = 1
    
    return(out)


def encode_snp(row):
    '''
    Given chrom	pos	reference	alternate	variant_call
    encode with fractionals
    '''
    ref, alt, vc = row['reference'], row['alternate'], row['variant_call']

    if vc == "0/1":
        output = (encode_one_letter(ref) + encode_one_letter(alt)) / 2
    elif vc == "1/1":
        output = encode_one_letter(alt)
    elif vc == "1/2":
        alt_split = alt.split(alt)
        non_alt = np.setdiff1d(['A', 'C', 'G', 'T'], [ref])
        output = (encode_one_letter(non_alt[0]) + encode_one_letter(non_alt[1]) + encode_one_letter(non_alt[2])) / 3
    return output


def run_one_hot_encoder(sequence, snp_df):
    '''Workhorse function to encode a long string, swapping in snps when necessesary
    
    Arguments
    ---------
    sequence: str
        A (long) string of A, C, T, G defining the reference sequences
    '''

    n_snps = 0
    l = len(sequence)

    # No need to be precise, everyhting divisible by 8
    x = np.zeros((l, 4), dtype = 'float')

    # Remember that i starts at zero, 
    # whereas the positions in the VCF start at 1
    for i, nt in enumerate(sequence):
        # For position
        pos = i + 1

        if pos in snp_df.index:
            x[i, :] = encode_snp(snp_df.loc[pos])
            n_snps += 1
        else:
            x[i, :] = encode_one_letter(nt)
    return x, n_snps

In [114]:
w = 2

pos = methy_df.index[20000]

# Move left index back one to account for zero-based
lix = pos - w - 1
rix = pos + w + 1

seq = ref_seq[lix:rix].seq
seq

Seq('TGCGCC')

In [101]:
w = 499

# Initialize lists
all_pos = []
all_counts = []
all_encodings = []
all_snps = []

iter = 0

for pos, row in methy_df.iterrows():

    if (iter % 100000 == 0):
        print(f"Iteration {iter}")

    
    lix = pos - w - 1
    rix = pos + w + 1

    seq = ref_seq[lix:rix].seq

    one_hot_matrix, n_snps = run_one_hot_encoder(seq, snps_df)
    counts = row[['methylated', 'coverage']]

    all_pos.append(pos)
    all_counts.append(counts)
    all_encodings.append(one_hot_matrix)
    all_snps.append(n_snps)
    
    iter += 1




Iteration 0


KeyboardInterrupt: 

In [118]:
output = {
    'features': torch.tensor(np.array(all_encodings)).type(torch.uint8),
    'counts': torch.tensor(all_counts).type(torch.uint8),
    'position': np.array(all_pos),
    'n_snps': np.array(all_snps)
}
    

In [120]:
with open('../data/test.pkl', 'wb') as fp:
    pickle.dump(output, fp)

In [122]:
output2 = {
    'features': torch.tensor(np.array(all_encodings)).type(torch.uint8),
    'counts': torch.tensor(all_counts).type(torch.uint8),
    'position': torch.tensor(all_pos).type(torch.uint8),
    'n_snps': torch.tensor(all_snps).type(torch.uint8)
}

In [123]:
torch.save(output2, '../data/test.torch.pt')

In [124]:
z = torch.load('../data/test.torch.pt')

In [125]:
z['features']

tensor([[[0, 0, 1, 0],
         [0, 1, 0, 0],
         [0, 1, 0, 0],
         ...,
         [0, 0, 0, 1],
         [0, 0, 0, 1],
         [0, 0, 0, 1]],

        [[0, 1, 0, 0],
         [0, 0, 1, 0],
         [1, 0, 0, 0],
         ...,
         [0, 0, 0, 1],
         [1, 0, 0, 0],
         [0, 0, 0, 1]],

        [[0, 1, 0, 0],
         [0, 1, 0, 0],
         [0, 1, 0, 0],
         ...,
         [0, 0, 0, 1],
         [0, 0, 1, 0],
         [0, 0, 0, 1]],

        ...,

        [[0, 1, 0, 0],
         [0, 0, 1, 0],
         [0, 0, 1, 0],
         ...,
         [0, 0, 1, 0],
         [0, 0, 0, 1],
         [0, 1, 0, 0]],

        [[0, 1, 0, 0],
         [1, 0, 0, 0],
         [0, 1, 0, 0],
         ...,
         [0, 0, 0, 1],
         [0, 0, 1, 0],
         [0, 1, 0, 0]],

        [[1, 0, 0, 0],
         [0, 0, 1, 0],
         [0, 1, 0, 0],
         ...,
         [0, 1, 0, 0],
         [1, 0, 0, 0],
         [0, 1, 0, 0]]], dtype=torch.uint8)