In [8]:
%load_ext rpy2.ipython

from cap2.capalyzer.pangea import PangeaFileSource
from cap2.capalyzer.pangea.utils import get_pangea_group
from cap2.capalyzer.table_builder import CAPTableBuilder

from plotnine import *
import pandas as pd
import warnings
from glob import glob
from os.path import isfile

warnings.filterwarnings('ignore')

twins_group = get_pangea_group('Mason Lab', 'NASA Twins', 'dcdanko@gmail.com', )
twins_source = PangeaFileSource(twins_group)
twins = CAPTableBuilder('twins', twins_source)
iss_group = get_pangea_group('Mason Lab', 'NASA ISS', 'dcdanko@gmail.com', )
iss_source = PangeaFileSource(iss_group)
iss = CAPTableBuilder('iss', iss_source)


The rpy2.ipython extension is already loaded. To reload it, use:
  %reload_ext rpy2.ipython


In [18]:


import networkx as nx
from functools import lru_cache

class SNPBarcode:

    def __init__(self, tbl):
        self.tbl = tbl
        self.sample_name = tbl['sample_name'].unique()[0]
        self.seq = tbl['seq'].unique()[0]
        self.snps = set(zip(tbl['coord'], tbl['changed']))
        self.min_pos = 1000 * 1000 * 1000
        self.max_pos = -1
        for pos, _ in self.snps:
            if pos < self.min_pos:
                self.min_pos = pos
            if pos > self.max_pos:
                self.max_pos = pos

    def __len__(self):
        return len(self.snps)
    
    def __str__(self):
        return self.sample_name + ';' + self.seq + ';' + ','.join([f'{a}:{b}' for a, b in sorted(list(self.snps))])    
    def __hash__(self):
        return hash(str(self))


class SNPBarcodeSet:

    def __init__(self):
        self.seq = None
        self.snps = set([])
        self.barcodes = []
        self.min_pos = 1000 * 1000 * 1000
        self.max_pos = -1

    def __len__(self):
        return len(self.snps)

    def add_barcode(self, bc):
        if self.seq:
            assert bc.seq == self.seq
        else:
            self.seq = bc.seq
        self.snps |= bc.snps
        if bc.min_pos < self.min_pos:
            self.min_pos = bc.min_pos
        if bc.max_pos > self.max_pos:
            self.max_pos = bc.max_pos
        self.barcodes.append(bc)
        return self


def barcode_barcode_similarity(b1, b2):
    if b1.seq != b2.seq:
        return 0
    if b1.min_pos > b2.max_pos:
        return 0
    if b1.max_pos < b2.min_pos:
        return 0
    jaccard = len(b1.snps & b2.snps) / min(len(b1.snps), len(b2.snps))
    return jaccard


def barcode_barcode_set_similarity(bc, bc_set):
    if bc.seq != bc_set.seq:
        return 0
    if bc.min_pos > bc_set.max_pos:
        return 0
    if bc.max_pos < bc_set.min_pos:
        return 0
    jaccard = len(bc.snps & bc_set.snps) / min(len(bc.snps), len(bc_set.snps))
    return jaccard


def build_barcode_sets(barcodes, sim_thresh=0.5):
    """Return a list of SNPBarcodeSets that fulfill sevreal reqs.

     - all barcodes are in one or more barcode sets
     - each barcode in a set has similarity of at least sim_thresh to that set
    """
    barcode_sets = []
    for i, bc in enumerate(barcodes):
        if i % (25 * 1000) == 0:
            print(i, len(barcode_sets))
        added_to_bc_set = False
        for bc_set in barcode_sets:
            s = barcode_barcode_set_similarity(bc, bc_set)
            if s < sim_thresh:
                continue
            added_to_bc_set = True
            bc_set.add_barcode(bc)
        if not added_to_bc_set:
            new_bc_set = SNPBarcodeSet().add_barcode(bc)
            barcode_sets.append(new_bc_set)
    return barcode_sets


def barcode_barcode_similarity_graph(barcodes, external_bcs=[], sim_thresh=0.5):
    """Return a Graph with edges between similar barcodes.

    - Barcodes with no similar barcdoes are not included
    - weight of each edge is the similarity
    """
    barcode_sets = build_barcode_sets(barcodes, sim_thresh=sim_thresh)
    G = nx.Graph()
    for bc_set in barcode_sets:
        for bc1 in bc_set.barcodes:
            for bc2 in bc_set.barcodes:
                if bc1 == bc2:
                    break
                s = barcode_barcode_similarity(bc1, bc2)
                if s >= sim_thresh:
                    G.add_edge(bc1, bc2, weight=s)
    
    comps = list(nx.connected_components(G))
    print(f'finished building clusters. attaching externals to {len(comps)} clusters.')
    comp_count = {}
    for i, bc1 in enumerate(external_bcs):
        if i % (50 * 1000) == 0:
            print(f'Processed {i} external bcs')
        for comp_ind, comp in enumerate(comps):
            if comp_count.get(comp_ind, 0) >= 2:
                continue
            for bc2 in comp:
                s = barcode_barcode_similarity(bc1, bc2)
                if s >= sim_thresh:
                    comp_count[comp_ind] = comp_count.get(comp_ind, 0) + 1
                    G.add_edge(bc1, bc2, weight=s)
                    break
                
    return G

def parse_snp_clusters(sample_name, filepath):
    tbl = pd.read_csv(filepath, compression='gzip', index_col=0)
    tbl = tbl.query('weight >= 10')
    tbl['sample_name'] = sample_name
    barcodes = [bc for bc in tbl.groupby('cluster').apply(SNPBarcode) if len(bc) >= 5]
    return barcodes

In [10]:

def get_graph(organism):
    twins_filepaths = list(twins_source('cap2::experimental::make_snp_clusters', f'snp_clusters__{organism}'))
    iss_filepaths = list(iss_source('cap2::experimental::make_snp_clusters', f'snp_clusters__{organism}'))
    all_barcodes = []
    for sample_name, filepath in twins_filepaths:
        try:
            all_barcodes += parse_snp_clusters(sample_name, filepath)
        except:
            pass
    iss_barcodes = []
    for sample_name, filepath in iss_filepaths:
        try:
            iss_barcodes += parse_snp_clusters(sample_name, filepath)
        except:
            pass
    print(len(all_barcodes), len(iss_barcodes))

    G = barcode_barcode_similarity_graph(all_barcodes, external_bcs=iss_barcodes)
    return G


def get_component_table(G):
    tbl = {}
    for i, c in enumerate(nx.connected_components(G)):
        if len(c) == 1:
            continue
        tbl[i] = {}
        for bc in c:
            tbl[i][bc.sample_name] = 1
    tbl = pd.DataFrame.from_dict(tbl, orient='columns')
    return tbl

In [None]:
strain_list = [
    'Bifidobacterium_pseudocatenulatum',
    'Fusobacterium_necrophorum',
    'Serratia_proteamaculans'
    'Brevibacterium_siliguriense',
    'Gordonibacter_urolithinfaciens',
    'Bacillus_albus',
    'Gluconobacter_albidus',
    'Geobacillus_stearothermophilus',
    'Bifidobacterium_catenulatum',
    'Streptococcus_viridans', # 10
    'Bacteroides_caccae',
    'Vibrio_alginolyticus',
    'Staphylococcus_sciuri',
    'Pectobacterium_parmentieri',
    'Cronobacter_condimenti',
    'Campylobacter_lari',
    'Atlantibacter_hermannii',
    'Bacillus_tequilensis',
    'Achromobacter_ruhlandii',
    'Serratia_proteamaculans', #20
    'Leptotrichia_hongkongensis',
    'Exiguobacterium_antarcticum',
    'Brenneria_rubrifaciens',
    'Staphylococcus_simiae',
    'Anoxybacillus_amylolyticus',
    'Kosakonia_sacchari',
    'Yersinia_canariae',
    'Providencia_heimbachae',
    'Spirochaeta_perfilievii', # 29
][:3]

organism_components = {}
for organism in strain_list:
    print(organism)
    G = get_graph(organism)
    print('made graph')
    tbl = get_component_table(G)
    print('made table')
    organism_components[organism] = (G, tbl)
    
organism_components

Bifidobacterium_pseudocatenulatum
