In [None]:
import pandas as pd
import glob
import os
import logging
import time

In [None]:
import logging
import colorlog

# Set up logger
logger = logging.getLogger("color_logger")

# Prevent adding multiple handlers
if not logger.handlers:
    # Set up color formatter
    handler = colorlog.StreamHandler()
    handler.setFormatter(colorlog.ColoredFormatter(
        "%(asctime)s - %(log_color)s%(levelname)s:%(reset)s %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",  # Set the date and time format
        log_colors={
            'DEBUG': 'blue',
            'INFO': 'green',
            'WARNING': 'yellow',
            'ERROR': 'red',
            'CRITICAL': 'bold_red',
        }
    ))
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)  # Set log level to INFO

# Test logging
#logger.debug("This is a debug message (won't show)")
#logger.info("This is an info message")
#logger.warning("This is a warning message")
#logger.error("This is an error message")
#logger.critical("This is a critical message")


In [None]:
def timing(f):
    """
    Helper function for timing other functions
    Parameters
    ----------
    f : function
    Returns
    -------
    function
        new function wrap with timer and logging 
    """
    def wrap(*args):
        time1 = time.time()
        ret = f(*args)
        time2 = time.time()
        logger.info('{:s} function took {:.10f} s'.format(f.__name__, (time2-time1)))
        return ret
    
    return wrap

In [None]:
def mean(L) :
    return sum(L) / len(L)

In [None]:
def assign_tails(tail) : 
    
    if not any(i in tail for i in ['G', 'C', 'T']) : 
        return "A"
    elif not any(i in tail for i in ['A', 'C', 'T']) : 
        return "G"
    elif not any(i in tail for i in ['A', 'G', 'T']) : 
        return "C"
    elif not any(i in tail for i in ['A', 'G', 'C']) : 
        if len(tail) >= 2 : 
            return "TT"
        else :
            return "T"
    else :
        return "other"

In [None]:

@timing
def process_tailing_piRNA(files, conditions) :
    
    # get the conditions
    condition_dict = {}
    for e,c in enumerate(conditions) : 
        with open(c, 'r') as f : 
            for line in f : 
                if not line.startswith("simple_name") : 
                    info = line.strip().split(",")
                    sample = info[0]
                    condition = info[1]
                    condition_dict[sample] = condition
    f.close()
    
    # iterate through files
    index_dict = {
        'gene' : 0, 
        'start' : 1, 
        'end' : 2, 
        'seq' : 3, 
        'count' : 4, 
        'strand' : 5, 
        'feature' : 6, 
        'rawtail' : 7, 
        'tail' : 8, 
        'count_total_norm':9
    }
    
    
    counter = {}
    total_level = {}
    for e,file in enumerate(files) :
        genes, seqs, tails, counts, features, samples, conditions = [], [], [], [], [], [], []
        name = os.path.basename(file).split(".")[0]
        condition = condition_dict[name]
        logger.info(f"Processing {name}...")
        with open(file, 'r') as f : 
            for line in f : 
                if not line.startswith("gene") : 
                    
                    info = line.strip().split("\t")
                    gene = info[index_dict['gene']]
                    start = int(info[index_dict['start']])
                    seq = info[index_dict['seq']]
                    tail = info[index_dict['rawtail']]
                    count = float(info[index_dict['count_total_norm']])
                    feature = info[index_dict['feature']]
                    sample = name
                    seqlen = len(seq) - len(tail)
                    
                    if (feature == "piRNA_ref" and start <= 2) :
                        if sample in total_level.keys() : 
                            total_level[sample] += count
                        else :
                            total_level[sample] = count
                        
                        if "parn" in sample : 
                            seqlen_threshold = True
                        else : 
                            if seqlen < 21 :
                                seqlen_threshold = True
                            else : 
                                seqlen_threshold = False
                            
                        
                        
                        if (tail != "*") :
                            if seqlen_threshold : 
                                genes.append(gene)
                                seqs.append(seq)
                                tails.append(tail)
                                counts.append(count)
                                features.append(feature)
                                samples.append(sample)
                                conditions.append(condition)

                                counter_key = ( condition,  seq )
                                if counter_key in counter.keys() : 
                                    counter[counter_key] += 1 
                                else : 
                                    counter[counter_key] = 1
        f.close()        
        tmp = pd.DataFrame({
                'gene':genes,
                'seq':seqs,
                'tail':tails,
                'count':counts,
                'feature':features,
                'sample':samples,
                'condition':conditions
            })        
        if e == 0 : 
            df = tmp 
        else : 
            df = pd.concat([df, tmp], ignore_index = True)
    
    
    df['pass_filter'] = df.apply(lambda x : 1 if counter[ (x['condition'], x['seq']) ] >= 2 else 0, axis = 1)
    
    df_pass = df.query('pass_filter == 1').reset_index(drop = True).drop(columns = ['pass_filter'])
    
    df_pass['tail_group'] = df_pass.apply(lambda x: assign_tails(x['tail']), axis = 1)
    
    df_grouped = df_pass.groupby(
        ['sample', 'condition', 'tail_group']
    ).agg(
        tailed_count = ('count', 'sum')
    ).reset_index()
    
    df_grouped['total_piRNA_count'] = df_grouped.apply(lambda x: total_level[x['sample']], axis = 1)
    df_grouped['tailing_percent'] = 100*(df_grouped['tailed_count'] / df_grouped['total_piRNA_count'])
    
    df_grouped_condition = df_grouped.groupby(
        ['condition', 'tail_group']
    ).agg(
        mean_tailing_percent = ('tailing_percent', 'mean'),
        mean_tailing_rpm = ('tailed_count', 'sum')
    ).reset_index()
    
    # average the conditions in perfect_matching dict
    #perfect_matching_conditions = {}
    #for k,v in condition_dict.items() : 
    #    sample = k
    #    condition = v
    #    if sample in perfect_matching : 
    #        if not condition in perfect_matching_conditions.keys() : 
    #            perfect_matching_conditions[condition] = [ perfect_matching[sample] ]
    #        else :
    #            perfect_matching_conditions[condition] =  perfect_matching_conditions[condition] + [ perfect_matching[sample] ]

    #for k,v in perfect_matching_conditions.items() :
    #    perfect_matching_conditions[k] = mean(v)
    
    #df_grouped['perfect_matching_count'] = df_grouped.apply(lambda x: perfect_matching_conditions[x['condition']], axis = 1)
    return df_grouped_condition, df_pass
        
    
    

In [None]:
piRNA_tailing = process_tailing_piRNA(glob.glob("./smRNA_seq_mutants_remove_rRNA/20250115/tailor_transcripts/files/*.bed.tsv"),
                                ["./smRNA_seq_mutants_remove_rRNA/20250115/samples/replicates.csv"])

In [None]:
piRNA_tailing[0].to_csv("./piRNA_tailing.tsv", sep = "\t", header = True, index = False)
piRNA_tailing[1].to_csv("./piRNA_tailing_raw_data.tsv", sep = "\t", header = True, index = False)

In [None]:
piRNA_tailing[1].query('condition == "TP2"').sort_values(['count'], ascending = False).groupby(['tail_group']).agg(SUM = ('count', 'sum'))

# Other RNAs

In [None]:

def process_tailing(files : [], conditions : [], my_feature : str) :
    
    # get the conditions
    condition_dict = {}
    for e,c in enumerate(conditions) : 
        with open(c, 'r') as f : 
            for line in f : 
                if not line.startswith("simple_name") : 
                    info = line.strip().split(",")
                    sample = info[0]
                    condition = info[1]
                    condition_dict[sample] = condition
    f.close()
    
    # iterate through files
    index_dict = {
        'gene' : 0, 
        'start' : 1, 
        'end' : 2, 
        'seq' : 3, 
        'count' : 4, 
        'strand' : 5, 
        'feature' : 6, 
        'rawtail' : 7, 
        'tail' : 8, 
        'count_total_norm':9
    }
    
    
    counter = {}
    total_level = {}
    for e,file in enumerate(files) :
        genes, seqs, tails, counts, features, samples, conditions = [], [], [], [], [], [], []
        name = os.path.basename(file).split(".")[0]
        condition = condition_dict[name]
        logger.info(f"Processing {name}...")
        with open(file, 'r') as f : 
            for line in f : 
                if not line.startswith("gene") : 
                    
                    info = line.strip().split("\t")
                    gene = info[index_dict['gene']]
                    start = int(info[index_dict['start']])
                    seq = info[index_dict['seq']]
                    tail = info[index_dict['rawtail']]
                    count = float(info[index_dict['count_total_norm']])
                    feature = info[index_dict['feature']]
                    sample = name
                    seqlen = len(seq) - len(tail)
                    
                    if (feature == my_feature) :
                        if sample in total_level.keys() : 
                            total_level[sample] += count
                        else :
                            total_level[sample] = count
                        
                        seqlen_threshold = True # by default let it pass
                            
                        if (tail != "*") :
                            if seqlen_threshold : 
                                genes.append(gene)
                                seqs.append(seq)
                                tails.append(tail)
                                counts.append(count)
                                features.append(feature)
                                samples.append(sample)
                                conditions.append(condition)

                                counter_key = ( condition,  seq )
                                if counter_key in counter.keys() : 
                                    counter[counter_key] += 1 
                                else : 
                                    counter[counter_key] = 1
        f.close()        
        tmp = pd.DataFrame({
                'gene':genes,
                'seq':seqs,
                'tail':tails,
                'count':counts,
                'feature':features,
                'sample':samples,
                'condition':conditions
            })        
        if e == 0 : 
            df = tmp 
        else : 
            df = pd.concat([df, tmp], ignore_index = True)
    
    
    df['pass_filter'] = df.apply(lambda x : 1 if counter[ (x['condition'], x['seq']) ] >= 1 else 1, axis = 1)
    
    df_pass = df.query('pass_filter == 1').reset_index(drop = True).drop(columns = ['pass_filter'])
    
    df_pass['tail_group'] = df_pass.apply(lambda x: assign_tails(x['tail']), axis = 1)
    
    df_grouped = df_pass.groupby(
        ['sample', 'condition', 'tail_group']
    ).agg(
        tailed_count = ('count', 'sum')
    ).reset_index()
    
    df_grouped['total_count'] = df_grouped.apply(lambda x: total_level[x['sample']], axis = 1)
    df_grouped['tailing_percent'] = 100*(df_grouped['tailed_count'] / df_grouped['total_count'])
    
    df_grouped_condition = df_grouped.groupby(
        ['condition', 'tail_group']
    ).agg(
        mean_tailing_percent = ('tailing_percent', 'mean'),
        mean_tailing_rpm = ('tailed_count', 'sum')
    ).reset_index()
    
    df_grouped_condition['feature'] = my_feature
    
    # average the conditions in perfect_matching dict
    #perfect_matching_conditions = {}
    #for k,v in condition_dict.items() : 
    #    sample = k
    #    condition = v
    #    if sample in perfect_matching : 
    #        if not condition in perfect_matching_conditions.keys() : 
    #            perfect_matching_conditions[condition] = [ perfect_matching[sample] ]
    #        else :
    #            perfect_matching_conditions[condition] =  perfect_matching_conditions[condition] + [ perfect_matching[sample] ]

    #for k,v in perfect_matching_conditions.items() :
    #    perfect_matching_conditions[k] = mean(v)
    
    #df_grouped['perfect_matching_count'] = df_grouped.apply(lambda x: perfect_matching_conditions[x['condition']], axis = 1)
    return df_grouped_condition
        
    
    

In [None]:
miRNA_tailing = process_tailing(files = glob.glob("./smRNA_seq_mutants_remove_rRNA/20250115/tailor_transcripts/files/*.bed.tsv"),
                                conditions = ["./smRNA_seq_mutants_remove_rRNA/20250115/samples/replicates.csv"],
                                my_feature = "pre_miRNA_ref")
miRNA_tailing.to_csv("./miRNA_tailing.tsv", sep = "\t", header = True, index = False)