# Database connection

In [61]:
import clickhouse_connect

client = clickhouse_connect.get_client(host='localhost', username='cgoeldel', password='Goe1409ldel')



# Build Pedigree Graph from Nodelist and Edgelist

In [66]:
import networkx as nx

G = nx.DiGraph()

with open("nodelist.txt") as f:
    nodes = f.read().splitlines()

with open("edgelist.txt") as f:
    edges = f.read().splitlines()

for node in nodes: 

    if node == "filename,node,gen,meth" or node == "":
        continue
    
    filename, node, gen, meth = node.split(',')
    
    node = {
        'filename': filename,
        'node': node,
        'gen': int(gen),
        'meth': True if meth == 'Y' else False
    }

    

    G.add_node(node['node'], **node)

for edge in edges:

    if edge == "from,to":
        continue

    from_, to_ = edge.split(',')

    G.add_edge(from_, to_)

def get_predecessor_node(node):
    if not node in G:
        return None

    pred = iter(G.pred[node])

    if pred.__length_hint__() == 0:
        return None

    pred = G.nodes[next(pred)]

    if pred["meth"]:
        return pred
    else:
        return get_predecessor_node(pred['node'])

def get_pred_node_by_gen_and_line(gen, line):
    pred =  get_predecessor_node(f"{gen}_{line}")
    if pred is None:
        return None, None
    gen, line = pred["node"].split('_')
    return int(gen), int(line)


In [130]:
import polars as pl
from IPython.display import display
import torch

batch_size = 32

cs_neighbours = 5
hist_mod_neighbours = 20
gene_neighbours = 1
meth_neighbours = 20

mutables = 5 # not implemented yet




samples = client.query("select count(distinct(generation, line)) from methylome")
samples = samples.result_rows[0][0]

hist_mod_dict = {
   "input" : 1,
   "H3" : 2,
   "H3K4Me1" : 3,
   "H3K27Me3" : 4,
   "H2AZ" : 5,
   "H3K56Ac" : 6,
   "H3K4Me3" : 7    
}




def get_neighbours(site):
    # get surrounding genes
    genes = client.query_arrow(f"select type, ({site['location']} - start) as start_diff, ({site['location']} - end) as end_diff from genes where chromosome = {site['chromosome']} and strand = {site['strand']} order by abs(start_diff) limit {gene_neighbours}")
    g = pl.DataFrame(genes)
    # display(g)
    # Get surrounding histone mods 
    histone_mods = client.query_arrow(f"select modification, ({site['location']} - start ) as start_diff, ({site['location']} - end) as end_diff from histone_mods where chromosome = {site['chromosome']}  order by abs(start_diff) limit {hist_mod_neighbours}")
    h = pl.DataFrame(histone_mods)
    h = h.with_columns(pl.col("modification").map_dict(hist_mod_dict).alias("modification"))
    # display(h)
    # Get surrounding chromatine states
    chr_states = client.query_arrow(f"select state, ({site['location']} - start ) as start_diff, ({site['location']} -end) as end_diff from chr_states where chromosome = {site['chromosome']} order by abs(start_diff) limit {cs_neighbours } ")

    c = pl.DataFrame(chr_states)
    # display(c)
    # Get each site in all generations and lines                                                                                                                                                                                                                                                           The target site will be excluded later
    all_generations_and_neighbours = client.query_arrow(f"select * except (trinucleotide_context, pedigree, id), ({site['location']} - location ) as location_diff from methylome where chromosome = {site['chromosome']} and strand = {site['strand']} order by abs(location_diff), location_diff, generation, line limit {(meth_neighbours +1)  * samples}")
    m = pl.DataFrame(all_generations_and_neighbours) # (meth_neighbours * samples, 12)
    # display(m)

    site_across_generations = m.filter(pl.col("location_diff") == 0)
    m = m.filter(pl.col("location_diff") != 0)
   
    m = m.filter((pl.col("generation") != 0) & (pl.col("line") != 0))

    preds = []
    targets = []
    for site in site_across_generations.iter_rows(named=True):
        pred_gen, pred_line = get_pred_node_by_gen_and_line(site["generation"], site["line"])
        if not pred_gen is None:
            pred = site_across_generations.filter((pl.col("generation") == pred_gen) & (pl.col("line") == pred_line)) # (1, 12)
            preds.append(torch.tensor(pred.to_numpy()[0]))
            targets.append(site["meth_lvl"])

    t = torch.tensor(targets)
    p = torch.stack(preds) # (samples - 1, 12)
    g = torch.tensor(g.to_numpy()) 
    m = torch.tensor(m.to_numpy()).T 
    h = torch.tensor(h.to_numpy())
    c = torch.tensor(c.to_numpy())


    m = m.reshape((samples-1), 12 * meth_neighbours) # (samples -1, concatenated neighbours)
    
    # These are time-invariant
    g = g.reshape(3 * gene_neighbours)
    h = h.reshape(3 * hist_mod_neighbours)
    c = c.reshape(3* cs_neighbours)

    # surroundings 
    s = torch.cat([g, c, h])
    s = s.expand((samples -1), -1)

    x = torch.cat([m, p, s], dim=1)


    return x, t

def get_batch(mode):
    offset = 0. if mode == "train" else 0.9
    sample = 0.9 if mode == "train" else 0.1
    sites =  client.query_arrow(f"select * except (trinucleotide_context, pedigree, id) from methylome sample {sample} offset {offset} where generation = 0 order by rand() limit {batch_size}")

    df = pl.DataFrame(sites)
    # display(df)

    xs = []
    targets = []
    for site in df.iter_rows(named=True):
        x, target = get_neighbours(site)
        xs.append(x)
        targets.append(target)
   
    data = torch.cat(xs)
    targets = torch.cat(targets)
    return data, targets
    

data, targets = get_batch("train")
print(data.shape, targets.shape)

torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12]) torch.Size([12, 78])
torch.Size([12, 240]) torch.Size([12, 12

In [None]:
import polars as pl
import pyarrow as pa

data_sizes = client.query_arrow("Select count(*) as samples,chromosome, generation, line, pedigree from methylome_wt group by chromosome, pedigree, generation, line order by chromosome, generation, line")
num_chromosomes = len(data_sizes['chromosome'].unique())

data_sizes = pl.DataFrame(data_sizes)
data_sizes

samples,chromosome,generation,line,pedigree
u64,u8,u8,u8,str
10856447,1,0,0,"""MA3"""
10856447,1,1,2,"""MA3"""
10856447,1,1,8,"""MA3"""
10856447,1,2,2,"""MA3"""
10856447,1,2,8,"""MA3"""
10856447,1,4,2,"""MA3"""
10856447,1,4,8,"""MA3"""
10856447,1,5,2,"""MA3"""
10856447,1,5,8,"""MA3"""
10856447,1,8,2,"""MA3"""


In [None]:
samples_per_chromosome = data_sizes.group_by("chromosome").agg(pl.col("samples").min())
samples_per_chromosome =  samples_per_chromosome.to_struct("samples")

samples_per_chromosome.to_list()
samples_per_chromosome = {item['chromosome']:item['samples'] for item in samples_per_chromosome.to_list()}
samples_per_chromosome

{4: 6727432, 5: 9690976, 2: 7063707, 1: 10856447, 3: 8520954}

In [None]:
import torch
pedigree_dict = {
    'MA3': 0,
    'CMT3': 1,
    'SUV456': 2,
    'ROS': 3,
    'NRPE': 4
}

trinucleotide_context = {
'CTC': 0, 
 'CGA': 1,
 'CCC':  2,
 'CAT':  3,
 'CCG':  4,
 'CTG':  5,
 'CAA':  6,
 'CGT':  7,
 'CTT':  8,
 'CGG':  9,
 'CGN':  10,
 'CAC':  11,
 'CCT':  12,
 'CTA':  13,
 'CGS':  14,
 'CAG':  15,
 'CGC':  16,
 'CGK':  17,
 'CCA': 18,
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def encode_methylome(df): 
    df = df.with_columns(
        pl.col("pedigree").map_dict(pedigree_dict).alias("pedigree"),
        pl.col("trinucelotide_context").map_dict(trinucleotide_context).alias("trinucelotide_context"),
    )

    data = torch.tensor(df.to_numpy(), dtype=torch.float32 ,device=device)

    return data
    # torch.tensor(df.to_numpy, dtype=torch.float32 ,device=device)

batch_size = 6 * num_chromosomes # 6 samples per chromosome = 30 samples per batch
block_size = 8

def get_batch(context): 
    for i in range(num_chromosomes):

        idx = torch.randint(0, samples_per_chromosome[i] - block_size, (batch_size,))

    
        methylome_first_batch = client.query_arrow("select * from methylome_wt where location between 0 and 100000 limit 1000")
    

torch.Size([10000, 13])

In [None]:
chromosome_lengths = 

5