# Utils

> useful basic functions

In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr,spearmanr
import os
from collections import Counter, defaultdict
import re
from random import choice
from itertools import groupby, chain
from operator import itemgetter
import os

In [None]:
#| hide
from randseq.example_data import get_example_data_dir

In [None]:
#| export
def calculate_log2fc(df, reference_column='MFDpir', count_threshold=20, pseudocount=1):
    """
    Filters a DataFrame, normalizes columns by sum, calculates log2 fold change
    relative to a reference column, and drops the reference column.

    Args:
        df (pd.DataFrame): Input DataFrame (rows=features, columns=samples).
        reference_column (str): Name of the reference sample column.
        count_threshold (int): Minimum count in reference column to keep a feature.
        pseudocount (float): Value added to normalized counts before log2FC.

    Returns:
        pd.DataFrame: DataFrame with log2 fold change values. None if reference
                      column is missing. Empty DataFrame if all rows are filtered out.
    """
    if reference_column not in df.columns:
        print(f"Error: Reference column '{reference_column}' not found in DataFrame.")
        return None

    # 1. Filter rows
    df_filtered = df[df[reference_column] > count_threshold].copy()
    if df_filtered.empty:
        print("DataFrame is empty after filtering.")
        return pd.DataFrame()

    # 2. Add pseudocount
    # The pseudocount is added here, before normalization.
    df_with_pseudocount = df_filtered + pseudocount

    # 3. Normalize each column by its sum
    # Normalization is now performed on the data that includes the pseudocount.
    # df.sum(axis=0) calculates sum for each column.
    # .div performs element-wise division, axis=1 aligns column_sums with columns.
    column_sums = df_with_pseudocount.sum(axis=0)
    
    # Handle cases where column_sums might be zero (e.g., if all values were -pseudocount after filtering, though unlikely with positive counts)
    # If a column sum is 0 after adding pseudocount, it implies an issue or very specific data.
    # Division by zero will result in NaN or inf.
    df_processed = df_with_pseudocount.div(column_sums, axis=1)

    # Check if reference column still exists (it should)
    if reference_column not in df_processed.columns:
        # This check is crucial as operations might alter columns if not handled carefully.
        print(f"Error: Reference column '{reference_column}' lost during processing.")
        return None
        
    reference_values = df_processed[reference_column]

    # 4. Calculate log2FC
    # .div performs element-wise division, axis=0 aligns reference_values (a Series) with rows of df_processed.
    log2fc_df = np.log2(df_processed.div(reference_values, axis=0))
    
    # Columns are already named correctly from df_processed.

    # 5. Drop the reference column's own log2FC (which will be ~zeros)
    if reference_column in log2fc_df.columns:
        log2fc_df = log2fc_df.drop(columns=[reference_column])
    else:
        print(f"Warning: Reference column '{reference_column}' not found in log2fc_df for dropping.")
        
    return log2fc_df

In [None]:
#| hide
data_path=get_example_data_dir()
os.listdir(data_path)

['__init__.py', 'countsTable.csv', '__pycache__']

In [None]:
counts_file="countsTable.csv"
file_path=os.path.join(data_path,counts_file)
counts=pd.read_csv(file_path, index_col=0)
counts = counts[[col for col in counts.columns if ("_T0" in col) or ("MFDpir" in col)]]

log2fc_df = calculate_log2fc(counts, reference_column='MFDpir', count_threshold=20, pseudocount=1)
log2fc_df.min(axis=0)

K12_T0       -2.844874
HS_T0        -3.320280
E1114_T0     -4.544421
E1167_T0     -3.645600
H120_T0      -3.572726
TA054_T0     -3.319211
TA447_T0     -3.566876
E101_T0      -8.108167
41-1Ti9_T0   -4.255918
TA280_T0     -4.800802
M114_T0      -4.143357
TA249_T0     -4.973009
ROAR8_T0     -4.621302
JJ1886_T0    -9.188692
CFT073_T0    -3.868714
APECO1_T0    -6.525634
UTI89_T0     -8.571068
S88_T0       -8.601767
MG1655_T0    -4.777041
dtype: float64

In [None]:
#| export
def revcomp(seq:str):
    '''Computes the reverse complement of a sequence'''
    trns=str.maketrans("ATGCN","TACGN")
    return seq.upper().translate(trns)[::-1]

In [None]:
assert revcomp("ATGC")=="GCAT"
assert revcomp("ATGCN")=="NGCAT"
assert revcomp("atgNN")=="NNCAT"

In [None]:
#| export

bases=list("ATGC")
def allseqs(n,seqs=[""]):
    '''Recursive function that generates all possible sequences of length n
    seqs is a list of sequences to which we will add bases
    '''
    seqs=[s+b for s in seqs for b in bases]
    if len(seqs[0])<n:
        return allseqs(n,seqs)
    else:
        return seqs

In [None]:
",".join(allseqs(2))

'AA,AT,AG,AC,TA,TT,TG,TC,GA,GT,GG,GC,CA,CT,CG,CC'

In [None]:
#| export
flatten = lambda l: [item for sublist in l for item in sublist]

In [None]:
assert flatten([[1,2],[3,4],[5,6]])==[1, 2, 3, 4, 5, 6]

In [None]:
#| export
def get_all_sites(pattern):
    #generates a list of all sites that match the pattern
    return [s[:pattern[0]]+"N"*pattern[1]+s[pattern[0]:] for s in allseqs(pattern[0]+pattern[2])]


In [None]:
get_all_sites((3,6,4))[:5]

['AAANNNNNNAAAA',
 'AAANNNNNNAAAT',
 'AAANNNNNNAAAG',
 'AAANNNNNNAAAC',
 'AAANNNNNNAATA']

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()