# Code to process VCF and make epistasis plot

In [None]:
import pandas as pd
import vcf
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.formula.api as smf
from adjustText import adjust_text
from collections import Counter
import statsmodels.api as sm
import forestplot as fp
from random import sample

In [None]:
def count_comments(filename: str):
    """
    Args:
        filename(:obj:`filepath`): file path for VCF
    Raises:
        :obj:`ValueError`: if file provided does not have comments

    Returns:
        :obj:`int`: number of comments in VCF
    """
    comments = 0
    if filename.endswith('.gz'):
        with gzip.open(filename) as fh:
            for line in fh:
                if line.startswith(b'##'):
                    comments += 1
                else:
                    break
        if comments == 0:
            raise ValueError('VCF files must have comments')
    else:
        with open(filename, 'r') as fh:
            for line in fh:
                if line.startswith('##'):
                    comments += 1
                else:
                    break
        if comments == 0:
            raise ValueError('VCF files must have commnets')
    return comments


#########

# DataFrame manipulation method

def get_vcf_sample_ids(vcf_file_path, standard_columns):
    # Open the VCF file in read mode
    with open(vcf_path, 'r') as file:
        # Skip the first 6 lines (header lines) in the file
        for _ in range(6):
            next(file)
        # Read the 7th line of the file
        seventh_line = file.readline().strip().split('\t')
        sample_ids = [column for column in seventh_line if column not in standard_columns]
    
    return sample_ids


def iterate_chunks(input_list, chunk_size):
    """
    Iterate through a list in chunks of the specified size.

    Parameters:
    - input_list: The list to iterate through.
    - chunk_size: The size of each chunk.

    Returns:
    - None (Yields chunks of the list).
    """
    for i in range(0, len(input_list), chunk_size):
        yield input_list[i:i + chunk_size]
        

def pivot_vcf(vcf_file_path, standard_columns, chunk_size):
    
    sample_ids = get_vcf_sample_ids(vcf_file_path, standard_columns)
    comments = count_comments(vcf_file_path)
    
    
    chunk_frames = []
    
    # Determine the total number of chunks
    total_chunks = (len(sample_ids) + chunk_size - 1) // chunk_size
    
    # Iterate through the list in chunks with tqdm
    for i in tqdm(range(0, len(sample_ids), chunk_size), total=total_chunks, desc='Processing chunks'):
        chunk = sample_ids[i:i + chunk_size]
        frame = pd.read_table(vcf_file_path, skiprows=comments, usecols=chunk)
        chunk_frames.append(frame.T)
        
    vcf_frame = pd.concat(chunk_frames)
    vcf_frame.reset_index(inplace=True)
    vcf_frame.rename(columns={'index': 'person_id'}, inplace=True)
    
    return vcf_frame


#########

# String manipulation method

def read_lines(filename):
    lines = []
    with open(filename, 'r') as file:
        for line in file:
            # Skip header lines or empty lines
            if line.startswith('##') or not line.strip():
                continue
            lines.append(line)
    return lines


def lines_to_dataframe(lines):
    # Split each line into fields
    fields = [line.strip().split('\t') for line in lines]
    
    # Create a DataFrame from the fields
    df = pd.DataFrame(fields)
    
    return df


def drop_rows_by_column_value(df, column_name, values_to_drop):
    # Boolean indexing to select rows where column equals any value from the list
    mask = ~df[column_name].isin(values_to_drop)
    
    # Return the filtered DataFrame
    return df[mask]


def clean_frame(df):
    ids = df.iloc[0, 1:].to_list()
    df.columns = ['participant_id'] + ids
    df = df.drop(df.index[0])
    
    return df.reset_index(drop=True)


def vcf_to_df(vcf_file_path, standard_columns):
    
    rows = read_lines(vcf_file_path)
    temp_df = lines_to_dataframe(rows)
    
    temp_df.columns = temp_df.iloc[0]
    temp_df = temp_df[1:]
    final_df = temp_df.T
    
    final_df = final_df[9:]
    final_df.columns = temp_df['ID']
    final_df = final_df.applymap(lambda x: 2 if x=='1/1' else (1 if x=='0/1' else (0 if x=='0/0' else np.nan)))
    return final_df

def get_dataframe_vcf(vcf_path):
    standard_columns = ['#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT']
    variants = vcf_to_df(vcf_path, [x for x in standard_columns if x!= 'ID'])
    variants.index = variants.index.str.split('_').str[0]
    variants.reset_index(inplace=True)
    variants[0] = variants[0].astype(int)
    return variants

In [None]:
def analyze_snp_interaction(df, snp1, snp2, phenotype, gene1, gene2, filename=None, covariates=None):
    # Set up the plot style
    plt.style.use('seaborn-v0_8-white')
    plt.rcParams['font.size'] = 10
    plt.rcParams['axes.labelsize'] = 12
    plt.rcParams['axes.titlesize'] = 14
    plt.rcParams['xtick.labelsize'] = 10
    plt.rcParams['ytick.labelsize'] = 10
    plt.rcParams['legend.fontsize'] = 10
    plt.rcParams['figure.dpi'] = 300

    # Group data and calculate statistics
    grouped_df = df.groupby([snp1, snp2]).agg({
        phenotype: ['mean', 'std', 'count']
    }).reset_index()
    grouped_df.columns = [snp1, snp2, f'{phenotype}_mean', f'{phenotype}_std', 'count']
    grouped_df['se'] = grouped_df[f'{phenotype}_std'] / np.sqrt(grouped_df['count'])
    grouped_df['ci'] = 1.96 * grouped_df['se']

    # Create the plot
    fig, ax = plt.subplots(figsize=(8, 6))
    color_palette = sns.color_palette("colorblind", n_colors=3)

    # Create scatter plot
    for i, genotype in enumerate(sorted(grouped_df[snp1].unique())):
        genotype_data = grouped_df[grouped_df[snp1] == genotype]
        ax.scatter(genotype_data[snp2], genotype_data[f'{phenotype}_mean'], 
                   color=color_palette[i], edgecolor='black', s=60, zorder=3)
        ax.plot(genotype_data[snp2], genotype_data[f'{phenotype}_mean'], 
                linestyle='--', color=color_palette[i], alpha=0.5, zorder=2)

    # Add error bars
    for i, row in grouped_df.iterrows():
        ax.errorbar(row[snp2], row[f'{phenotype}_mean'], yerr=row['ci'], 
                    fmt='none', c='black', capsize=5, capthick=1, elinewidth=1, zorder=1)
        
    # Add count labels
    for i, row in grouped_df.iterrows():
        ax.annotate(f"{int(row['count'])}", (row[snp2], row[f'{phenotype}_mean']),
                    xytext=(5, 5), textcoords='offset points', fontsize=8)

        
    # Set x-axis ticks and labels
    ax.set_xticks([0, 1, 2])
    ax.set_xticklabels(['0', '1', '2'])
    ax.set_xlim(-0.5, 2.5)
    
    # Set labels and title
    ax.set_xlabel(f'{gene2} {snp2}', fontweight='bold')
    ax.set_ylabel(f'{phenotype.capitalize()}', fontweight='bold')

    # Add legend
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=c, 
                                  markeredgecolor='black', markersize=8, label=l)
                       for c, l in zip(color_palette, sorted(df[snp1].unique()))]
    ax.legend(handles=legend_elements, title=f'{gene1} {snp1}', 
              loc='best', bbox_to_anchor=(1, 1))

    # Adjust layout and save
    plt.tight_layout()
    if filename:
        plt.savefig(filename, format='svg', dpi=300, bbox_inches='tight')
    plt.show()

    # Fit the linear model and create summary table
    if covariates is None:
        covariates = []
    formula = f'{phenotype} ~ {" + ".join(covariates + [snp1, snp2, f"{snp1}:{snp2}"])}'
    model = smf.ols(formula=formula, data=df).fit()
    
    summary_table = pd.DataFrame({
        'Coefficient': model.params,
        'Std Error': model.bse,
        't-value': model.tvalues,
        'p-value': model.pvalues,
        'Lower CI': model.conf_int()[0],
        'Upper CI': model.conf_int()[1]
    })
    
    summary_table = summary_table.reindex(summary_table['t-value'].abs().sort_values(ascending=False).index)
    
    return summary_table

# Forest Plot in R

In [None]:
library(data.table)
library(forestplot)
library(grid)

# Read the data
df_results <- fread('hfe_cirrhosis_dataframe.tsv')

# Format p-values
df_results$formatted_pval <- ifelse(df_results$`p-val` < 0.001, "<0.001", 
                                    sprintf("%.3f", df_results$`p-val`))

# Create estimated CI column
df_results$est_ci <- sprintf("%.2f (%.2f-%.2f)", df_results$r, df_results$ll, df_results$hl)

# Prepare the data for forestplot with headers
tabletext <- list()
mean <- numeric()
lower <- numeric()
upper <- numeric()
is_summary <- logical()
is_shaded <- logical()

# Add main header
tabletext <- list(c("Genotype", "N", "OR (95% CI)", "P-value"))
mean <- c(NA)
lower <- c(NA)
upper <- c(NA)
is_summary <- c(TRUE)
is_shaded <- c(FALSE)

for (cpv in c("CPV 0", "CPV 1", "CPV 2")) {
  subset <- df_results[grep(cpv, label)]
  
  # Add CPV header
  tabletext <- c(tabletext, list(c(cpv, rep("", 3))))
  mean <- c(mean, NA)
  lower <- c(lower, NA)
  upper <- c(upper, NA)
  is_summary <- c(is_summary, TRUE)
  is_shaded <- c(is_shaded, FALSE)
  
  # Add data rows
  for (i in 1:nrow(subset)) {
    tabletext <- c(tabletext, list(c(
      sub("CPV \\d, ", "", subset$label[i]),
      as.character(subset$n[i]),
      subset$est_ci[i],
      subset$formatted_pval[i]
    )))
    mean <- c(mean, subset$r[i])
    lower <- c(lower, subset$ll[i])
    upper <- c(upper, subset$hl[i])
    is_summary <- c(is_summary, FALSE)
    is_shaded <- c(is_shaded, i %% 2 == 0)  # Alternate shading
  }
}

# Convert tabletext to matrix
tabletext <- do.call(rbind, tabletext)

# Create the forest plot
fp <- forestplot(
  tabletext,
  mean = mean,
  lower = lower,
  upper = upper,
  is.summary = is_summary,
  zero = 1,
  boxsize = 0.1,
  lineheight = unit(10, "mm"),
  colgap = unit(4, "mm"),
  graphwidth = unit(120, "mm"),
  xlab = "Odds Ratio",
  xticks = c(0, 1, 2, 3, 4, 5, 6, 7),
  clip = c(0.1, 7),
  col = fpColors(box = "black", 
                 lines = "black", 
                 zero = "gray"),
  txt_gp = fpTxtGp(label = gpar(cex = 0.9),
                   ticks = gpar(cex = 0.8),
                   xlab = gpar(cex = 1),
                   title = gpar(cex = 1.2)),
  grid = structure(c(1), gp = gpar(lty = 2, col = "#CCCCFF")),
  graph.pos = 4,
  hrzl_lines = list("1" = gpar(lwd=1, col="#444444"),
                    "2" = gpar(lwd=1, col="#444444"),
                    "6" = gpar(lwd=1, col="#444444"),
                    "10" = gpar(lwd=1, col="#444444")),
  title = "Forest Plot of HFE Cirrhosis Odds Ratios",
  align = c("l", "c", "c", "c"),
  cex = 0.9,
  shaded = is_shaded
)

# Save the plot as a PDF
svg("forest_plot_hfe_cirrhosis_grouped.svg", width = 15, height = 12)
plot(fp)
dev.off()