In [None]:
import numpy as np
from pandas import DataFrame, read_csv
# read_csv()
import os

BLOCK_SIZE = 20
FILTER_THRESHOLD = 4

In [None]:
from collections import namedtuple, defaultdict


# Point = namedtuple('Point',['window', 'snp', 'bp'])

class Point:
    def __init__(self, snp, bp=0):
        self.snp, self.bp = snp, bp
    
    @property
    def window(self):
        return self.snp // BLOCK_SIZE

# class Specimen:
#     def __init__(self, ident, sequence)
#         ident, sequence
    
class Node:
    def __init__(self, ident, start, end, upstream=None, downstream=None):
        self.ident = ident
        self.start = start #Point()
        self.end = end #Point()
        # {nothing_node:501, Node: 38,  Node: 201, Node: 3}
        self.upstream = defaultdict(lambda: 0) if not upstream else upstream 
        # {Node: 38,  Node: 201, Node: 3}
        self.downstream = defaultdict(lambda: 0) if not downstream else downstream
        self.specimens = set()
    
    def __len__(self):
        return len(self.specimens)
    
    def __repr__(self):
        return "N%s(%s, %s)" % (str(self.ident), str(self.start.snp), str(self.end.snp))
    
    def __hash__(self):
        return hash(self.ident) + hash(self.start.snp) + hash(self.end.snp)
    
    def details(self):
        return f"""Node{self.ident}: {self.start.snp} - {self.end.snp}
        upstream: { dict((key, value) for key,value in self.upstream.items()) }
        downstream: { dict((key, value) for key,value in self.downstream.items()) }
        specimens: {self.specimens}"""
        

a = Point(0)
b = Point(14)
str(Node(57, a, b))
nothing_node = Node(-1, Point(None), Point(None))
global_nodes = {0: nothing_node}


deepcopy(a)

In [None]:
def read_data(file_path = "../test_data/KE_chromo10.txt"):
    """Individuals are rows, not columns"""
    loci = []
    with open(file_path) as ke:
        for line in ke.readlines():
            loci.append(tuple(int(x) for x in line.split()))
            
    
    individuals = np.array(loci).T.tolist()
    return loci, individuals
alleles, individuals = read_data()
assert len(alleles) == 32767
assert len(individuals[1]) == 32767
assert len(individuals) == 501

In [None]:
def first(iterable):
    return next(iter(iterable))

In [None]:
def signature(individual, start_locus):
    return tuple(individual[start_locus : start_locus + BLOCK_SIZE])

def get_unique_signatures(individuals, start_locus, block_size = 20):
    unique_blocks = {}
    for individual in individuals:
        sig = signature(individual, start_locus)
        if sig not in unique_blocks:
            unique_blocks[sig] = Node(len(unique_blocks), Point(start_locus // block_size, start_locus), 
                                      Point(start_locus // block_size, start_locus + BLOCK_SIZE)) #TODO: -1?
    
    return unique_blocks
unique_blocks = get_unique_signatures(individuals, 0 )
    
assert len(unique_blocks) == 4
unique_blocks
# assert unique_blocks == {(0, 2, 0, 0, 2, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 2, 0, 0, 0, 0): 0,
#  (0, 0, 2, 2, 0, 2, 0, 2, 2, 2, 0, 2, 2, 2, 0, 0, 2, 2, 2, 2): 1,
#  (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0): 2,
#  (2, 0, 2, 2, 0, 2, 0, 2, 2, 2, 0, 2, 2, 2, 0, 0, 2, 2, 2, 2): 3}

In [None]:
def get_all_signatures(alleles, individuals):
    unique_signatures = []
    for locus_start in range(0, len(alleles) - BLOCK_SIZE, BLOCK_SIZE):  # discards remainder 
        sig = get_unique_signatures(individuals, locus_start, BLOCK_SIZE)
        unique_signatures.append(sig)
    return unique_signatures
unique_signatures = get_all_signatures(alleles, individuals)

In [None]:
unique_signatures[21]

In [None]:
def build_individuals(individuals, unique_signatures):
    simplified_individuals = []
    for i_specimen, specimen in enumerate(individuals):
        my_simplification = []
        for w, window in enumerate(unique_signatures):  # the length of the genome
            sig = signature(specimen, w * BLOCK_SIZE)
    #         print(sig, unique_signatures[w][sig])
    #         print(i_specimen, window)
            my_simplification.append(unique_signatures[w][sig])
        simplified_individuals.append(my_simplification)
    return simplified_individuals
simplified_individuals = build_individuals(individuals, unique_signatures)
print(simplified_individuals[500][:100])
len(simplified_individuals), len(simplified_individuals[60])

# Nodes: Populate upstream and downstream

In [None]:
# build nodes:  first 4 are the 4 starting signatures in window 0.  
# For each node list which individuals are present at that node
# List transition rates from one node to all other upstream and downstream
def populate_transitions(simplified_individuals):
    for i, indiv in enumerate(simplified_individuals):
        # look what variants are present
        for x, node in enumerate(indiv):
            node.specimens.add(i)
            if x + 1 < len(indiv):
                node.downstream[indiv[x+1]] += 1
            else:
                node.downstream[nothing_node] += 1
            if x-1 >= 0:
                node.upstream[indiv[x-1]] += 1
            else: 
                node.upstream[nothing_node] += 1
            

In [None]:
unique_signatures = get_all_signatures(alleles, individuals)
simplified_individuals = build_individuals(individuals, unique_signatures)
populate_transitions(simplified_individuals)

#### TODO: turn these into tests

In [None]:
simplified_individuals[50][0].downstream

In [None]:
simplified_individuals[49][0].downstream

In [None]:
simplified_individuals[500][0].downstream

In [None]:
simplified_individuals[91][0].downstream

In [None]:
[x.downstream.values() for x in unique_signatures[1000].values()]

In [None]:
[x.upstream.values() for x in unique_signatures[1000].values()]

---------------

# Simple Merge

In [None]:
# TODO: add signature directly to node

In [None]:
from blist import blist
from copy import copy, deepcopy

In [None]:
def test_no_duplicate_nodes(global_nodes):
    unique_nodes = set()
    for node in global_nodes:
        if node in unique_nodes:
            print(node)
        else:
            unique_nodes.add(node)


In [None]:
# zoom_stack = [[]]
def simple_merge(global_nodes):
    new_layer = []  # TODO: copy old nodes to new layer conditionally
    n = 0
    while n < len(global_nodes):  # size of global_nodes changes, necessitating this weird loop
        node = global_nodes[n]
    #     print(node, type(node))
        if len(node.downstream) == 1: 
            next_node = first(node.downstream.keys())
            if len(node.specimens) == len(next_node.specimens):
                #Torsten deletes nodeA and modifies next_node
                next_node.upstream = node.upstream
                next_node.start = node.start
                #prepare to delete node by removing references
                for parent in node.upstream.keys():
                    if parent != nothing_node:
                        count = parent.downstream[node]
                        del parent.downstream[node]  # updating pointer 
                        parent.downstream[next_node] = count 
                global_nodes.remove(node)  #delete node
                # zoom_stack[0].append(merged)
                n -= 1
        n += 1
    return global_nodes        

In [None]:
def test_simple_merge():
    global_nodes = blist([node for window in unique_signatures for node in window.values()])  # think about referencing and deletion
    assert len(global_nodes) == 7180
    summary1 = simple_merge(global_nodes)
    assert len(summary1) == 3690
    return summary1
summary1 = test_simple_merge()

#### Neglect Nodes

In [None]:
def delete_node(node, cutoff):
    """Changes references to this node to add to references to nothing_node"""
    if cutoff < 1:
        return  # if cutoff is 0, then don't touch upstream and downstream
    for parent, count in node.upstream.items():
        parent.downstream[nothing_node] += parent.downstream[node]
        del parent.downstream[node]
    for descendant, count in node.downstream.items():
        descendant.upstream[nothing_node] += descendant.upstream[node]
        del descendant.upstream[node]
        

def neglect_nodes(all_nodes, deletion_cutoff=FILTER_THRESHOLD):
    nodes_to_delete = set()
#     filtered_nodes = copy(all_nodes)
#     filtered_nodes.remove(1)
#     assert len(all_nodes) != len(filtered_nodes)
    for node in all_nodes:
        if len(node.specimens) <= deletion_cutoff:
            delete_node(node, deletion_cutoff)  # TODO: check if this will orphan 
            nodes_to_delete.add(node)
    filtered_nodes = blist([x for x in all_nodes if x not in nodes_to_delete])
    # TODO: remove orphaned haplotypes in a node that transition to and from zero within a 10 window length
    return filtered_nodes 


def test_neglect_nodes(all_nodes):
    summary2 = neglect_nodes(all_nodes)
    assert len(summary2) == 2854
    unchanged = neglect_nodes(summary2, 0)
    assert len([n for n in unchanged if len(n.specimens) == 0]) == 0
    return summary2
summary2 = test_neglect_nodes(summary1)

In [None]:
summary2[8].details()

#### Split Groups

In [None]:
def split_one_group(prev_node, anchor, next_node):
    """ Called when up.specimens == down.specimens"""
    new = Node(777, prev_node.start, next_node.end, prev_node.upstream, next_node.downstream)  # TODO: what about case where more content is joining downstream?
    # Comment: That is actually the case we want to split up to obtain longer blocks later
    # Extension of full windows will take care of potential loss of information later
    
    if nothing_node != prev_node:
        new.specimens = anchor.specimens.intersection(prev_node.specimens)
    elif nothing_node != next_node:
        new.specimens = anchor.specimens.intersection(next_node.specimens)
    else:
        new.specimens = anchor.specimens
        for n in next_node.downstream.keys():
            if n != nothing_node:
                new.specimens = new.specimens.remove(n.specimens)
        for n in prev_node.upstream.keys():
            if n != nothing_node:
                new.specimens = new.specimens.remove(n.specimens)
    
    if nothing_node is prev_node:  # Rare case
        new.start = anchor.start
        new.upstream = anchor.upstream
    if nothing_node is next_node:
        new.end = anchor.end
        new.downstream = anchor.downstream
        

    
    print(new.details())
    print(new.upstream.keys())
    print(new.upstream.values())
    print(sum(new.upstream.values()))
    
    # Update upstream/downstream
    running = new.upstream.keys()

    ## n.upstream/downstream contains the same key multiple times?!
    ## My quick fix was to delete all upstream/downstream and just recalculate everything...
    new.upstream = defaultdict(lambda: 0) 
    for n in running:
        if n != nothing_node:
            new.upstream[n] = len(new.specimens.intersection(n.specimens))
            n.downstream[new] = new.upstream[n]
            n.downstream[prev_node] = n.downstream[prev_node] - n.downstream[new]
            if n.downstream[prev_node] == 0:
                del n.downstream[prev_node]
    
    running = new.downstream.keys()
    new.downstream = defaultdict(lambda: 0)
    for n in running:
        if n != nothing_node:
            new.downstream[n] = len(new.specimens.intersection(n.specimens))
            n.upstream[new] = new.downstream[n]
            n.upstream[next_node] = n.upstream[next_node] - n.upstream[new]
            if n.upstream[next_node] == 0:
                del n.upstream[next_node]
    
    print(new.details())
    print(new.upstream.keys())
    print(new.upstream.values())
    print(sum(new.upstream.values()))
    
    accounted_upstream = sum(new.upstream.values()) - new.upstream[nothing_node]
    #print(f'upstream {sum(new.upstream.values())} downstream {sum(new.downstream.values())}')
    new.upstream[nothing_node] = len(new.specimens) - accounted_upstream
    accounted_downstream = sum(new.downstream.values()) - new.downstream[nothing_node]
    new.downstream[nothing_node] = len(new.specimens) - accounted_downstream 
    
    assert all([count > -1 for count in new.upstream.values()]), new.details()
    assert all([count > -1 for count in new.downstream.values()]), new.details()
    # Update Specimens in prev_node, anchor, next_node
    if prev_node != nothing_node:
        prev_node.specimens -= new.specimens
    
    if next_node != nothing_node:
        next_node.specimens -= new.specimens
    
    anchor.specimens -= new.specimens
        
    ## anchor.specimens.difference_update(prev_node.specimens) REASON?
    return new

test_graph = summary2  # deepcopy(
example = test_graph[7]
original = deepcopy(example)
print(example.details())
def test_split_one_group(prev_node, anchor, next_node):
    x = split_one_group(prev_node, anchor, next_node)
    assert x
    answer = set(int(x)-1 for x in '14  16  19  20  28  56  59  69  88 133 140 155 159 160 175 193 199 201 224 249 252 258 260 267 268 283 292 295 318 322 325 332 341 344 346 351 354 357 362 364 367 373 374 375 381 386 392 393 394 402 403 417 421 424 426 431 434 435 438 442 445 447 452 455 457 462 463 464 467 471 473 475 476 477 478 480 483 484 494 497 501'.split())
    assert x.specimens == answer, 'Specimens set does not agree with HaploBlocker' + str(x.specimens.difference(answer))
    return x

x = test_split_one_group(first(example.upstream),  example, first(example.downstream) )

example.details()

original[7].details()

In [None]:
def split_groups(all_nodes):
    """This is called crossmerge in the R code"""
    number_of_windows = len(first(simplified_individuals))
    length = len(all_nodes)# size of global_nodes changes, necessitating this weird loop
    for n in range(length):  
        node = all_nodes[n]
        #check if all transitition upstream match with one of my downstream nodes
        #if set(node.upstream.values()) == set(node.downstream.values()): WHY?
        if node.start.snp != 0 and node.end.window != number_of_windows: #chr begin or end
            if len(node.specimens) > 0:
                # Matchup upstream and downstream with specimen identities
                for up in tuple(node.upstream.keys()):
                    for down in tuple(node.downstream.keys()):

                        set1 = up.specimens
                        set2 = down.specimens
                        if up == nothing_node:
                            set1 = node.specimens
                            for index in tuple(node.upstream.keys()):
                                set1.intersection(index.specimens) # =- does not work for empty sets
                        if down == nothing_node:
                            set2 = node.specimens
                            for index in tuple(node.downstream.keys()):
                                set2.intersection(index.specimens) # =- does not work for empty sets

                        if set1 == set2 and len(set1) > 0:
                            new_node = split_one_group(up, node, down)
                            all_nodes.append(new_node)
        
    filtered = neglect_nodes(all_nodes, 0)
    return filtered
    

def test_split_groups(all_nodes):
    summary3 = split_groups(all_nodes)
    assert summary3
    return summary3
summary3 = test_split_groups(summary2)
    

In [None]:
len(summary3) ## Order of nodes does matter here! HaploBlocker output: 1887

### Everything below does not work currently. Some operations on empty sets lead to crashes. 
#### We should clean up upstream/downstream to not display 0 transition cases. del did not work directly

### Simple-merge / Cross-merge runs

In [None]:
window_cluster = blist([node for window in unique_signatures for node in window.values()])  # think about referencing and deletion
len(window_cluster)

for i in range(10):
    window_cluster = simple_merge(window_cluster)
    window_cluster = split_groups(window_cluster)
    print(len(window_cluster))


In [None]:
len(window_cluster)

### Neglect nodes runs

In [None]:
for i in range(10):
    window_cluster = neglect_nodes(window_cluster)
    window_cluster = simple_merge(window_cluster)
    window_cluster = split_groups(window_cluster)
    window_cluster = simple_merge(window_cluster)

    print(len(window_cluster))


In [None]:
len(window_cluster)