In [1]:
%load_ext rpy2.ipython

from cap2.capalyzer.pangea import PangeaFileSource
from cap2.capalyzer.pangea.utils import get_pangea_group
from cap2.capalyzer.table_builder import CAPTableBuilder

from plotnine import *
import pandas as pd
import warnings
from glob import glob
from os.path import isfile

warnings.filterwarnings('ignore')

twins_group = get_pangea_group('Mason Lab', 'NASA Twins', 'dcdanko@gmail.com', )
twins_source = PangeaFileSource(twins_group)
twins = CAPTableBuilder('twins', twins_source)
iss_group = get_pangea_group('Mason Lab', 'NASA ISS', 'dcdanko@gmail.com', )
iss_source = PangeaFileSource(iss_group)
iss = CAPTableBuilder('iss', iss_source)


In [2]:

import networkx as nx
from functools import lru_cache

class SNPBarcode:

    def __init__(self, tbl):
        self.tbl = tbl
        self.sample_name = tbl['sample_name'].unique()[0]
        self.seq = tbl['seq'].unique()[0]
        self.snps = set(zip(tbl['coord'], tbl['changed']))
        self.min_pos = 1000 * 1000 * 1000
        self.max_pos = -1
        for pos, _ in self.snps:
            if pos < self.min_pos:
                self.min_pos = pos
            if pos > self.max_pos:
                self.max_pos = pos

    def __len__(self):
        return len(self.snps)
    
    def __str__(self):
        return self.sample_name + ';' + self.seq + ';' + ','.join([f'{a}:{b}' for a, b in sorted(list(self.snps))])
    
    def __hash__(self):
        return hash(str(self))


class SNPBarcodeSet:

    def __init__(self):
        self.seq = None
        self.snps = set([])
        self.barcodes = []
        self.min_pos = 1000 * 1000 * 1000
        self.max_pos = -1

    def __len__(self):
        return len(self.snps)

    def add_barcode(self, bc):
        if self.seq:
            assert bc.seq == self.seq
        else:
            self.seq = bc.seq
        self.snps |= bc.snps
        if bc.min_pos < self.min_pos:
            self.min_pos = bc.min_pos
        if bc.max_pos > self.max_pos:
            self.max_pos = bc.max_pos
        self.barcodes.append(bc)
        return self


def barcode_barcode_similarity(b1, b2):
    if b1.seq != b2.seq:
        return 0
    if b1.min_pos > b2.max_pos:
        return 0
    if b1.max_pos < b2.min_pos:
        return 0
    jaccard = len(b1.snps & b2.snps) / min(len(b1.snps), len(b2.snps))
    return jaccard


def barcode_barcode_set_similarity(bc, bc_set):
    if bc.seq != bc_set.seq:
        return 0
    if bc.min_pos > bc_set.max_pos:
        return 0
    if bc.max_pos < bc_set.min_pos:
        return 0
    jaccard = len(bc.snps & bc_set.snps) / min(len(bc.snps), len(bc_set.snps))
    return jaccard


def build_barcode_sets(barcodes, sim_thresh=0.5):
    """Return a list of SNPBarcodeSets that fulfill sevreal reqs.

     - all barcodes are in one or more barcode sets
     - each barcode in a set has similarity of at least sim_thresh to that set
    """
    barcode_sets = []
    for i, bc in enumerate(barcodes):
        if i % (25 * 1000) == 0:
            print(i, len(barcode_sets))
        added_to_bc_set = False
        for bc_set in barcode_sets:
            s = barcode_barcode_set_similarity(bc, bc_set)
            if s < sim_thresh:
                continue
            added_to_bc_set = True
            bc_set.add_barcode(bc)
        if not added_to_bc_set:
            new_bc_set = SNPBarcodeSet().add_barcode(bc)
            barcode_sets.append(new_bc_set)
    return barcode_sets


def barcode_barcode_similarity_graph(barcodes, external_bcs=[], sim_thresh=0.5):
    """Return a Graph with edges between similar barcodes.

    - Barcodes with no similar barcdoes are not included
    - weight of each edge is the similarity
    """
    barcode_sets = build_barcode_sets(barcodes, sim_thresh=sim_thresh)
    G = nx.Graph()
    for bc_set in barcode_sets:
        for bc1 in bc_set.barcodes:
            for bc2 in bc_set.barcodes:
                if bc1 == bc2:
                    break
                s = barcode_barcode_similarity(bc1, bc2)
                if s >= sim_thresh:
                    G.add_edge(bc1, bc2, weight=s)
    
    comps = list(nx.connected_components(G))
    print(f'finished building clusters. attaching externals to {len(comps)} clusters.')
    comp_count = {}
    for i, bc1 in enumerate(external_bcs):
        if i % (100 * 1000) == 0:
            print(f'Processed {i} external bcs')
        for comp_ind, comp in enumerate(comps):
            if comp_count.get(comp_ind, 0) >= 2:
                continue
            for bc2 in comp:
                s = barcode_barcode_similarity(bc1, bc2)
                if s >= sim_thresh:
                    comp_count[comp_ind] = comp_count.get(comp_ind, 0) + 1
                    G.add_edge(bc1, bc2, weight=s)
                    break
                
    return G

def parse_snp_clusters(sample_name, filepath):
    tbl = pd.read_csv(filepath, compression='gzip', index_col=0)
    tbl = tbl.query('weight >= 10')
    tbl['sample_name'] = sample_name
    barcodes = [bc for bc in tbl.groupby('cluster').apply(SNPBarcode) if len(bc) >= 5]
    return barcodes

In [7]:
def barcode_barcode_similarity_graph(barcodes, external_bcs=[], sim_thresh=0.5):
    """Return a Graph with edges between similar barcodes.

    - Barcodes with no similar barcdoes are not included
    - weight of each edge is the similarity
    """
    seq_tbl = {}
    for bc in barcodes:
        coord_tbl = seq_tbl.get(bc.seq, {})
        bc_list = coord_tbl.get(bc.min_pos // 100, [])
        bc_list.append(bc)
        coord_tbl[bc.min_pos // 100] = bc_list
        seq_tbl[bc.seq] = coord_tbl

    G = nx.Graph()
    for seq, coord_tbl in seq_tbl.items():
        positions = sorted([pos for pos in coord_tbl.keys()])
        for p1 in positions:
            for bc1 in coord_tbl[p1]:
                for p2 in positions:
                    if (p2 * 100) > bc1.max_pos:
                        break
                    for bc2 in coord_tbl[p2]:
                        if bc1 == bc2:
                            continue
                        s = barcode_barcode_similarity(bc1, bc2)
                        if s >= sim_thresh:
                            G.add_edge(bc1, bc2, weight=s)  
        print('finished', seq)
        
    comps = list(nx.connected_components(G))
    print(f'finished building clusters. attaching externals to {len(comps)} clusters.')
    comp_count = {}
    for i, bc1 in enumerate(external_bcs):
        if i % (100 * 1000) == 0:
            print(f'Processed {i} external bcs')
        for comp_ind, comp in enumerate(comps):
            if comp_count.get(comp_ind, 0) >= 2:
                continue
            for bc2 in comp:
                s = barcode_barcode_similarity(bc1, bc2)
                if s >= sim_thresh:
                    comp_count[comp_ind] = comp_count.get(comp_ind, 0) + 1
                    G.add_edge(bc1, bc2, weight=s)
                    break

    return G

In [8]:

def get_graph(organism):
    twins_filepaths = list(twins_source('cap2::experimental::make_snp_clusters', f'snp_clusters__{organism}'))
    iss_filepaths = list(iss_source('cap2::experimental::make_snp_clusters', f'snp_clusters__{organism}'))
    all_barcodes = []
    for sample_name, filepath in twins_filepaths:
        try:
            all_barcodes += parse_snp_clusters(sample_name, filepath)
        except:
            pass
    iss_barcodes = []
    for sample_name, filepath in iss_filepaths:
        try:
            iss_barcodes += parse_snp_clusters(sample_name, filepath)
        except:
            pass
    print(len(all_barcodes), len(iss_barcodes))

    G = barcode_barcode_similarity_graph(all_barcodes, external_bcs=iss_barcodes)
    return G


def get_component_table(G):
    tbl = {}
    for i, c in enumerate(nx.connected_components(G)):
        if len(c) == 1:
            continue
        tbl[i] = {}
        for bc in c:
            tbl[i][bc.sample_name] = 1
    tbl = pd.DataFrame.from_dict(tbl, orient='columns')
    return tbl

In [None]:
from os import makedirs

def get_long_component_table(G):
    tbl = []
    for i, c in enumerate(nx.connected_components(G)):
        if len(c) == 1:
            continue
        for bc in c:
            row = {
                'sample_name': bc.sample_name,
                'cluster': i,
                'seq': bc.seq,
                'min_pos': min([a for a, b in bc.snps]),# bc.min_pos,
                'max_pos': max([a for a, b in bc.snps]),# bc.max_pos,
                'snps': ','.join([f'{a}:{b}' for a, b in bc.snps])
            }
            tbl.append(row)
    tbl = pd.DataFrame(tbl)
    return tbl

def save_table(organism, G, tbl):
    fileroot = f'graphs/v1/{organism}'
    makedirs(fileroot, exist_ok=True)
    tbl.to_csv(fileroot + f'/{organism}.wide_table.csv')
    long_tbl = get_long_component_table(G)
    long_tbl.to_csv(fileroot + f'/{organism}.long_table.csv')

In [None]:
from os.path import isfile
from os import makedirs

strain_list = [
    'Geobacillus_stearothermophilus',
    'Bifidobacterium_catenulatum',
    'Streptococcus_viridans', # 10
]

def save_table(organism, G, tbl):
    fileroot = f'graphs/v1/{organism}'
    makedirs(fileroot, exist_ok=True)
    tbl.to_csv(fileroot + f'/{organism}.wide_table.csv')
    long_tbl = get_long_component_table(G)
    long_tbl.to_csv(fileroot + f'/{organism}.long_table.csv')
    
def tables_exist(organism):
    fileroot = f'graphs/v1/{organism}'
    wide = fileroot + f'/{organism}.wide_table.csv'
    long = fileroot + f'/{organism}.long_table.csv'
    if isfile(wide) and isfile(long):
        return wide, long
    return None, None

organism_components = {}
for organism in strain_list:
    if tables_exist(organism)[0]:
        print('complete:', organism)
        continue
    print('processing:', organism)
    filepath = f'graphs/v1/{organism}.gml.gz'
    G = get_graph(organism)
    print('made graph')
    tbl = get_component_table(G)
    print('made table')
    organism_components[organism] = (G, tbl)
    save_table(organism, G, tbl)
    
len(organism_components)

processing: Geobacillus_stearothermophilus
106038 119069
finished NZ_RCTK01000106.1
finished NZ_LUCR01000030.1
finished NZ_LUCR01000168.1
finished NZ_LUCR01000012.1
finished NZ_LUCR01000015.1
finished NZ_CP034952.1
finished NZ_CP016552.1
finished NZ_LQYV01000079.1
finished NZ_LQYV01000166.1
finished NZ_JYNW01000089.1
finished NZ_JYNW01000025.1
finished NZ_LDNT01000158.1
finished NZ_LDNT01000310.1
finished NZ_LDNT01000311.1
finished NZ_LDNT01000323.1
finished NZ_LDNT01000330.1
finished NZ_LDNT01000341.1
finished NZ_LDNT01000354.1
finished NZ_LDNT01000435.1
finished NZ_LQYY01000015.1
finished NZ_LQYY01000153.1
finished NZ_LQYY01000158.1
finished NZ_LQYY01000041.1
finished NZ_LQYY01000159.1
finished NZ_LDNS01000140.1
finished NZ_LDNS01000267.1
finished NZ_LDNS01000298.1
finished NZ_LDNS01000309.1
finished NZ_LDNS01000356.1
finished NZ_LDNS01000378.1
finished NZ_LDNS01000406.1
finished NZ_LDNS01000416.1
finished NZ_LDNS01000422.1
finished NZ_LDNS01000043.1
finished NZ_LDNS01000096.1
finish

In [10]:

    
    
    
save_table('Serratia_proteamaculans', G, tbl)

In [21]:
G, tbl = organism_components['Serratia_proteamaculans']
tbl

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,10353,10354,10355,10356,10357,10358,10359,10360,10361,10362
011515_TW_B,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,,,,,,,,,,
MHV-twin-3_S41658396,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,,,,,,,,,1.0,1.0
030116_TW_B,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,,,,,,,,,,
082916_HR_S,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,,,,,,,,,,
MHV-TW2_S41668201,1.0,1.0,1.0,,1.0,1.0,1.0,,,1.0,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
IF4SW_P,,,,,,,,,,,...,,,,,,,,,,
IIF1SW,,,,,,,,,,,...,,,,,,,,,,
IF8SW_P,,,,,,,,,,,...,,,,,,,,,,
IIF7SW,,,,,,,,,,,...,,,,,,,,,,


In [24]:


def get_long_component_table(G):
    tbl = []
    for i, c in enumerate(nx.connected_components(G)):
        if len(c) == 1:
            continue
        for bc in c:
            row = {
                'sample_name': bc.sample_name,
                'cluster': i,
                'seq': bc.seq,
                'min_pos': min([a for a, b in bc.snps]),# bc.min_pos,
                'max_pos': max([a for a, b in bc.snps]),# bc.max_pos,
                'snps': ','.join([f'{a}:{b}' for a, b in bc.snps])
            }
            tbl.append(row)
    tbl = pd.DataFrame(tbl)
    return tbl

get_long_component_table(G)

Unnamed: 0,sample_name,cluster,seq,min_pos,max_pos,snps
0,011515_TW_B,0,NZ_SDFS01000013.1,496,653,"646:T,568:G,581:A,520:C,500:A,628:A,499:C,550:..."
1,MHV-twin-3_S41658396,0,NZ_SDFS01000013.1,496,653,"646:T,568:G,581:A,520:C,500:A,628:A,499:C,550:..."
2,030116_TW_B,0,NZ_SDFS01000013.1,496,653,"646:T,568:G,581:A,520:C,500:A,628:A,499:C,550:..."
3,082916_HR_S,0,NZ_SDFS01000013.1,496,653,"646:T,568:G,581:A,520:C,500:A,628:A,499:C,550:..."
4,MHV-TW2_S41668201,0,NZ_SDFS01000013.1,550,610,"550:T,568:G,610:G,581:A,583:A,552:C,588:C,562:T"
...,...,...,...,...,...,...
70348,IF2SW,10360,NZ_MQMT01000001.1,81687,81924,"81720:G,81817:G,81805:A,81790:C,81831:C,81897:..."
70349,MHV-twin-4_S41677925,10361,NZ_SWDE01000001.1,157356,157485,"157410:A,157422:C,157407:T,157401:G,157413:A,1..."
70350,MHV-twin-3_S41658396,10361,NZ_SWDE01000001.1,157356,157485,"157410:A,157422:C,157407:T,157401:G,157413:A,1..."
70351,MHV-twin-4_S41677925,10362,NZ_SWDF01000003.1,106844,106982,"106844:T,106886:T,106943:T,106949:A,106922:T,1..."
