In [None]:
import sys, os
sys.path.append(r'D:\ChromatinTracingPipeline\CommonTools')
import MaxViewer as mv
import IOTools as io
import AlignmentTools as at

DNA_data_folder = r'D:\DNA_FISH\E20200918_D0043\DNA' #Folder with DNA dax files
RNA_data_folder = r'D:\DNA_FISH\E20200918_D0043\RNA' #Folder with RNA dax files
Analysis_folder = r'D:\DNA_FISH\E20200918_D0043\Analysis_CK' #Folder with analysis results of chromatin tracing
os.makedirs(Analysis_folder+os.sep+"RNA_DNA_Analysis", exist_ok=True)

In [None]:
import glob,os,pickle
import numpy as np
from collections import OrderedDict
from collections import namedtuple

#Get all folders in RNA data, excluding ones with an underscore (e.g. H0_postBleach)
folders = [folder for folder in glob.glob(RNA_data_folder + os.sep + r'H*') if '_' not in os.path.basename(folder)]
#Add the DAPI folder from the DNA data
folders += glob.glob(DNA_data_folder + os.sep + r'H0*')
#Get list of all FOVs, using 2nd folder in list. Assumes the same in all folders
fovs = [os.path.basename(fl) for fl in glob.glob(folders[1]+'\*.dax')]

dic = pickle.load(open(Analysis_folder + os.sep + 'Selected_Spot.pkl','rb'))
class_ids, coords, cell_ids = np.array(dic[u'class_ids']), np.array(dic[u'coords']), np.array(dic[u'cell_ids'])


In [None]:
from collections import namedtuple

Vec3D = namedtuple('Vec3D', ['x', 'y', 'z'])
DNA_Spot = namedtuple("DNA_Spot", ['cell_id', 'x', 'y', 'z'])

def norm(im,sz=40):
    import cv2
    im_ = im.astype(np.float32)
    im_ = im_ / cv2.blur(im_,(sz,sz))
    im_[im_>2] = 2
    return im_

def align(im1, im2):
    """Calculates the drift between two images"""
    return at.fftalign_2d(norm(im1),norm(im2),max_disp=500)

def prepare_data(ifov):
    #Get list of dax files for a specific FOV (ifov) across all folders
    files_ = [fld+os.sep+fovs[ifov] for fld in folders]
    #Make sure the list is sorted with RNA first, in order, then DNA 
    #(probably not necessary since folders is already in the correct order?)
    isort = np.argsort([int(os.path.basename(os.path.dirname(fl))[1:].split('R')[0].split('B')[0]) +  (-10 if 'RNA' in fl else 0)
                        for fl in files_])
    files_ = np.array(files_)[isort]

    #Load the images for this FOV
    ims = OrderedDict() #As of python 3.7, regular dicts have this functionality, but this is used for compatibility
    for fl in files_:
        if 'H0' in fl:
            ncols = 5
            cols = [(4, '_dapi')]
        if 'H1' in fl:
            ncols = 4
            cols = [(0, '_mCherry'), (1, '_GFP'), (2, '_beads')]
        if 'H2' in fl:
            ncols = 4
            cols = [(1, '_sox2'), (2, '_beads')]
        for i, tag in cols:
            name = os.path.basename(os.path.dirname(fl)) + tag
            ims[name] = io.DaxReader(fl).loadAll()[i::ncols]

    #Calculate the drift
    imed = int(len(ims['H0_dapi']) / 2)
    xtDNA, ytDNA = align(ims['H0_dapi'][imed], ims['H0B,B,B_dapi'][imed])
    xtsox2, ytsox2 = align(ims['H1_beads'][imed], ims['H2_beads'][imed])
    #print(f"DAPI drift {ifov}: {xtDNA},{ytDNA}")
    
    #Delete the bead images now that drift is calculated
    ims.pop('H1_beads')
    ims.pop('H2_beads')
    
    #Apply drift correction to the sox2 and DNA dapi images
    ims['H2_sox2'] = at.translate(ims['H2_sox2'], [0,-xtsox2,-ytsox2])
    ims['H0B,B,B_dapi'] = at.translate(ims['H0B,B,B_dapi'], [0,-xtDNA,-ytDNA])
    
    #Load the median images for 750 and cy5 colors and apply drift correction
    #Note: These images are only used in visualization. May improve speed to skip
    mean_dax1 = Analysis_folder + os.sep + os.path.basename(files_[1]).replace('.dax','_mean_0.dax')
    mean_dax2 = Analysis_folder + os.sep + os.path.basename(files_[1]).replace('.dax','_mean_1.dax')
    ims['DNA1'] = at.translate(io.DaxReader(mean_dax1).loadAll(),[0,-xtDNA,-ytDNA])
    ims['DNA2'] = at.translate(io.DaxReader(mean_dax2).loadAll(),[0,-xtDNA,-ytDNA])
    
    #Get the DNA spot data for this FOV
    spots_ = coords[class_ids==ifov]
    cell_ids_ = cell_ids[class_ids==ifov]
    x,y,z = spots_[:,0]+xtDNA, spots_[:,1]+ytDNA, spots_[:,2]
    dna_spots = [DNA_Spot(cid, x_, y_, z_) for cid, x_, y_, z_ in zip(cell_ids_, x, y, z)]
    #print(ifov, len(cell_ids_), len(dna_spots), Counter(Counter(cell_ids_).values()))
    
    return ims, dna_spots, (xtDNA,ytDNA)

In [None]:
import matplotlib.pylab as plt

def get_spot(image):
    data = {}
    
    #Get coordinates of brightest pixel relative to center
    pos = np.unravel_index(np.argmax(image),image.shape)-np.array(image.shape)/2.
    #then L2 normalize (i.e. get distance of brightest pixel from center in nm)
    data['dist'] = np.linalg.norm(pos*[200,108,108])
    
    #Max used as signal, median as noise
    data['h'], data['bk'] = np.max(image), np.median(image)
    data['ratio'] = data['h'] / data['bk']
    
    #Burst size (intensity) measured as sum of pixels minus background*pixel volume
    data['intensity'] = np.sum(image) - data['bk']*image.size
    
    data['keep'] = int((data['ratio'] > 3) and (data['dist'] < 1000))
    
    data['im'] = np.max(image,0) #Will it break something if h is np.max(image,0)?
    return data

def subset(image, c, m):
    return image[max(0,c.z-m.z):c.z+m.z, max(0,c.x-m.x):c.x+m.x, max(0,c.y-m.y):c.y+m.y]

def find_RNA_spots(ims, dna_spots):
    infos = []#OrderedDict()
    keeps = []#OrderedDict()
    plot = False
    for ispot, dna_spot in enumerate(dna_spots):
        try:
            margin = Vec3D(x=10, y=10, z=6)
            center = Vec3D(x=int(dna_spot.x), y=int(dna_spot.y), z=int(dna_spot.z))

            sox2 = get_spot(subset(ims['H2_sox2'], center, margin))
            mCherry = get_spot(subset(ims['H1_mCherry'], center, margin))
            GFP = get_spot(subset(ims['H1_GFP'], center, margin))

            infos.append([dna_spot, sox2, mCherry, GFP])
            keeps.append([sox2['keep'], mCherry['keep'], GFP['keep']])

            if plot:           
                imDNA1 = np.max(subset(ims['DNA1'], center, margin), 0)
                imDNA2 = np.max(subset(ims['DNA2'], center, margin), 0)
                sy,sx,sz = 30,30,6
                xc,yc,zc = int(dna_spot.x), int(dna_spot.y), int(dna_spot.z)
                imDAPI1 = np.max(ims['H0_dapi'][zc:zc+1,xc-sx:xc+sx,yc-sy:yc+sy],0)
                imDAPI2 = np.max(ims['H0B,B,B_dapi'][zc:zc+1,xc-sx:xc+sx,yc-sy:yc+sy],0)

                fig,axs = plt.subplots(1,7,figsize=(10,3))
                axs[0].imshow(imDNA1,cmap='gray')
                axs[0].set_title('750')
                axs[1].imshow(imDNA2,cmap='gray')
                axs[1].set_title('Cy5')
                axs[2].imshow(sox2['im'],cmap='gray')
                axs[2].set_title('Sox2')
                axs[3].imshow(mCherry['im'],cmap='gray')
                axs[3].set_title('mCherry')
                axs[4].imshow(GFP['im'],cmap='gray')
                axs[4].set_title('GFP')
                axs[5].imshow(imDAPI1,cmap='gray')
                axs[5].set_title('RDAPI')
                axs[6].imshow(imDAPI2,cmap='gray')
                axs[6].set_title('DDAPI')
                plt.suptitle('Spot '+str(ispot)+' cell:'+str(dna_spot.cell_id)+\
                             '\nSox:'+str(keep_sox)+' '+str([np.round(H_sox,2),int(dist_sox)])+\
                            '  mCherry:'+str(keep_mcherry)+' '+str([np.round(H_mCherry,2),int(dist_mCherry)])+\
                            '  GFP:'+str(keep_GFP)+' '+str([np.round(H_GFP,2),int(dist_GFP)]))
                plt.show()
        except:
            #print(dna_spot)
            pass
        
    return infos, keeps

In [None]:
from tqdm import tqdm_notebook as tqdm
infos = []
keeps = []
drifts = []
for ifov in tqdm(range(len(fovs))):
    ims, dna_spots, driftxy = prepare_data(ifov)
    drifts.append(driftxy)
    infos_, keeps_ = find_RNA_spots(ims, dna_spots)
    infos.extend(infos_)
    keeps.extend(keeps_)


In [None]:
from collections import Counter

print(np.mean(np.array(keeps), axis=0))
print(Counter([tuple(x) for x in keeps]))

In [None]:
sox2h = [sox2['ratio'] for cid, sox2, gfp, mcherry in infos if sox2['keep'] and mcherry['keep'] and not gfp['keep']]
mcherry2h = [mcherry['ratio'] for cid, sox2, gfp, mcherry in infos if sox2['keep'] and mcherry['keep'] and not gfp['keep']]
plt.figure(figsize=(6,6))
plt.title("{} spots".format(len(sox2h)))
plt.plot(sox2h,mcherry2h,'o')
plt.xlabel('Sox2 signal-to-noise')
plt.ylabel('mCherry signal-to-noise');

In [None]:
sox2h = [sox2['ratio'] for cid, sox2, gfp, mcherry in infos if sox2['keep'] and gfp['keep'] and not mcherry['keep']]
gfp2h = [gfp['ratio'] for cid, sox2, gfp, mcherry in infos if sox2['keep'] and gfp['keep'] and not mcherry['keep']]
plt.figure(figsize=(6,6))
plt.title("{} spots".format(len(sox2h)))
plt.plot(sox2h,gfp2h,'o')
plt.xlabel('Sox2 signal-to-noise')
plt.ylabel('GFP signal-to-noise');

In [None]:
#Load single cell chromatin tracing data
trace_data = {}
with open(Analysis_folder + os.sep + "Post_analysis" + os.sep + "single_cell_data.csv") as f:
    headers = f.readline().strip().split(',')
    for line in f:
        cols = line.strip().split(',')
        cell_id = int(cols[6])
        trace_data[cell_id] = {tag:float(num) for tag, num in zip(headers, cols)}
        
print(list(trace_data.values())[0])

In [None]:
#cell_ids = [x[0].cell_id for x in infos]
#print(Counter(cell_ids))
#len(cell_ids)
#print(Counter(class_ids))
#ifov = 10
#spots_ = coords[class_ids==ifov]
#cell_ids_ = cell_ids[class_ids==ifov]
#x,y,z = spots_[:,0], spots_[:,1], spots_[:,2]
#dna_spots = [DNA_Spot(cid, x_, y_, z_) for cid, x_, y_, z_ in zip(cell_ids_, x, y, z)]
#print(cell_ids_)
#print(len(x))
#print(dna_spots)

In [None]:
from collections import defaultdict
alldata = defaultdict(list)
for dna_spot, sox2, mcherry, gfp in infos:
    if dna_spot.cell_id in trace_data:
        alldata[dna_spot.cell_id].append([dna_spot, sox2, mcherry, gfp])
#for cell in alldata:
 #   for x in alldata[cell]:
  #      print(cell, x[0])
print(len(alldata))

In [None]:
cast_nums = []
noncast_nums = []
linked_data = []
cast_dists = []
noncast_dists = []
missing = 0
ambiguous = 0
fovs = defaultdict(lambda: defaultdict(list))
for dna_data in trace_data.values():
    fov = int(dna_data['cell_id'] // 1000)
    cast_coords = np.array([dna_data['CAST_y'], dna_data['CAST_x'], dna_data['CAST_z']])
    noncast_coords = np.array([dna_data['129_y'], dna_data['129_x'], dna_data['129_z']])
    fovs[fov]['precast'].append(cast_coords)
    fovs[fov]['prenoncast'].append(noncast_coords)
    cast_coords = np.array([dna_data['CAST_y']+drifts[fov][0]*109, dna_data['CAST_x']+drifts[fov][1]*109, dna_data['CAST_z']])
    noncast_coords = np.array([dna_data['129_y']+drifts[fov][0]*109, dna_data['129_x']+drifts[fov][1]*109, dna_data['129_z']]) 
    cast_rna = None
    cast_dist = 10000000
    noncast_rna = None
    noncast_dist = 10000000
    fovs[fov]['cast'].append(cast_coords)
    fovs[fov]['noncast'].append(noncast_coords)
    for rna_data in alldata[dna_data['cell_id']]:
        #num += f"|{rna_data[1]['keep']},{rna_data[2]['keep']},{rna_data[3]['keep']}|"
        coords = np.array([rna_data[0].x*109, rna_data[0].y*109, rna_data[0].z*200])
        fovs[fov]['rna'].append(coords)
        castdist = np.linalg.norm(coords - cast_coords)
        noncastdist = np.linalg.norm(coords - noncast_coords)
        if castdist < cast_dist:
            cast_rna = rna_data
            cast_dist = castdist
        if noncastdist < noncast_dist:
            noncast_rna = rna_data
            noncast_dist = noncastdist
        #print()
    if cast_rna is None or noncast_rna is None or cast_rna == noncast_rna:
        #print(dna_data['cell_id'], drifts[fov][0]*109, drifts[fov][1]*109)
        #print(cast_coords)
        #print(noncast_coords)
        #print([(rd[0].x*109, rd[0].y*109, rd[0].z*200) for rd in alldata[dna_data['cell_id']]])
        #print()
        if len(alldata[dna_data['cell_id']]) < 2:
            missing += 1
        else:
            ambiguous += 1
        continue
    #nums.append(num)
    cast_dists.append(cast_dist)
    noncast_dists.append(noncast_dist)
    cast_nums.append(f"{cast_rna[1]['keep']},{cast_rna[2]['keep']},{cast_rna[3]['keep']}")
    noncast_nums.append(f"{noncast_rna[1]['keep']},{noncast_rna[2]['keep']},{noncast_rna[3]['keep']}")
    linked_data.append({'cast_rna': cast_rna, '129_rna': noncast_rna, 'dna': dna_data})
    
print(len(linked_data))
print(len(trace_data))
print(missing, ambiguous)
print(Counter(cast_nums))
print(Counter(noncast_nums))

print(np.median(cast_dists+noncast_dists))
#import seaborn as sns
#sns.histplot(cast_dists+noncast_dists)
                        
#plt.figure(figsize=(10,10))
#ifov=10
#precastx = [p[0] for p in fovs[ifov]['precast']]
#precasty = [p[1] for p in fovs[ifov]['precast']]
#prenoncastx = [p[0] for p in fovs[ifov]['prenoncast']]
#prenoncasty = [p[1] for p in fovs[ifov]['prenoncast']]
#castx = [p[0] for p in fovs[ifov]['cast']]
#casty = [p[1] for p in fovs[ifov]['cast']]
#noncastx = [p[0] for p in fovs[ifov]['noncast']]
#noncasty = [p[1] for p in fovs[ifov]['noncast']]
#rnax = [p[0] for p in fovs[ifov]['rna']]
#rnay = [p[1] for p in fovs[ifov]['rna']]
#plt.scatter(precastx, precasty, color="#cc4444")
#plt.scatter(prenoncastx, prenoncasty, color="#44cc44")
#plt.scatter(castx, casty, color="#ff0000")
#plt.scatter(noncastx, noncasty, color="#00ff00")
#plt.scatter(rnax, rnay, color="#0000ff", alpha=0.5)

In [None]:
import upsetplot

def convert_nums(nums):
    memberships = defaultdict(int)
    for num in nums:
        sox2, mch, gfp = num.split(',')
        labels = []
        if sox2 == '1':
            labels.append('Sox2')
        if mch == '1':
            labels.append(' mCherry')
        if gfp == '1':
            labels.append(' GFP')
        memberships[tuple(labels)] += 1
    return memberships

castmem = convert_nums(cast_nums)
noncastmem = convert_nums(noncast_nums)
print(castmem.keys())
keys = [(), ('Sox2',), ('Sox2', ' mCherry'), ('Sox2', ' GFP'), (' mCherry',), (' GFP',), (' mCherry', ' GFP'), ('Sox2', ' mCherry', ' GFP')]
#keys = [('Sox2',), ('mCherry',), ('GFP',)]
dfc = upsetplot.from_memberships(keys, data=[castmem[k] for k in keys])
dfnc = upsetplot.from_memberships(keys, data=[noncastmem[k] for k in keys])
upsetplot.plot(dfc, sort_by='cardinality', sort_categories_by=None, show_counts=True)
plt.savefig(Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+"cast_categories_plot_bogdan.png",dpi=300)
upsetplot.plot(dfnc, sort_by='cardinality', sort_categories_by=None, show_counts=True)
plt.savefig(Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+"129_categories_plot_bogdan.png",dpi=300)
plt.show()


In [None]:
#Save results to csv
with open(Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+"binary_celldata.csv", "w") as f:
    for data in linked_data:
        line = [int(data['dna']['cell_id'])]
        line.extend([data['cast_rna'][i]['keep'] for i in [1,2,3]])
        line.extend([data['129_rna'][i]['keep'] for i in [1,2,3]])
        print(','.join([str(x) for x in line]), file=f)
    
newdata = {}
with open(r"C:\Users\ckern\Downloads\rna2alleles_d43.csv") as f:
    header = f.readline().strip()
    for line in f:
        cols = line.strip().split(',')
        newdata[int(cols[0])] = line.strip()
       
for data in linked_data:
    cell = int(data['dna']['cell_id'])
    newdata[cell] += ','+','.join([str(data['cast_rna'][i]['keep']) for i in [1,2,3]])
    newdata[cell] += ','+','.join([str(data['129_rna'][i]['keep']) for i in [1,2,3]])
    
header += ",CAST_Sox2,CAST_mCherry,CAST_GFP,129_Sox2,129_mCherry,129_GFP"
with open(r"C:\Users\ckern\Downloads\rna2alleles_d43.csv", "w") as f:
    print(header,file=f)
    for cell, line in newdata.items():
        print(line, file=f)

In [None]:
import pickle
pickle.dump(linked_data, open(os.path.join(Analysis_folder, "RNA_DNA_Analysis", "pickled_data.pkl"), 'wb'))

In [None]:
import pickle
linked_data = pickle.load(open(os.path.join(Analysis_folder, "RNA_DNA_Analysis", "pickled_data.pkl"), 'rb'))

In [None]:
def histo_plot(datakey, xlabel):
    bursting = []
    silent = []
    cast_burst = []
    cast_silent = []
    noncast_burst = []
    noncast_silent = []
    for data in linked_data:
        if (data['cast_rna'][1]['keep'] or data['cast_rna'][3]['keep']) and not data['cast_rna'][2]['keep']:
        #if data['cast_rna'][1]['keep'] and data['cast_rna'][3]['keep']:
            bursting.append(data['dna']['CAST_'+datakey])
            cast_burst.append(data['dna']['CAST_'+datakey])
        elif not data['cast_rna'][1]['keep'] and not data['cast_rna'][2]['keep'] and not data['cast_rna'][3]['keep']:
        #elif not data['cast_rna'][1]['keep'] and not data['cast_rna'][3]['keep']:
            silent.append(data['dna']['CAST_'+datakey])
            cast_silent.append(data['dna']['CAST_'+datakey])
        if (data['129_rna'][1]['keep'] or data['129_rna'][2]['keep']) and not data['129_rna'][3]['keep']:
        #if data['129_rna'][1]['keep'] and data['129_rna'][2]['keep']:
            bursting.append(data['dna']['129_'+datakey])
            noncast_burst.append(data['dna']['129_'+datakey])
        elif not data['129_rna'][1]['keep'] and not data['129_rna'][2]['keep'] and not data['129_rna'][3]['keep']:
        #elif not data['129_rna'][1]['keep'] and not data['129_rna'][2]['keep']:
            silent.append(data['dna']['129_'+datakey])
            noncast_silent.append(data['dna']['129_'+datakey])

    bursting = [x for x in bursting if not math.isnan(x)]
    silent = [x for x in silent if not math.isnan(x)]
    cast_burst = [x for x in cast_burst if not math.isnan(x)]
    cast_silent = [x for x in cast_silent if not math.isnan(x)]
    noncast_burst = [x for x in noncast_burst if not math.isnan(x)]
    noncast_silent = [x for x in noncast_silent if not math.isnan(x)]
    import seaborn
    from scipy.stats import chisquare, ranksums, mannwhitneyu
    def make_plot(bst, slnt, fileadd=''):
        bstchi = [len([b for b in bst if b <= 250]), len([b for b in bst if b > 250])]
        sltchi = [len([b for b in slnt if b <= 250]), len([b for b in slnt if b > 250])]
        plt.figure(figsize=(6,5))
        #seaborn.distplot(bst, label=f"Bursting (n={len(bst)};mean={sum(bst)/len(bst):.0f}nm)", kde=False)
        #seaborn.distplot(slnt, label=f"Resting (n={len(slnt)};mean={sum(slnt)/len(slnt):.0f}nm)", kde=False)
        seaborn.distplot(bst, label=f"CAST (n={len(bst)};mean={sum(bst)/len(bst):.0f}nm)", kde=False)
        seaborn.distplot(slnt, label=f"129 (n={len(slnt)};mean={sum(slnt)/len(slnt):.0f}nm)", kde=False)
        plt.xlabel(xlabel)
        plt.ylabel("Chromosomes")
        #title =f"Bursting: {100*bstchi[0]/(bstchi[0]+bstchi[1]):.1f}% in contact ({bstchi[0]}/{bstchi[1]})\n"
        #title +=f"Resting: {100*sltchi[0]/(sltchi[0]+sltchi[1]):.1f}% in contact ({sltchi[0]}/{sltchi[1]})\n"
        #title += f"Chi-square p-value={chisquare(bstchi, sltchi)[1]:0.5f}"
        title = f"Wilcoxon p-value={ranksums(bst, slnt)[1]:0.5f}\n"
        title += f"Mann-Whitney p-value={mannwhitneyu(bst, slnt)[1]:0.5f}"
        plt.title(title)
        #plt.axvline(250, linestyle="dashed", color='red',label='Contact threshold: 250nm')
        plt.legend()
        plt.savefig(Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+"rna_histogram_"+fileadd+datakey+".png",dpi=300)
    #make_plot(bursting, silent)
    #make_plot(cast_burst, cast_silent, 'CAST_')
    #make_plot(noncast_burst, noncast_silent, '129_')
    make_plot(cast_burst, noncast_burst, 'burst_')
    make_plot(cast_silent, noncast_silent, 'silent_')

#histo_plot('ep_dist', 'E-P distance')
#histo_plot('ins', 'Insulation (10-25, 25-33)')#
#histo_plot('rgs_10_25', 'Radius of gyration (10-25)')
#histo_plot('rgs_10_33', 'Radius of gyration (10-33)')
#histo_plot('rgs_1_42', 'Radius of gyration (1-42)')
#histo_plot('rgs_10_39', 'Radius of gyration (10-39)')
#histo_plot('rgs_25_33', 'Radius of gyration (25-33)')

In [None]:
import math
from scipy.stats import chisquare
cast_burst = []
noncast_burst = []
cast_silent = []
noncast_silent = []
for data in linked_data:
    if (data['cast_rna'][1]['keep'] or data['cast_rna'][3]['keep']) and not data['cast_rna'][2]['keep']:
        cast_burst.append(data['dna']['CAST_ep_dist'])
    elif not data['cast_rna'][1]['keep'] and not data['cast_rna'][2]['keep'] and not data['cast_rna'][3]['keep']:
        cast_silent.append(data['dna']['CAST_ep_dist'])
    if (data['129_rna'][1]['keep'] or data['129_rna'][2]['keep']) and not data['129_rna'][3]['keep']:
        noncast_burst.append(data['dna']['129_ep_dist'])
    elif not data['129_rna'][1]['keep'] and not data['129_rna'][2]['keep'] and not data['129_rna'][3]['keep']:
        noncast_silent.append(data['dna']['129_ep_dist'])
cast_burst = [x for x in cast_burst if not math.isnan(x)]
cast_silent = [x for x in cast_silent if not math.isnan(x)]
noncast_burst = [x for x in noncast_burst if not math.isnan(x)]
noncast_silent = [x for x in noncast_silent if not math.isnan(x)]
bursting = list(sorted(cast_burst + noncast_burst))
silent = list(sorted(cast_silent + noncast_silent))

from scipy.stats import chi2_contingency
def get_pvalue(bursting, silent, thresh):
    bstchi = [len([b for b in bursting if b <= thresh]), len([b for b in bursting if b > thresh])]
    sltchi = [len([b for b in silent if b <= thresh]), len([b for b in silent if b > thresh])]
    if 0 in bstchi or 0 in sltchi:
        return 1
    return chi2_contingency([bstchi, sltchi])[1]
    
def cumulative_graph(bursting, silent, name1, name2):
    pvals = [get_pvalue(bursting, silent, t) for t in range(0,1500,20)]
    bstchi = [len([b for b in bursting if b <= 250]), len([b for b in bursting if b > 250])]
    sltchi = [len([b for b in silent if b <= 250]), len([b for b in silent if b > 250])]
    plt.figure(figsize=(6,5))
    for i, pval in enumerate(pvals):
        if pval <= 0.01:
            plt.axvspan((20*i)-10, (20*i)+10, color='#bbbbbb')
        elif pval <= 0.05:
            plt.axvspan((20*i)-10, (20*i)+10, color='#dddddd')
    plt.plot(bursting, [(x+1)/len(bursting) for x in range(len(bursting))], label=name1)
    plt.plot(silent, [(x+1)/len(silent) for x in range(len(silent))], label=name2)
    plt.xlabel("E-P distance (nm)")
    plt.ylabel("Fraction of chromosomes")
    title = f"{name1}: {100*bstchi[0]/(bstchi[0]+bstchi[1]):.1f}% in contact ({bstchi[0]}/{bstchi[1]})\n"
    title += f"{name2}: {100*sltchi[0]/(sltchi[0]+sltchi[1]):.1f}% in contact ({sltchi[0]}/{sltchi[1]})\n"
    title += f"p-value={chisquare(bstchi, sltchi)[1]:0.5f}, "
    title += f"{chi2_contingency([bstchi, sltchi])[1]:0.5f}"
    plt.title(title)
    plt.axvline(250, linestyle="dashed", color='red',label='Contact threshold: 250nm')
    plt.legend()
    #plt.savefig(Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+f"cumulative_ep_dist_{name1}_{name2}_d43.png",dpi=300)
    #plt.savefig(Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+f"cumulative_ep_dist_{name1}_{name2}_d43.pdf",dpi=300)

cumulative_graph(bursting, silent, "Bursting", "Resting")
cumulative_graph(list(sorted(cast_burst+cast_silent)), list(sorted(noncast_burst+noncast_silent)), "CAST", "129")
cumulative_graph(list(sorted(cast_burst)), list(sorted(cast_silent)), "CAST bursting", "CAST resting")
cumulative_graph(list(sorted(noncast_burst)), list(sorted(noncast_silent)), "129 bursting", "129 resting")

In [None]:
from collections import defaultdict
import seaborn as sns
import math

def bursting(rna_data, i, j):
    return (rna_data[1]['keep'] or rna_data[i]['keep']) and not rna_data[j]['keep']

def resting(rna_data):
    return not rna_data[1]['keep'] and not rna_data[2]['keep'] and not rna_data[3]['keep']

def ratio(threshold, nums1, nums2):
    n1_ratio = len([n for n in nums1 if n <= threshold]) / len(nums1)
    n2_ratio = len([n for n in nums2 if n <= threshold]) / len(nums2)
    return n1_ratio / n2_ratio

def threshold_plot(dists, group1, group2, name1, name2):
    ratio_nums = [math.log2(ratio(thresh, group1, group2)) for thresh in dists]
    plt.figure()
    plt.xlabel("E-P contact threshold (nm)")
    plt.ylabel(f"Ratio of contact frequency ({name1}/{name2})")
    sns.lineplot(dists, ratio_nums)
    plt.ylim(-1,1)
    #plt.yticks([0.5,.75,1.0,1.25,1.5])
    plt.xticks([150,250,350,450,550,650,750])
    plt.tight_layout()
    filename = Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+"thresholdplot_"+name1+"_"+name2
    #plt.savefig(filename+".png",dpi=300)
    
dists = list(range(150,750,25))
bursting_cast = [data['dna']['CAST_ep_dist'] for data in linked_data if bursting(data['cast_rna'], 3, 2)]
resting_cast = [data['dna']['CAST_ep_dist'] for data in linked_data if resting(data['cast_rna'])]
bursting_129 = [data['dna']['129_ep_dist'] for data in linked_data if bursting(data['129_rna'], 2, 3)]
resting_129 = [data['dna']['129_ep_dist'] for data in linked_data if resting(data['129_rna'])]
bursting_all = bursting_cast + bursting_129
resting_all = resting_cast + resting_129
all_cast = bursting_cast + resting_cast
all_129 = bursting_129 + resting_129

threshold_plot(dists, all_cast, all_129, "CAST", "129")
threshold_plot(dists, resting_all, bursting_all, "Resting", "Bursting")
threshold_plot(dists, resting_cast, bursting_cast, "CAST Resting", "CAST Bursting")
threshold_plot(dists, resting_129, bursting_129, "129 Resting", "129 Bursting")
threshold_plot(dists, bursting_cast, bursting_129, "CAST Bursting", "129 Bursting")
threshold_plot(dists, resting_cast, resting_129, "CAST Resting", "129 Resting")



In [None]:
import pickle
linked_data = pickle.load(open('D:\DNA_FISH\E20200918_D0043\Analysis_CK\RNA_DNA_Analysis\linked_data.pkl', 'rb'))

In [None]:
import matplotlib.pyplot as plt
import math
from scipy.stats import sem,ks_2samp

def bursting(rna_data, i, j):
    return (rna_data[1]['keep'] or rna_data[i]['keep']) and not rna_data[j]['keep']

def resting(rna_data):
    return not rna_data[1]['keep'] and not rna_data[2]['keep'] and not rna_data[3]['keep']

def ratio(threshold, nums1, nums2):
    n1_ratio = len([n for n in nums1 if n <= threshold]) / len(nums1)
    n2_ratio = len([n for n in nums2 if n <= threshold]) / len(nums2)
    return n1_ratio / n2_ratio


dists = list(range(150,750,25))
bursting_cast = [data['dna']['CAST_ep_dist'] for data in linked_data if bursting(data['cast_rna'], 3, 2)]
resting_cast = [data['dna']['CAST_ep_dist'] for data in linked_data if resting(data['cast_rna'])]
bursting_129 = [data['dna']['129_ep_dist'] for data in linked_data if bursting(data['129_rna'], 2, 3)]
resting_129 = [data['dna']['129_ep_dist'] for data in linked_data if resting(data['129_rna'])]
bursting_all = bursting_cast + bursting_129

def ratio(threshold, nums1, nums2):
    l1 = len([x for x in nums1 if x <= threshold])
    l2 = len([x for x in nums2 if x <= threshold])
    ratio = l1 / (l1+l2)
    err = 1.96 * math.sqrt((ratio*(1-ratio)) / (l1+l2))
    return ratio, err


xticks = range(150,751,25)
n1 = [ratio(x, bursting_cast, resting_cast) for x in xticks]
n2 = [ratio(x, bursting_129, resting_129) for x in xticks]
nums1 = [n[0]*100 for n in n1]
err1 = [n[1]*100 for n in n1]
nums2 = [n[0]*100 for n in n2]
err2 = [n[1]*100 for n in n2]
print(ks_2samp(nums1, nums2))
plt.figure()
plt.xlabel("E-P contact threshold (nm)")
plt.ylabel(f"Fraction of alleles bursting (%)")
#plt.errorbar(xticks, nums1, yerr=err1, label="CAST")
#plt.errorbar(xticks, nums2, yerr=err2, label="129")
plt.plot(xticks, nums1, label="CAST")
plt.plot(xticks, nums2, label="129")
plt.fill_between(xticks, [x-y for x,y in zip(nums1,err1)], [x+y for x,y in zip(nums1,err1)], alpha=0.5)
plt.fill_between(xticks, [x-y for x,y in zip(nums2,err2)], [x+y for x,y in zip(nums2,err2)], alpha=0.5)

#plt.fill_between(xticks, nums2-err2, nums2+err2, alpha=0.5)
#plt.ylim(0.5,1.5)
#plt.yticks([0.5,.75,1.0,1.25,1.5])
plt.xticks([150,300,450,600,750])
#plt.ylim(ymin=0.2, ymax=0.8)
#plt.yticks([0.2,0.3,0.4,0.5,0.6,0.7,0.8])
plt.ylim(ymin=20, ymax=80)
plt.yticks([20,30,40,50,60,70,80])
plt.legend()
plt.tight_layout()
import os
Analysis_folder = r'D:\DNA_FISH\E20200918_D0043\Analysis_CK'
filename = Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+"bursting_fraction_d43"
plt.savefig(filename+".png",dpi=300)
plt.savefig(filename+".pdf",dpi=300)


In [None]:
import math
import seaborn
from scipy import stats
from matplotlib import pyplot as plt

def pearson(x, y):
    return stats.pearsonr(x, y)

def scatter_plot(xdatakey, xlabel, ydatakey, ylabel, cast=True, noncast=True):
    x=[]
    y=[]
    for data in linked_data:
        if cast:
            x.append(data['cast_rna'][1][xdatakey])
            y.append(data['dna']['CAST_'+ydatakey])
        if noncast:
            x.append(data['129_rna'][1][xdatakey])
            y.append(data['dna']['129_'+ydatakey])
    #seaborn.regplot(x,y)
    seaborn.jointplot(x, y, kind="reg", stat_func=pearson)       
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.tight_layout()
    filename = Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+"scatterplot_"+xdatakey+"_v_"+ydatakey
    if cast and not noncast:
        filename += '_CAST'
    if noncast and not cast:
        filename += '_129'
    plt.savefig(filename+".png",dpi=300)
    
xaxes = [('h', 'Sox2 brightness'), ('ratio', 'Sox2 signal-to-noise'), ('intensity', 'Sox2 intensity')]
yaxes = [('ep_dist', 'E-P distance'), ('ins', 'Insulation (10-25, 25-33)'), ('rgs_10_25', 'Radius of gyration (10-25)'),
         ('rgs_10_33', 'Radius of gyration (10-33)'), ('rgs_1_42', 'Radius of gyration (1-42)'),
         ('rgs_10_39', 'Radius of gyration (10-39)'), ('rgs_25_33', 'Radius of gyration (25-33)')]
for xdatakey, xlabel in xaxes:
    for ydatakey, ylabel in yaxes:
        scatter_plot(xdatakey, xlabel, ydatakey, ylabel, cast=True, noncast=True)
        scatter_plot(xdatakey, xlabel, ydatakey, ylabel, cast=True, noncast=False)
        scatter_plot(xdatakey, xlabel, ydatakey, ylabel, cast=False, noncast=True)
        

In [None]:
#Save the data for input to chromatin tracing pipeline
pickle.dump(linked_data, open(Analysis_folder+os.sep+"RNA_DNA_Analysis"+os.sep+"linked_data.pkl", 'wb'))