# Build Pedigree Graph from Nodelist and Edgelist

In [5]:
import networkx as nx

G = nx.DiGraph()

with open("nodelist.txt") as f:
    nodes = f.read().splitlines()

with open("edgelist.txt") as f:
    edges = f.read().splitlines()

for node in nodes: 

    if node == "filename,node,gen,meth" or node == "":
        continue
    
    filename, node, gen, meth = node.split(',')
    
    node = {
        'filename': filename,
        'node': node,
        'gen': int(gen),
        'meth': True if meth == 'Y' else False
    }

    

    G.add_node(node['node'], **node)

for edge in edges:

    if edge == "from,to":
        continue

    from_, to_ = edge.split(',')

    G.add_edge(from_, to_)

def get_predecessor_node(node):
    if not node in G:
        return None

    pred = iter(G.pred[node])

    if pred.__length_hint__() == 0:
        return None

    pred = G.nodes[next(pred)]

    if pred["meth"]:
        return pred
    else:
        return get_predecessor_node(pred['node'])

def get_pred_node_by_gen_and_line(gen, line):
    pred =  get_predecessor_node(f"{gen}_{line}")
    if pred is None:
        return None, None
    gen, line = pred["node"].split('_')
    return int(gen), int(line)


# Hyperparamters

In [6]:
import torch

batch_size = 32

cs_neighbours = 5
hist_mod_neighbours = 20
gene_neighbours = 1
meth_neighbours = 20

mutables = 5 # not implemented yet

device = "cuda" if torch.cuda.is_available() else "cpu"


In [11]:
import torch
import clickhouse_connect
import polars as pl
import time

class Sites(torch.utils.data.IterableDataset):
    def __init__(self, mode="train"):
        super(Sites).__init__()
        self.mode = mode

        self.client = clickhouse_connect.get_client(host='localhost', username='cgoeldel', password='Goe1409ldel')
        genes = self.client.query_arrow(f"select * from genes where chromosome < 5") if mode == "train" else self.client.query_arrow(f"select * from genes where chromosome = 5")
        self.genes = pl.DataFrame(genes)




        # self.samples_per_chrom = pl.DataFrame(self.client.query_arrow("select chromosome, count(*) from methylome_within_gbM where generation = 0 group by chromosome").result_rows[0][0])
        # self.samples = self.client.query("select count(distinct(generation, line)) from methylome_within_gbM").results_rows[0][0]
        # self.hist_mod_dict = {  
        #         "input" : 1,
        #         "H3" : 2,
        #         "H3K4Me1" : 3,
        #         "H3K27Me3" : 4,
        #         "H2AZ" : 5,
        #         "H3K56Ac" : 6,
        #         "H3K4Me3" : 7    
        #         }

    
    def __iter__(self):
        for gene in self.genes.iter_rows(named=True):
            tic = time.perf_counter()

            # TODO: more elegant way to iterate over all pedigree branches
            gene_length = gene['end'] - gene['start']
            sites_line_2 = self.client.query_arrow(f"select location - {gene['start']} as start_diff, {gene['end']} -location as end_diff, location * 1.0 / ({gene_length}) as percentile, meth_lvl from methylome_within_gbM where chromosome = {gene['chromosome']} and (line = 2 or line = 0)  and strand = {gene['strand']} and location between {gene['start']} and {gene['end']} order by generation, line, start_diff")
            sites_line_2 = pl.DataFrame(sites_line_2)
            # sites_line_8 = self.client.query_arrow(f"select location - {gene['start']} as start_diff, {gene['end']} -location as end_diff, location * 1.0 / ({gene_length}) as percentile, meth_lvl from methylome_within_gbM where chromosome = {gene['chromosome']} and (line = 8 or line = 0)  and strand = {gene['strand']} and location between {gene['start']} and {gene['end']} order by generation, line, start_diff")
            # sites_line_8 = pl.DataFrame(sites_line_8)
            print(sites_line_2.shape)
            display(gene)
       
            targets_line_2 = sites_line_2['meth_lvl'].to_numpy()
            # targets_line_8 = sites_line_8['meth_lvl'].to_numpy()

            targets_line_2 = torch.tensor(targets_line_2, dtype=torch.float32)
            # targets_line_8 = torch.tensor(targets_line_8, dtype=torch.float32)

            targets_line_2 = targets_line_2.reshape(gene_length, 6)
            # targets_line_8 = targets_line_8.reshape(gene_length, 6)

            targets_line_2 = targets_line_2[1:, :] # remove first row, these can not be predicted as there is no predecessor
            # targets_line_8 = targets_line_8[1:, :]
            
            sites_line_2 = torch.tensor(sites_line_2.to_numpy(), dtype=torch.float32).reshape(gene_length * 4, 6)
            # sites_line_8 = torch.tensor(sites_line_8.to_numpy(), dtype=torch.float32).reshape(gene_length, 6)

            # remove the last row, these have no predecessor, so no known target
            sites_line_2 = sites_line_2[:-1, :]
            # sites_line_8 = sites_line_8[:-1, :]

            chr_states = self.client.query_arrow(f"select state, ({gene['start']} - start ) as start_diff, ({gene['end']} - end) as end_diff from chr_states where chromosome = {gene['chromosome']} and start between {gene['start']} and {gene['end']} and end between {gene['start']} and {gene['end']}")
            chr_states = pl.DataFrame(chr_states)

            chr_states_and_gene = chr_states.vstack(gene)

            chr_states = torch.tensor(chr_states_and_gene.to_numpy(), dtype=torch.float32)

            print(sites_line_2.shape)
            # print(sites_line_8.shape)
            print(chr_states.shape)
            print(targets_line_2.shape)
            # print(targets_line_8.shape)

            # every last node -> list back to root: decoder 
            # surrounding methylome -> encoder
            # surrounding gene + chromatine state -> encoder
            x = sites_line_2.to(device)
            c = chr_states.to(device)
            t = targets_line_2.to(device)


            toc = time.perf_counter()
            print(f"Gene {gene['chromosome']}:{gene['start']}:{gene['end']} in {toc - tic:0.4f} seconds")
            yield x, c, t
        
   
   
   
   
   
    


In [None]:
def save_surroundings(chromosome, location, idx, targets, mode="train"): 
    idx = idx.tolist()
    targets = targets.tolist()
    data = [chromosome, location, idx, targets, mode]
    
    client.insert("training_data", [data], column_names=["chromosome", "location", "prompt", "targets", "mode"])

In [None]:
def get_neighbours(site, mode="train"):
    # get surrounding genes
    tic = time.perf_counter()
    genes = client.query_arrow(f"select type, ({site['location']} - start) as start_diff, ({site['location']} - end) as end_diff from genes where chromosome = {site['chromosome']} and strand = {site['strand']} order by abs(start_diff) limit {gene_neighbours}")
    g = pl.DataFrame(genes)

    toc = time.perf_counter()
    # print(f"genes in {toc - tic:0.4f} seconds")
    # display(g)
    # Get surrounding histone mods 
    tic = time.perf_counter()

    offset = 200
    sufficient = False
    while not sufficient:
        histone_mods = client.query_arrow(f"select modification, ({site['location']} - start ) as start_diff, ({site['location']} - end) as end_diff from histone_mods where chromosome = {site['chromosome']} and start > {site['location'] - offset} and end < {site['location'] + offset} order by abs(start_diff) limit {hist_mod_neighbours}")
        h = pl.DataFrame(histone_mods)
        if len(h) == hist_mod_neighbours:
            sufficient = True
        else:
            print(f"Only {len(h)} histone mods found, increasing offset to {offset * 2}")
            offset *= 2
            continue

        h = h.with_columns(pl.col("modification").map_dict(hist_mod_dict).alias("modification"))
    
    toc = time.perf_counter()
    # print(f"Histone mods in {toc - tic:0.4f} seconds")
    # display(h)
    # Get surrounding chromatine states
    tic = time.perf_counter()
    chr_states = client.query_arrow(f"select state, ({site['location']} - start ) as start_diff, ({site['location']} -end) as end_diff from chr_states where chromosome = {site['chromosome']} order by abs(start_diff) limit {cs_neighbours } ")
    c = pl.DataFrame(chr_states)

    toc = time.perf_counter()
    # print(f"Chromatine State in {toc - tic:0.4f} seconds")
    # display(c)
    # Get each site in all generations and lines 
    tic = time.perf_counter()
    offset = 200
    sufficient = False
    while not sufficient:
        all_generations_and_neighbours = client.query_arrow(f"select * except (trinucleotide_context, pedigree, id), ({site['location']} - location ) as location_diff from methylome where chromosome = {site['chromosome']} and strand = {site['strand']} and location between  {site['location'] - offset} and  {site['location'] + offset} order by abs(location_diff), location_diff, generation, line limit {(meth_neighbours +1)  * samples}")
        m = pl.DataFrame(all_generations_and_neighbours) # (meth_neighbours * samples, 12)
        if len(m) == (meth_neighbours +1)  * samples:
            sufficient = True
        else:
            print(f"Only {len(m)} neighbours found, increasing offset to {offset * 2}")
            offset *= 2
            continue
    # display(m)

    toc = time.perf_counter()
    # print(f"neighbours in {toc - tic:0.4f} seconds")                
    tic = time.perf_counter()
    site_across_generations = m.filter(pl.col("location_diff") == 0)
    m = m.filter(pl.col("location_diff") != 0)
   
    m = m.filter((pl.col("generation") != 0) & (pl.col("line") != 0))

    preds = []
    targets = []
    for site in site_across_generations.iter_rows(named=True):
        pred_gen, pred_line = get_pred_node_by_gen_and_line(site["generation"], site["line"])
        if not pred_gen is None:
            pred = site_across_generations.filter((pl.col("generation") == pred_gen) & (pl.col("line") == pred_line)) # (1, 12)
            preds.append(torch.tensor(pred.to_numpy()[0], dtype=torch.float32))
            targets.append(site["meth_lvl"])

    toc = time.perf_counter()
    # print(f"Filtering for {toc - tic:0.4f} seconds")    
    tic = time.perf_counter()
    t = torch.tensor(targets, dtype=torch.float32)
    p = torch.stack(preds) # (samples - 1, 12)
    g = torch.tensor(g.to_numpy(), dtype=torch.float32) 
    m = torch.tensor(m.to_numpy(),dtype=torch.float32).T 
    h = torch.tensor(h.to_numpy(),dtype=torch.float32)
    c = torch.tensor(c.to_numpy(),dtype=torch.float32)

    m = m.reshape((samples-1), 12 * meth_neighbours) # (samples -1, concatenated neighbours)
    
    # These are time-invariant
    g = g.reshape(3 * gene_neighbours)
    h = h.reshape(3 * hist_mod_neighbours)
    c = c.reshape(3* cs_neighbours)

    # surroundings 
    s = torch.cat([g, c, h])
    s = s.expand((samples -1), -1)

    x = torch.cat([m, p, s], dim=1)

    toc = time.perf_counter()
    # print(f"Reshaping in {toc - tic:0.4f} seconds")
    save_surroundings(site["chromosome"], site["location"], x, t, mode)

    return x, t


In [None]:
import time
import numpy as np



def create_all_samples():

    start = 0
    end = 1e6
    total = 557173708

    while start < total:
        next_samples = client.query_arrow(f"select * except (trinucleotide_context, pedigree, id) from methylome where generation = 0 and context = 'CG' and location between {start} and {end} order by chromosome, location, strand")
        df = pl.DataFrame(next_samples)

        print(f"Creating {df.height} samples in range {start} to {end}")
        i = 0
        for site in df.iter_rows(named=True):
            tic = time.perf_counter()
            mode = np.random.choice(["train", "test", "validation"], p=[0.8, 0.1, 0.1])
            print(f"Creating sample {i} in mode {mode} at location {site['location']} on chromosome {site['chromosome']}")
            get_neighbours(site, mode)
            toc = time.perf_counter()
            print(f"Created sample {i} in {toc - tic:0.4f} seconds")
            i += 1
        
        start = end
        end = end + 1e6


tic = time.perf_counter()
# create_all_samples()
toc = time.perf_counter()
print(f"Creating all samples in {toc - tic:0.4f} seconds")


Creating 254372 samples in range 0 to 1000000.0
Creating sample 0 in mode test at location 109 on chromosome 1
Created sample 0 in 0.0418 seconds
Creating sample 1 in mode train at location 110 on chromosome 1
Created sample 1 in 0.0423 seconds
Creating sample 2 in mode train at location 115 on chromosome 1
Created sample 2 in 0.0364 seconds
Creating sample 3 in mode train at location 116 on chromosome 1
Created sample 3 in 0.0385 seconds
Creating sample 4 in mode train at location 161 on chromosome 1
Created sample 4 in 0.0372 seconds
Creating sample 5 in mode train at location 162 on chromosome 1
Created sample 5 in 0.0389 seconds
Creating sample 6 in mode train at location 310 on chromosome 1
Created sample 6 in 0.0380 seconds
Creating sample 7 in mode train at location 311 on chromosome 1
Created sample 7 in 0.0389 seconds
Creating sample 8 in mode train at location 500 on chromosome 1
Created sample 8 in 0.0375 seconds
Creating sample 9 in mode train at location 501 on chromosome 

KeyboardInterrupt: 

In [None]:

def get_batch(mode):
    offset = 0. if mode == "train" else 0.9
    sample = 0.9 if mode == "train" else 0.1
    tic = time.perf_counter()
    sites =  client.query_arrow(f"select * except (trinucleotide_context, pedigree, id) from methylome sample {sample} offset {offset} where generation = 0 order by rand() limit {batch_size}")
    toc = time.perf_counter()

    #print(f"fetch batch in {toc - tic:0.4f} seconds")
    df = pl.DataFrame(sites)
    # display(df)

    xs = []
    targets = []
    for site in df.iter_rows(named=True):
        x, target = get_neighbours(site)
        xs.append(x)
        targets.append(target)
   
    data = torch.cat(xs).to(device)
    targets = torch.cat(targets).to(device)
    return data, targets
    
tic = time.perf_counter()
data, targets = get_batch("train")
toc = time.perf_counter()
print(f"Everything in {toc - tic:0.4f} seconds")
print(data, targets)

# Defining the model

In [12]:
from torch import nn
from torch.nn import functional as F

class MethylationMaster(nn.Module):
    def __init__(self):
        super().__init__()

        self.transformer = nn.Transformer(d_model=330, nhead=33, dtype=torch.float32)

    def forward(self, x, c, targets = None):
        logits =  self.transformer(c, x, tgt_is_causal=True)
        print(logits.shape) # Should be (gene_length * 6)

        if targets is None:
            return logits, None

        loss = F.mse_loss(logits, targets)

        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx):
        logits = self.transformer(idx, idx) # (B * S, C)

        last = logits[-1]
        last_tar = targets[-1]

        probs = F.softmax(last, dim=-1) 
        # sample from the distribution
        guess = torch.multinomial(probs, num_samples=1) # (B, 1)


        print(f"Guess: {guess}, correct: {last_tar}")

        return guess == last_tar

# Training

In [13]:
from torch.utils.data import DataLoader

training_data = Sites(mode="train")
train_dataloader = DataLoader(training_data, batch_size=None)
conschti = MethylationMaster().to(device)

print(sum(p.numel() for p in conschti.parameters())/1e6, 'M parameters')
optimizer = torch.optim.Adam(conschti.parameters(), lr=1e-3)


24.134376 M parameters




In [15]:

trainings_steps = 1000

for i in range(trainings_steps):
    optimizer.zero_grad()

    x, c, t = next(iter(train_dataloader))


    logits, loss = conschti(x, c , t)
    loss.backward()
    optimizer.step()
    
    print(f"Step {i}, loss: {loss.item()}")


RuntimeError: shape '[2065, 6]' is invalid for input of size 0