In [None]:
from higashi.Higashi_wrapper import *
import pickle
import numpy as np
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform

In [None]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.version.cuda)

In [None]:
base_path = "/home/unix/jiahao/wanglab/jiahao/test/mlcb/"
data_path = os.path.join(base_path, "datasets")
metadata_path = os.path.join(base_path, "metadata") 

## sn-m3c-seq

In [None]:
# create new input file list
file_list = os.path.join(metadata_path, "snm3c_filelist.txt")
with open(file_list, "r") as f:
    input_files = f.read().splitlines()

input_files = [f"{data_path}/snm3c/{f.split('/')[7]}/{f.split('/')[8]}" for f in input_files]
input_files

In [None]:
# Create a DataFrame
df = pd.DataFrame(input_files)

# Write the DataFrame to a .txt file separated by \t
output_file = os.path.join(data_path, "snm3c", 'filelist.txt')
df.to_csv(output_file, sep='\t', index=False,  header=False)

In [None]:
# chcek dataset info
with open(os.path.join(data_path, "snm3c", 'label_info.pickle'), 'rb') as f:
    # Load the pickled data
    label_info = pickle.load(f)

print(label_info)

In [None]:
# Set the training configuration
config = os.path.join(data_path, "snm3c", "config.JSON")

config_info = {
    "data_dir": '/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/snm3c/', # where the data is, the data has to be named as data.txt or the code can't find it
    "label_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/snm3c/label_info.pickle",
    "structured": True,
    "input_format": 'higashi_v2',
    "temp_dir": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/output/snm3c", # where to store model temp files
    "genome_reference_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/metadata/chromInfo.txt",
    "cytoband_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/metadata/cytoBand.txt",
    "chrom_list": ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15',
       'chr17', 'chr18', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9',
       'chr16', 'chr19', 'chr2', 'chr8'],
    "header_included": False,
    "contact_header" : ["strand1", "chrom1", "pos1", "fragment1", "strand2", "chrom2", "pos2", "fragment2"],
    "resolution": 500000,
    "resolution_cell": 500000,
    "resolution_fh": [500000],
    "embedding_name": "snm3c",
    "minimum_distance": 500000,
    "maximum_distance": -1,
    "local_transfer_range": 0,
    "loss_mode": "zinb",
    "dimensions": 96, # can be adjusted later
    "impute_list":['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15',
       'chr17', 'chr18', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9',
       'chr16', 'chr19', 'chr2', 'chr8'],
    "neighbor_num": 5,
    "batch_id": "batch_id",
    "cpu_num": 8,
    "gpu_num": 1,
    "embedding_epoch": 30, # this can be adjusted
    "correct_be_impute": True,
}

# save the config file to the current directory
import json
with open(config,"w") as f:
    json.dump(config_info, f, indent = 6)

In [None]:
# Initialize the Higashi instance
higashi_model = Higashi(config)

# Data processing (only needs to be run for once)
higashi_model.generate_chrom_start_end()
higashi_model.extract_table()
higashi_model.create_matrix()

In [None]:
 # prep the model
higashi_model.prep_model()

# train model embeddings
higashi_model.train_for_embeddings()

In [None]:
import re
cell_groups = higashi_model.label_info['cell_group']

pattern = r"^[^_]+"

# Extract strings
extracted = [re.match(pattern, cell_group).group() for cell_group in cell_groups]
extracted

In [None]:
# Load metadata from mC
import re
pattern = r"^[^_]+"
mc_meta = pd.read_csv('/home/unix/jiahao/wanglab/jiahao/test/mlcb/metadata/CEMBA.mC.Metadata.csv')
mc_cell_group = mc_meta.CellGroup.to_list()
mc_meta_cellgroup = [re.match(pattern, cell_group).group() for cell_group in mc_cell_group]
mc_meta["cellgrp"] = mc_meta_cellgroup
mc_meta

In [None]:
class_to_cellgrp = mc_meta.groupby('cellgrp')['SubClass'].unique().to_dict()
class_to_cellgrp

In [None]:
def classify_cell_group(group):
    if "Gaba" in group:
        return "Inhibitory Neuron"
    elif "Glut" in group:
        return "Excitatory Neuron"
    elif "NN" in group:
        return "Non-Neuronal"


tmp = []
for i in extracted:
  tmp.append(class_to_cellgrp[i][0])
cell_group_mapping = {group: classify_cell_group(group) for group in pd.Series(tmp).unique()}

plotting_grps = []
for i in tmp:
  plotting_grps.append(cell_group_mapping[i])
plotting_grps

In [None]:
# Visualize embedding results
cell_embeddings = higashi_model.fetch_cell_embeddings()

from umap import UMAP
import seaborn as sns
import matplotlib.pyplot as plt
palette = sns.color_palette("tab20") + sns.color_palette("Set3", 7)

# Generate UMAP embeddings (example)
vec = UMAP(n_components=2, n_neighbors=5, random_state=99).fit_transform(cell_embeddings)
cell_groups = higashi_model.label_info['cell_group']

fig = plt.figure(figsize=(8, 9))

# First subplot without legend
ax1 = plt.subplot(1, 1, 1)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=tmp, ax=ax1, s=5, alpha=0.8, linewidth=0, palette=palette)
handles, labels = ax1.get_legend_handles_labels()
ax1.legend(
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, -0.2),  # Center the legend below the plot
    loc='upper center',          # Position legend relative to bbox_to_anchor
    ncol=3,                    # Arrange legend items in rows
    title="Cell Types"
)

plt.tight_layout()
plt.show()

In [None]:

palette =  sns.color_palette("Set2")

# Generate UMAP embeddings (example)
vec = UMAP(n_components=2, n_neighbors=5, random_state=99).fit_transform(cell_embeddings)
cell_groups = higashi_model.label_info['cell_group']

fig = plt.figure(figsize=(7, 7))

# First subplot without legend
ax1 = plt.subplot(1, 1, 1)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=plotting_grps, ax=ax1, s=5, alpha=0.8, linewidth=0, palette=palette)
handles, labels = ax1.get_legend_handles_labels()
ax1.legend(
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, -0.2),  # Center the legend below the plot
    loc='upper center',          # Position legend relative to bbox_to_anchor
    ncol=2,                    # Arrange legend items in rows
    title="Cell Types - Neurotransmitters"
)

plt.tight_layout()
plt.show()

## merfish

In [None]:
# create new input file list
contact_files = [f for f in os.listdir(os.path.join(data_path, "merfish")) if f.startswith("dist")]
len(contact_files)

In [None]:
cell_ids = [f.split("_")[1].replace('cell', '') for f in contact_files]
cell_ids

In [None]:
# chcek dataset info
with open(os.path.join(data_path, "merfish", 'label_info.pickle'), 'rb') as f:
    # Load the pickled data
    label_info = pickle.load(f)

print(label_info)

In [None]:
valid_cells = [f for f in cell_ids if f in label_info['Cell_ID']]
len(valid_cells)

In [None]:
valid_files = [f for f in contact_files if f.split("_")[1].replace('cell', '') in valid_cells]
len(valid_files)

In [None]:
# create new input file list
input_files = [f"{data_path}/merfish/{f}" for f in valid_files]
input_files

In [None]:
# Create a DataFrame
df = pd.DataFrame(input_files)

# Write the DataFrame to a .txt file separated by \t
output_file = os.path.join(data_path, "merfish", 'filelist.txt')
df.to_csv(output_file, sep='\t', index=False,  header=False)

In [None]:
# Set the training configuration
config = os.path.join(data_path, "merfish", "config.JSON")
# config = "/content/drive/Shareddrives/MLCB_project_dataset/Results/Embeddings_snm3C-seq-full/config.JSON"

config_info = {
    "data_dir": '/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/merfish/', # where the data is, the data has to be named as data.txt or the code can't find it
    "label_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/merfish/label_info.pickle",
    "structured": True,
    "input_format": 'higashi_v2',
    "temp_dir": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/output/merfish", # where to store model temp files
    "genome_reference_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/metadata/chromInfo.txt",
    "cytoband_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/metadata/cytoBand.txt",
    "chrom_list": ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15',
       'chr17', 'chr18', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9',
       'chr16', 'chr19', 'chr2', 'chr8'],
    "header_included": False,
    "resolution": 2500000,
    "resolution_cell": 2500000,
    "resolution_fh": [2500000],
    "embedding_name": "merfish",
    "minimum_distance": 2500000,
    "maximum_distance": -1,
    "local_transfer_range": 0,
    "loss_mode": "zinb",
    "dimensions": 96, # can be adjusted later
    "impute_list":['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15',
       'chr17', 'chr18', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9',
       'chr16', 'chr19', 'chr2', 'chr8'],
    "neighbor_num": 5,
    "cpu_num": 8,
    "gpu_num": 1,
    "embedding_epoch": 30, # this can be adjusted
    "correct_be_impute": True,
    "header_included": True
}

# save the config file to the current directory
import json
with open(config,"w") as f:
    json.dump(config_info, f, indent = 6)

In [None]:
# Initialize the Higashi instance
higashi_model = Higashi(config)

# Data processing (only needs to be run for once)

higashi_model.generate_chrom_start_end()
higashi_model.extract_table()
higashi_model.create_matrix()

In [None]:
 # prep the model
higashi_model.prep_model()

# train model embeddings
higashi_model.train_for_embeddings()

In [None]:
valid_index = [np.argwhere(i == higashi_model.label_info['Cell_ID'])[0][0] for i in valid_cells]
valid_index

In [None]:
# Visualize embedding results
cell_embeddings = higashi_model.fetch_cell_embeddings()

from umap import UMAP
import seaborn as sns
import matplotlib.pyplot as plt
palette = sns.color_palette("tab20") + sns.color_palette("Set3", 7)

# Generate UMAP embeddings (example)
vec = UMAP(n_components=2, n_neighbors=5, random_state=99).fit_transform(cell_embeddings)
cell_groups = higashi_model.label_info['cluster_class']
cell_groups[valid_index]

fig = plt.figure(figsize=(8, 9))

# First subplot without legend
ax1 = plt.subplot(1, 1, 1)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_groups[valid_index], ax=ax1, s=5, alpha=0.8, linewidth=0, palette=palette)
handles, labels = ax1.get_legend_handles_labels()
ax1.legend(
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, -0.2),  # Center the legend below the plot
    loc='upper center',          # Position legend relative to bbox_to_anchor
    ncol=3,                    # Arrange legend items in rows
    title="Cell Types"
)

plt.tight_layout()
plt.show()

## Combined

### modify snm3c data

In [None]:
# # create new input file list
# file_list = os.path.join(metadata_path, "snm3c_filelist.txt")
# with open(file_list, "r") as f:
#     input_files = f.read().splitlines()

# input_files = [f"{data_path}/snm3c/{f.split('/')[7]}/{f.split('/')[8]}" for f in input_files]
# input_files

In [None]:
# # create new input file list
# file_list = os.path.join(metadata_path, "snm3c_filelist.txt")
# with open(file_list, "r") as f:
#     input_files = f.read().splitlines()

# input_files = [f"{data_path}/backup/{f.split('/')[8]}" for f in input_files]
# input_files

In [None]:
# # create new snm3c files with header
# common_header =  ["strand1", "chrom1", "pos1", "fragment1", "strand2", "chrom2", "pos2", "fragment2"]
# new_data_path = '/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/test/'

# for current_file in tqdm(input_files):
#     df = pd.read_table(current_file, header=None)
#     df.columns = common_header
#     fname = os.path.basename(current_file)  
#     fname = fname.replace('.gz', '')
#     new_fname = os.path.join(new_data_path, fname)
#     df.to_csv(new_fname, sep='\t', index=False, header=True)

In [None]:
# # create new snm3c files with header
# common_header =  ["strand1", "chrom1", "pos1", "fragment1", "strand2", "chrom2", "pos2", "fragment2"]
# new_data_path = '/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/combined/'

# for current_file in tqdm(input_files):
#     df = pd.read_table(current_file, header=None)
#     df.columns = common_header
#     fname = os.path.basename(current_file)  
#     new_fname = os.path.join(new_data_path, fname)
#     df.to_csv(new_fname, sep='\t', index=False, header=True)

In [None]:
# new_data_path = '/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/test/'
# test_files = [f for f in os.listdir(new_data_path) if f.endswith('.gz')]
# df = pd.read_table(os.path.join(new_data_path, test_files[0]), header=None)
# df

In [None]:
# test_files = [f for f in os.listdir(new_data_path) if f.endswith('.gz')]
# for current_file in tqdm(test_files):
#     try:
#         df = pd.read_table(os.path.join(new_data_path, current_file), header=None)
#     except:
#         print(current_file)
#         continue


In [None]:
# create new file list for test
new_data_path = '/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/mlcb-combined/'
snm3c_list = [f for f in os.listdir(new_data_path) if f.endswith('tsv')]
merfish_list = [f for f in os.listdir(new_data_path) if f.startswith('dist')]

with open(os.path.join(data_path, "snm3c", 'label_info.pickle'), 'rb') as f:
    # Load the pickled data
    snm3c_label = pickle.load(f)
snm3c_label['cell_group'] = np.array(snm3c_label['cell_group'])
snm3c_label['batch_id'] = np.array(snm3c_label['batch_id'])

with open(os.path.join(data_path, "merfish", 'label_info.pickle'), 'rb') as f:
    # Load the pickled data
    merfish_label = pickle.load(f)

print(snm3c_label.keys())   
print(merfish_label.keys())

In [None]:
import re
cell_groups = snm3c_label['cell_group']
pattern = r"^[^_]+"

# Extract strings
extracted = [re.match(pattern, cell_group).group() for cell_group in cell_groups]

mc_meta = pd.read_csv('/home/unix/jiahao/wanglab/jiahao/test/mlcb/metadata/CEMBA.mC.Metadata.csv')
mc_cell_group = mc_meta.CellGroup.to_list()
mc_meta_cellgroup = [re.match(pattern, cell_group).group() for cell_group in mc_cell_group]
mc_meta["cellgrp"] = mc_meta_cellgroup
class_to_cellgrp = mc_meta.groupby('cellgrp')['SubClass'].unique().to_dict()

def classify_cell_group(group):
    if "Gaba" in group:
        return "Inhibitory Neuron"
    elif "Glut" in group:
        return "Excitatory Neuron"
    elif "NN" in group:
        return "Non-Neuronal"


tmp = []
for i in extracted:
  tmp.append(class_to_cellgrp[i][0])
cell_group_mapping = {group: classify_cell_group(group) for group in pd.Series(tmp).unique()}

plotting_grps = []
for i in tmp:
  plotting_grps.append(cell_group_mapping[i])

snm3c_label['cell_group'] = np.array(plotting_grps)

In [None]:
org_snm3c_list = pd.read_table(os.path.join(data_path, "snm3c", 'filelist.txt'), header=None)
all_cells = np.array([os.path.basename(f).replace('.gz', '') for f in org_snm3c_list[0].values])
valid_index = np.array([np.argwhere(i == all_cells)[0][0] for i in snm3c_list])
snm3c_cell_groups = snm3c_label['cell_group'][valid_index]  
snm3c_batch_id = snm3c_label['batch_id'][valid_index]

merfish_cell_ids = [f.split("_")[1].replace('cell', '') for f in merfish_list]
valid_cells = [f for f in merfish_cell_ids if f in merfish_label['Cell_ID']]
merfish_list_valid = [f for f in merfish_list if f.split("_")[1].replace('cell', '') in valid_cells]
valid_index = [np.argwhere(i == merfish_label['Cell_ID'])[0][0] for i in valid_cells]
merfish_cluster_class = merfish_label['cluster_class'][valid_index]

snm3c_df = pd.DataFrame(snm3c_list, columns=['file_name'])
snm3c_df['library'] = 'snm3c'
snm3c_df['batch_id'] = snm3c_batch_id
snm3c_df['cell_group'] = snm3c_cell_groups

merfish_df = pd.DataFrame(merfish_list_valid, columns=['file_name'])
merfish_df['library'] = 'merfish'
merfish_df['batch_id'] = 1
merfish_df['cell_group'] = merfish_cluster_class
merfish_label_dict = {
    'Gluta': 'Excitatory Neuron', 
    'Endo': 'Non-Neuronal', 
    'VLMC': 'Non-Neuronal', 
    'Astro': 'Non-Neuronal', 
    'Oligo': 'Non-Neuronal',
    'GABA': 'Inhibitory Neuron', 
    'Micro': 'Non-Neuronal', 
    'SMC': 'Non-Neuronal',
    'Peri': 'Non-Neuronal', 
    np.nan: 'Non-Neuronal', 
    'other': 'Non-Neuronal'
}

merfish_df['cell_group'] = merfish_df['cell_group'].map(merfish_label_dict)

total_df = pd.concat([snm3c_df, merfish_df], axis=0)
# total_df = pd.concat([snm3c_df.sample(2500), merfish_df.sample(2500)], axis=0)
total_df

In [None]:
total_df.to_csv(os.path.join(metadata_path, 'mlcb-combined-metadata.csv'), index=False, header=True)

In [None]:
total_df['file_path'] = new_data_path + total_df['file_name']
total_df['file_path'].to_csv(os.path.join(new_data_path, 'filelist.txt'), sep='\t', index=False, header=False)

total_label = {
    'library': total_df['library'].values,
    'batch_id': total_df['batch_id'].values,
    'cell_group': total_df['cell_group'].values,
}

with open(os.path.join(new_data_path, 'label_info.pickle'), 'wb') as f:
    pickle.dump(total_label, f)

In [None]:
# Set the training configuration
config = os.path.join(data_path, "mlcb-combined", "config.JSON")

config_info = {
    "data_dir": '/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/mlcb-combined/', # where the data is, the data has to be named as data.txt or the code can't find it
    "label_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/datasets/mlcb-combined/label_info.pickle",
    "structured": True,
    "input_format": 'higashi_v2',
    "temp_dir": "/home/unix/jiahao/projects/mlcb_run_500",
    "genome_reference_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/metadata/chromInfo.txt",
    "cytoband_path": "/home/unix/jiahao/wanglab/jiahao/test/mlcb/metadata/cytoBand.txt",
    "chrom_list": ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15',
       'chr17', 'chr18', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9',
       'chr16', 'chr19', 'chr2', 'chr8'],
    "header_included": True,
    "resolution": 500000, # 1000000 500000
    "resolution_cell": 500000,
    "resolution_fh": [500000],
    "embedding_name": "test",
    "minimum_distance": 500000,
    "maximum_distance": -1,
    "local_transfer_range": 0,
    "loss_mode": "zinb",
    "dimensions": 96, # can be adjusted later
    "impute_list":['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15',
       'chr17', 'chr18', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9',
       'chr16', 'chr19', 'chr2', 'chr8'],
    "neighbor_num": 5,
    "library_id": "library",
    "batch_id": "batch_id",
    "cpu_num": 8,
    "gpu_num": 1,
    "embedding_epoch": 30, # this can be adjusted
    "correct_be_impute": True,
}

# save the config file to the current directory
import json
with open(config,"w") as f:
    json.dump(config_info, f, indent = 6)

In [None]:
# Initialize the Higashi instance
config = os.path.join(data_path, "mlcb-combined", "config.JSON")
higashi_model = Higashi(config)

In [None]:
higashi_model.generate_chrom_start_end()

In [None]:
higashi_model.extract_table()

In [None]:
higashi_model.create_matrix()

In [None]:
 # prep the model
higashi_model.prep_model()

# train model embeddings
higashi_model.train_for_embeddings()

In [None]:
# Visualize embedding results
cell_embeddings = higashi_model.fetch_cell_embeddings()
# print (cell_embeddings.shape)

from umap import UMAP
import seaborn as sns
import matplotlib.pyplot as plt
palette = sns.color_palette("Set1", 2)

# Generate UMAP embeddings (example)
vec = UMAP(n_components=2, n_neighbors=5, random_state=99).fit_transform(cell_embeddings)
cell_groups = higashi_model.label_info['library']

fig = plt.figure(figsize=(8, 9))

# First subplot without legend
ax1 = plt.subplot(1, 1, 1)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_groups, ax=ax1, s=5, alpha=0.8, linewidth=0, palette=palette)
handles, labels = ax1.get_legend_handles_labels()
ax1.legend(
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, -0.2),  # Center the legend below the plot
    loc='upper center',          # Position legend relative to bbox_to_anchor
    ncol=3,                    # Arrange legend items in rows
    title="Cell Types"
)

plt.tight_layout()
plt.show()

In [None]:
cell_groups = higashi_model.label_info['cell_group']

fig = plt.figure(figsize=(8, 9))

# First subplot without legend
ax1 = plt.subplot(1, 1, 1)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_groups, ax=ax1, s=5, alpha=0.8, linewidth=0, palette='Set2')
handles, labels = ax1.get_legend_handles_labels()
ax1.legend(
    handles=handles,
    labels=labels,
    bbox_to_anchor=(0.5, -0.2),  # Center the legend below the plot
    loc='upper center',          # Position legend relative to bbox_to_anchor
    ncol=3,                    # Arrange legend items in rows
    title="Cell Types"
)

plt.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(figsize=(20, 8), nrows=1, ncols=2)

sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=higashi_model.label_info['library'], s=5, alpha=0.8, linewidth=0, palette='Set1', ax=axs[0])
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=higashi_model.label_info['cell_group'], s=5, alpha=0.8, linewidth=0, palette='Set2', ax=axs[1])

plt.tight_layout()
plt.show()

In [None]:
# Run Harmony
import harmonypy as hm
ho = hm.run_harmony(cell_embeddings, total_df, 'library')

# Write the adjusted PCs to a new file.
res = pd.DataFrame(ho.Z_corr)
res.shape

In [None]:
integrated_vec = UMAP(n_components=2, n_neighbors=5, random_state=99).fit_transform(res.T)

In [None]:
fig, axs = plt.subplots(figsize=(20, 8), nrows=1, ncols=2)

sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=higashi_model.label_info['library'], s=5, alpha=0.8, linewidth=0, palette='Set1', ax=axs[0])
sns.scatterplot(x=integrated_vec[:, 0], y=integrated_vec[:, 1], hue=higashi_model.label_info['library'], s=5, alpha=0.8, linewidth=0, palette='Set1', ax=axs[1])
axs[0].title.set_text('Before Harmony Integration')
axs[1].title.set_text('After Harmony Integration')

plt.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(figsize=(20, 16), nrows=2, ncols=2)

sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=higashi_model.label_info['library'], s=5, alpha=0.8, linewidth=0, palette='Set1', ax=axs[0, 0])
sns.scatterplot(x=integrated_vec[:, 0], y=integrated_vec[:, 1], hue=higashi_model.label_info['library'], s=5, alpha=0.8, linewidth=0, palette='Set1', ax=axs[0, 1])

sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=higashi_model.label_info['cell_group'], s=5, alpha=0.8, linewidth=0, palette='Set2', ax=axs[1, 0])
sns.scatterplot(x=integrated_vec[:, 0], y=integrated_vec[:, 1], hue=higashi_model.label_info['cell_group'], s=5, alpha=0.8, linewidth=0, palette='Set2', ax=axs[1, 1])

axs[0, 0].title.set_text('Before Harmony Integration')
axs[0, 1].title.set_text('After Harmony Integration')

plt.tight_layout()
plt.show()

In [None]:
config = os.path.join(data_path, "mlcb-combined", "config.JSON")
higashi_model = Higashi(config)
higashi_model.prep_model()

In [None]:
higashi_model.train_for_imputation_nbr_0()
higashi_model.impute_no_nbr()

In [None]:
higashi_model.train_for_imputation_with_nbr()
higashi_model.impute_with_nbr()

In [None]:
total_df = total_df.reset_index(drop=True)
total_df

In [None]:
test_ids = ['CEMBA3C_2A3C_R1_P2-4-I7-E7', 'CEMBA3C_MOp5D_R1_P1-6-L11-M23', 'CEMBA3C_4A3C_R2_P7-5-I2-F9']
total_df.loc[total_df['file_name'].str.contains(test_ids[2]), ]

In [None]:
for i in total_df.loc[[5029, 5631, 5704, 5962, 7977], 'file_name']:
    print(i)

In [None]:
import matplotlib.pyplot as plt
count = 0
fig = plt.figure(figsize=(12, 4*3))

merfish_list = [802, 2862, 2989]
# for id_ in np.random.randint(5000, 9267, 5):
for id_ in merfish_list:
    print(total_df.loc[id_, 'library'], id_)
    # code to fetch imputed contact maps
    ori, nbr0, nbr5 = higashi_model.fetch_map("chr4", id_)
    count += 1
    ax = plt.subplot(3, 3, count * 3 - 2)
    ax.imshow(ori.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(ori.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("raw")
    
    ax = plt.subplot(3, 3, count * 3 - 1)
    ax.imshow(nbr0.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(nbr0.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("higashi, k=0")
    
    ax = plt.subplot(3, 3, count * 3)
    ax.imshow(nbr5.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(nbr5.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("higashi, k=5")
    
plt.tight_layout()

In [None]:
import matplotlib.pyplot as plt
count = 0
fig = plt.figure(figsize=(12, 4*5))

merfish_list = [5029, 5631, 5704, 5962, 7977]
# for id_ in np.random.randint(5000, 9267, 5):
for id_ in merfish_list:
    print(total_df.loc[id_, 'library'], id_)
    # code to fetch imputed contact maps
    ori, nbr0, nbr5 = higashi_model.fetch_map("chr1", id_)
    count += 1
    ax = plt.subplot(5, 3, count * 3 - 2)
    ax.imshow(ori.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(ori.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("raw")
    
    ax = plt.subplot(5, 3, count * 3 - 1)
    ax.imshow(nbr0.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(nbr0.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("higashi, k=0")
    
    ax = plt.subplot(5, 3, count * 3)
    ax.imshow(nbr5.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(nbr5.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("higashi, k=5")
    
plt.tight_layout()