In [1]:
import os
import re
import pickle

import numpy as np
import pandas as pd

In [2]:
import torch
print(torch.__version__)

2.5.1+cu118


In [3]:
print("Is CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
print("CUDA version:", torch.version.cuda)
print("CUDA devices:", [torch.cuda.device(i) for i in range(torch.cuda.device_count())])
print("CUDA_VISIBLE_DEVICES:", os.environ.get('CUDA_VISIBLE_DEVICES'))

# Try to allocate a tensor on GPU
try:
    device = torch.device('cuda:0')
    torch.tensor([1.0], device=device)
    print("Successfully allocated tensor on GPU.")
except Exception as e:
    print("Failed to allocate tensor on GPU:", e)

Is CUDA available: True
CUDA device count: 1
CUDA version: 11.8
CUDA devices: [<torch.cuda.device object at 0x14770d4c2200>]
CUDA_VISIBLE_DEVICES: MIG-288afbf5-3444-56dc-81fb-30f6073ebc03
Successfully allocated tensor on GPU.


In [4]:
import sys
sys.path.insert(0, '/n/home09/pren/3DGeno_tools/src/gt3d/modules')
# sys.path.append('/n/home09/pren/Higashi')
# sys.path = ['/n/home09/pren/3DGeno_tools/src/gt3d/modules'] + sys.path
sys.path


['/n/home09/pren/3DGeno_tools/src/gt3d/modules',
 '/n/home09/pren/.conda/envs/gt3d/lib/python310.zip',
 '/n/home09/pren/.conda/envs/gt3d/lib/python3.10',
 '/n/home09/pren/.conda/envs/gt3d/lib/python3.10/lib-dynload',
 '',
 '/n/home09/pren/.conda/envs/gt3d/lib/python3.10/site-packages',
 '/n/home09/pren/3DGeno_tools/src']

In [5]:
import higashi
print(higashi.__file__)

/n/home09/pren/3DGeno_tools/src/gt3d/modules/higashi/__init__.py


In [6]:
from higashi.Higashi_wrapper import *

In [7]:
fish_path = '/n/netscratch/zhuang_lab/Lab/Peter/higashi_dnamerfish/multiplexed_fish/t6_filelist'

In [8]:
cell_df = pd.read_csv(os.path.join(fish_path, '4DNESMTNNB3N/4DNFIA7FUW8Y.csv'), skiprows=22)
cell_df = cell_df.rename(columns={'##columns=(Cell_ID': 'Cell_ID', 'Zfp804b)': 'Zfp804b'})
cell_df


FileNotFoundError: [Errno 2] No such file or directory: '/n/netscratch/zhuang_lab/Lab/Peter/higashi_dnamerfish/multiplexed_fish/t6_filelist/4DNESMTNNB3N/4DNFIA7FUW8Y.csv'

In [None]:
cell_ids = [re.search('cell\d+', filename).group()[4:] for filename in os.listdir(os.path.join(fish_path, 'merfish_contacts_thresh1000nm'))]
len(cell_ids)

In [None]:
training_data_meta = pd.DataFrame(cell_ids, columns=['Cell_ID']).merge(cell_df, on='Cell_ID', how='left')
training_data_meta

In [None]:
label_info = {k:np.asarray(training_data_meta[k]) for k in training_data_meta.columns}
label_info_path = '/n/home09/pren/higfiles/merfish_metadata/label_info_1000nmthresh.pickle'
with open(label_info_path, 'wb') as handle:
    pickle.dump(label_info, handle)

In [None]:
config = "/n/home09/pren/higfiles/configs/config_liu_zhuang_t6_1000nmthresh.JSON"
config_info = {
    "data_dir": fish_path, # where the data is, the data has to be named as data.txt (v1) or filelist.txt (v2) or the code can't find it
    "label_path": label_info_path,
    "structured": True,
    "input_format": 'higashi_v2',
    "temp_dir": "/n/netscratch/zhuang_lab/Lab/Peter/higashi_dnamerfish/embeddings_liu_zhuang_temp6_1000nmthresh/", # where to store model temp files
    "genome_reference_path": "/n/home09/pren/higfiles/params/chromInfo.txt",
    "cytoband_path": "/n/home09/pren/higfiles/params/cytoBand.txt",
    "chrom_list": ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15',
       'chr17', 'chr18', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9',
       'chr16', 'chr19', 'chr2', 'chr8', "chrX"],
    "resolution": 2500000, # 2.5 Mbp
    "resolution_cell": 2500000,
    "resolution_fh": [2500000],
    "embedding_name": "test_dnamerfish_1000nmthresh",
    "minimum_distance": 2500000,
    "maximum_distance": -1,
    "local_transfer_range": 0,
    "loss_mode": "zinb",
    "dimensions": 100, # can be adjusted later
    "impute_list":['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15',
       'chr17', 'chr18', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9',
       'chr16', 'chr19', 'chr2', 'chr8', "chrX"],
    "neighbor_num": 5,
    "cpu_num": 10,
    "gpu_num": 1,
    "embedding_epoch":65, # this can be adjusted
    "correct_be_impute": True,
    "header_included": True,
    "reprocess": True,
    # "contact_header": ['cell_id', 'chrom1', 'pos1', 'chrom2', 'pos2', 'count']
}

# save the config file to the current directory
import json
with open(config, "w") as f:
    json.dump(config_info, f, indent = 6)

In [None]:
# Initialize the Higashi instance
higashi_model = Higashi(config)


In [None]:
# Data processing (only needs to be run for once)
higashi_model.process_data()



In [None]:
higashi_model.prep_model()

In [None]:
# Stage 1 training
higashi_model.train_for_embeddings()

In [None]:
training_data_meta['cluster_subclass'].unique()

In [None]:
training_data_meta['cluster_class'].unique()

In [None]:
# Visualize initial embedding results
cell_embeddings = higashi_model.fetch_cell_embeddings()
print (cell_embeddings.shape)

from umap import UMAP
import seaborn as sns
import matplotlib.pyplot as plt

vec = UMAP(n_components=2, n_neighbors=5, random_state=0).fit_transform(cell_embeddings)
# cell_type = higashi_model.label_info['cluster_subclass']
cell_type = higashi_model.label_info['cluster_class']
batch = higashi_model.label_info['Sample_ID']
fig = plt.figure(figsize=(14, 5))
ax = plt.subplot(1, 2, 1)
# sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_type, ax=ax, s=5, alpha=0.8, linewidth=0,
#                 hue_order=['L2/3 IT', 'Endo', 'VLMC', 'Astro', 'Oligo', 'Sst', 'OPC', 'L5 IT',
#                            'L4/5 IT', 'Vip', 'L5 ET', 'Micro', 'Pvalb', 'L6 IT', 'L6 CT',
#                            'SMC', 'Peri', 'Lamp5', 'L6b', 'L5/6 NP', 'other', 'Sncg'])
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_type, ax=ax, s=5, alpha=0.8, linewidth=0,
                hue_order=['Gluta', 'Endo', 'VLMC', 'Astro', 'Oligo', 'GABA', 'Micro', 'SMC',
                           'Peri', 'other'])
handles, labels = ax.get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=1)
ax = plt.subplot(1, 2, 2)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=batch, ax=ax, s=5, alpha=0.8, linewidth=0)
handles, labels = ax.get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=1)
plt.tight_layout()
plt.show()


In [None]:
# np.save('/n/netscratch/zhuang_lab/Lab/Peter/higashi_dnamerfish/test2/state/embeddings_dnamerfish_stage1_r2.npy', cell_embeddings)

In [None]:
higashi_model.current_device

In [None]:
# higashi_model.higashi_model = torch.load(higashi_model.save_path + "_stage1_model", map_location=higashi_model.current_device)
# higashi_model.node_embedding_init = None


In [None]:
higashi_model.train_for_imputation_nbr_0()
higashi_model.impute_no_nbr()

In [None]:
higashi_model.train_for_imputation_with_nbr()
higashi_model.impute_with_nbr()

In [None]:
# Visualize final embedding results
cell_embeddings = higashi_model.fetch_cell_embeddings()
print (cell_embeddings.shape)

from sklearn.decomposition import PCA

cell_type = higashi_model.label_info['cluster_class']
fig = plt.figure(figsize=(14, 5))
ax = plt.subplot(1, 2, 1)
vec = PCA(n_components=2).fit_transform(cell_embeddings)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_type, ax=ax, s=6, linewidth=0)
handles, labels = ax.get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=1)
ax = plt.subplot(1, 2, 2)
vec = UMAP(n_components=2).fit_transform(cell_embeddings)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_type, ax=ax, s=6, linewidth=0)
handles, labels = ax.get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=1)
plt.tight_layout()
plt.savefig('/n/home09/pren/figures/dnamerfish_pca_umapp_1000nmthresh.png')
plt.show()

In [None]:
# Visualize final embedding results
cell_embeddings = higashi_model.fetch_cell_embeddings()
print (cell_embeddings.shape)

from sklearn.decomposition import PCA

cell_type = higashi_model.label_info['neuron_identity']
fig = plt.figure(figsize=(14, 5))
ax = plt.subplot(1, 2, 1)
vec = PCA(n_components=2).fit_transform(cell_embeddings)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_type, ax=ax, s=6, linewidth=0)
handles, labels = ax.get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=1)
ax = plt.subplot(1, 2, 2)
vec = UMAP(n_components=2).fit_transform(cell_embeddings)
sns.scatterplot(x=vec[:, 0], y=vec[:, 1], hue=cell_type, ax=ax, s=6, linewidth=0)
handles, labels = ax.get_legend_handles_labels()
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
ax.legend(handles=handles, labels=labels, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=1)
plt.tight_layout()
plt.savefig('/n/home09/pren/figures/dnamerfish_pca_umapp_neuronlabels_1000nmthresh.png')
plt.show()

In [None]:
# ori.data

In [None]:
count = 0
fig = plt.figure(figsize=(6, 2*5))
i = 0
for id_ in np.random.randint(0, 620, 620):
    if i == 5:
        break
    ori, nbr0, nbr5 = higashi_model.fetch_map("chr3", id_)
    # print('ori.data.shape: ', ori.data.shape)
    if ori.data.shape[0] == 0:
        continue
    else:
        i += 1
    count += 1
    ax = plt.subplot(5, 3, count * 3 - 2)
    ax.imshow(ori.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(ori.data, 0.6))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("raw")
    ax.set_ylabel(f'Cell {id_}')

    
    ax = plt.subplot(5, 3, count * 3 - 1)
    ax.imshow(nbr0.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(nbr0.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("higashi, k=0")
    
    ax = plt.subplot(5, 3, count * 3)
    ax.imshow(nbr5.toarray(), cmap='Reds', vmin=0.0, vmax=np.quantile(nbr5.data, 0.95))
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    if count == 1:
        ax.set_title("higashi, k=5")

plt.tight_layout()
# plt.suptitle('Chromosome 3')

plt.savefig('/n/home09/pren/figures/dnamerfish_imputation_1000nmthresh.png')