In [1]:
import pandas as pd
import argparse
from ast import literal_eval
import numpy as np
import itertools
from io import BytesIO
import tqdm
import dask.dataframe as dd
from dask import delayed
import matplotlib.pyplot as plt
import dask

In [2]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


cpu


In [2]:
# load the clustered data /gpfs/commons/groups/knowles_lab/Karin/data/GTEx/clustered_junctions.h5
clusts = pd.read_hdf("/gpfs/commons/groups/knowles_lab/Karin/data/GTEx/clustered_junctions_minjunccounts.h5", key='df') # these have start-1 coordinates compared to original GTEx matrix

# make Name column to match GTEx file by first need to add "chr" before Chromosome column and subtract 1 from Start column 
clusts["Name"] = "chr" + clusts["Chromosome"].astype(str) + "_" + (clusts["Start"]+1).astype(str) + "_" + clusts["End"].astype(str)

In [3]:
clusts.head()

Unnamed: 0,Chromosome,Start,End,Strand,gene_id,junction_id,gene_name,Cluster,Count,Name
0,1,169795213,169798918,+,ENSG00000000460,1_169795214_169798918,C1orf112,1,1,chr1_169795214_169798918
1,1,169806088,169807790,+,ENSG00000000460,1_169806089_169807790,C1orf112,2,1,chr1_169806089_169807790
2,1,169807929,169821678,+,ENSG00000000460,1_169807930_169821678,C1orf112,3,1,chr1_169807930_169821678
3,1,169821759,169823407,+,ENSG00000000460,1_169821760_169823407,C1orf112,4,1,chr1_169821760_169823407
4,1,169823472,169827050,+,ENSG00000000460,1_169823473_169827050,C1orf112,5,1,chr1_169823473_169827050


In [None]:
# Remove singleton clusters where Count == 1
clusts = clusts[clusts["Count"] > 1]
len(clusts.Name.unique())

In [None]:
# order clusts by descending count
clusts = clusts.sort_values(by="Count", ascending=False)
clusts.head()

# remove clusters with more than 10 junctions
clusts = clusts[clusts["Count"] <= 10]
len(clusts.Name.unique())

In [None]:
# Tot junc counts 
junc_counts = pd.read_csv("/gpfs/commons/groups/knowles_lab/Karin/data/GTEx/GTEx_juncs_total_counts.txt", sep="\t")
junc_counts.columns = ["Name", "Junc_Counts"]
junc_counts = junc_counts.sort_values(by="Junc_Counts", ascending=False)

In [None]:
# gtex sample annotations 
samples = pd.read_csv("/gpfs/commons/groups/knowles_lab/Karin/data/GTEx/GTEx_Analysis_v8_Annotations_SampleAttributesDS.txt", sep="\t")
samples = samples[["SAMPID", "SMTS", "SMTSD"]].drop_duplicates()
samples.head()

In [None]:
# make a dataframe for each tissue type in SMTS column that has each sample ID and the tissue type with corresponding junctions and their counts 

In [None]:
clusts_simple = clusts[["Name", "Cluster", "gene_name"]].drop_duplicates()
# reset index in the dataframe
clusts_simple = clusts_simple.reset_index(drop=True)
clusts_simple.head()

In [None]:
len(clusts_simple.Cluster.unique())

In [None]:
len(clusts_simple.Name.unique())

In [None]:
# subsample 500 Cluster IDs for a test run
clusts_sample = clusts_simple.sample(n=50, random_state=1)
print(len(clusts_sample.Cluster.unique()))
print(len(clusts_sample.Name.unique()))

In [None]:
clusts_sample.head()

In [None]:
0.1*300000

In [None]:
import dask.dataframe as dd
gtex_juncs = '/gpfs/commons/groups/knowles_lab/Karin/data/GTEx/GTEx_Analysis_2017-06-05_v8_STARv2.5.3a_junctions.gct'

class MeltedJunctions:
    def __init__(self, file_name, clusts_names, clusts, samples):
        self.file_name = file_name
        self.clusts_names = clusts_names
        self.clusts = clusts
        self.samples = samples
        
    def melt_junctions(self):
        melted_dfs = []
        
        # Read in the file as a Dask DataFrame
        dask_df = dd.read_csv(self.file_name, sample=1000000, sep="\t")
        
        # Skip the first two rows
        with open(self.file_name) as f:
            #next(f)
            #next(f)
            header = f.readline().strip().split("\t")
        
        print("Number of samples in the file: ", len(header))
        # Group the samples by tissue
        samples_df = self.samples

        # Keep only samples that are found in the header 
        samples_df = samples_df[samples_df['SAMPID'].isin(header)]
        grouped_samples = samples_df.groupby('SMTS')['SAMPID'].apply(list)
        # Iterate over the tissues and split the count matrix
        print("Iterating over tissues...")

        # Let's also only keep the junctions in our clusts_names list
        #dask_df = dask_df[dask_df['Name'].isin(self.clusts_names)] 

        for tissue, samples in grouped_samples.items():
            print("Processing tissue: ", tissue)
            # Get the column indices for the samples in the current tissue
            sample_indices = [header.index(sample) for sample in samples]
            print("Number of samples in the current tissue: ", str(len(sample_indices)))
            # Extract the columns for the current tissue
            tissue_df = dask_df.iloc[:, [0,1] + sample_indices]
            print("HI")
            # Filter out the junctions that are not in our clusts_names list
            tissue_df = tissue_df[tissue_df['Name'].isin(self.clusts_names)] 
            print("HI")
            # Add the tissue name as a column
            tissue_df['Tissue'] = tissue
            print("HI")
            # Extract the dataframe from dask 
            tissue_df = tissue_df.compute()
            print("HI")
            # Merge with cluster info to get Cluster ID 
            tissue_df = tissue_df.merge(self.clusts, on="Name", how="left")
            # Melt the dataframe
            tissue_df = tissue_df.melt(id_vars=['Name', 'Description', 'Tissue', 'gene_name', 'Cluster'], var_name='Sample', value_name='Count')
            # Remove rows with zero counts
            tissue_df = tissue_df[tissue_df['Count'] > 0]
            # Need to get total cluster counts for each sample-junction pair  (figure out how to do this later it's too much operation for single dask?)
            #cluster_counts= tissue_df.groupby(["Sample", "Cluster"])["Count"].sum().reset_index()
            #cluster_counts.columns = ['Sample', 'Cluster', 'Cluster_Counts']    
            #tissue_df = tissue_df.merge(clust_counts, on=["Sample", "Cluster"], how="left")
            #print(cluster_counts.head())
            melted_dfs.append(tissue_df)
            #clust_counts.append(cluster_counts)
        
        print("Concatenating melted dataframes...")
        return melted_dfs

In [None]:
# create an instance of the class with the file name and clusts names as arguments
melted_junctions = MeltedJunctions(gtex_juncs, clusts_sample.Name, clusts_sample, samples)

In [None]:
# call the melt_junctions method
melted_df = melted_junctions.melt_junctions()

In [None]:
test=melted_df[0].head(n=15)

In [None]:
test

In [None]:
cluster_counts= test.groupby(["Sample", "Cluster"])["Count"].sum().reset_index()
cluster_counts.columns = ['Sample', 'Cluster', 'Cluster_Counts']    
cluster_counts

In [None]:
test.merge(cluster_counts, on=["Sample", "Cluster"], how="left")

In [None]:
#save file and use as input for LDA script 
#summarized_data["junc_ratio"] = summarized_data["junc_count"] / summarized_data["Cluster_Counts"]
#summarized_data['sample_id_index'] = summarized_data.groupby('SAMPID').ngroup()
#summarized_data['junction_id_index'] = summarized_data.groupby('Name').ngroup()

In [None]:
summarized_data.head()

In [None]:
summarized_data.to_hdf("/gpfs/commons/groups/knowles_lab/Karin/data/GTEx/GTEx_junction_cluster_counts" + ".h5", key='df', mode='w', format="table")