HiCFoundation Resolution Enhancement Pipeline for Google Colab
This notebook provides a complete pipeline for Hi-C resolution enhancement using HiCFoundation, optimized for Google Colab.
Prerequisites
Before starting, make sure to:

Enable GPU in Runtime → Change runtime type → Hardware accelerator → GPU (T4 or better)
Have your .hic files ready to upload
Have a Google Drive account with sufficient storage space

1. Environment Setup
Check GPU and Mount Google Drive

In [None]:
# Check GPU availability
import torch
import os

if torch.cuda.is_available():
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("WARNING: No GPU detected! Please enable GPU in Runtime settings.")

# Mount Google Drive for data storage
from google.colab import drive
drive.mount('/content/drive')

# Create working directory in Google Drive
DRIVE_PATH = '/content/drive/MyDrive/HiCFoundation'
os.makedirs(DRIVE_PATH, exist_ok=True)
os.chdir(DRIVE_PATH)
print(f"Working directory: {os.getcwd()}")

Install Dependencies

In [None]:
# Install PyTorch with CUDA support
!pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html

# Install other required packages
!pip install easydict opencv-python simplejson lvis Pillow==9.5.0 pytorch_msssim
!pip install pandas hic-straw matplotlib scikit-image scipy einops tensorboard
!pip install cooler numba pyBigWig timm==0.3.2 scikit-learn

# Clone HiCFoundation repositories
!git clone https://github.com/Noble-Lab/HiCFoundation.git
!git clone https://github.com/Noble-Lab/HiCFoundation_paper.git

# Copy necessary files from HiCFoundation repo
!cp -r HiCFoundation/* .
!cp -r HiCFoundation_paper/utils/* utils/ 2>/dev/null || true

Create Directory Structure

In [None]:
import os

# Create necessary directories
dirs_to_create = [
    'utils',
    'hic-raw',
    'input-dirs',
    'input-dirs/pre-train-dirs',
    'ft-inputs',
    'ft-inputs/train',
    'ft-inputs/val',
    'outputs',
    'models',
    'logs'
]

for dir_name in dirs_to_create:
    os.makedirs(dir_name, exist_ok=True)
    print(f"Created directory: {dir_name}")

2. Data Upload
Go to this link: https://drive.google.com/drive/folders/1D5MqwauHKRFixhRbGljSnouxWFNVfL1l?usp=sharing
Download and Upload the .hic Files

In [None]:
from google.colab import files
import shutil

print("Upload your .hic files:")
uploaded = files.upload()

# Move uploaded files to hic-raw directory
for filename in uploaded.keys():
    shutil.move(filename, f'hic-raw/{filename}')
    print(f"Moved {filename} to hic-raw/")

# List files in hic-raw directory
!ls -la hic-raw/

3. Data Preprocessing
Create hic2array.py

In [None]:
%%writefile utils/hic2array.py
import numpy as np
from scipy.sparse import coo_matrix
import hicstraw
import os
import pickle

def write_pkl(data, path):
    with open(path, 'wb') as f:
        pickle.dump(data, f)

def read_chrom_array(chr1, chr2, normalization, hic_file, resolution):
    chr1_name = chr1.name
    chr2_name = chr2.name
    infos = []
    infos.append('observed')
    infos.append(normalization)
    infos.append(hic_file)
    infos.append(chr1_name)
    infos.append(chr2_name)
    infos.append('BP')
    infos.append(resolution)
    print(infos)
    row, col, val = [], [], []
    rets = hicstraw.straw(*infos)
    print('\tlen(rets): {:3e}'.format(len(rets)))
    for ret in rets:
        row.append((int)(ret.binX // resolution))
        col.append((int)(ret.binY // resolution))
        val.append(ret.counts)
    print('\tsum(val): {:3e}'.format(sum(val)))
    if sum(val) == 0:
        return None
    if chr1_name==chr2_name:
        max_shape =max(max(row),max(col))+1
        mat_coo = coo_matrix((val, (row, col)), shape = (max_shape,max_shape),dtype=np.float32)
    else:
        max_row = max(row)+1
        max_column = max(col)+1
        mat_coo = coo_matrix((val, (row, col)), shape = (max_row,max_column),dtype=np.float32)

    mat_coo = mat_coo #+ triu(mat_coo, 1).T #no below diagonaline records

    return mat_coo


def hic2array(input_hic,output_pkl=None,
              resolution=25000,normalization="NONE",
              tondarray=0):
    """
    input_hic: str, input hic file path
    output_pkl: str, output pickle file path
    resolution: int, resolution of the hic file
    """

    hic = hicstraw.HiCFile(input_hic)
    chrom_list=[]
    chrom_dict={}
    for chrom in hic.getChromosomes():
        print(chrom.name, chrom.length)
        if "all" in chrom.name.lower():
            continue
        chrom_list.append(chrom)
        chrom_dict[chrom.name]=chrom.length
    resolution_list = hic.getResolutions()
    if resolution not in resolution_list:
        print("Resolution not found in the hic file, please choose from the following list:")
        print(resolution_list)
        exit()
    output_dict={}
    for i in range(len(chrom_list)):
        for j in range(i,len(chrom_list)):
            if i!=j and tondarray in [2,3]:
                #skip inter-chromosome region
                continue
            
            chrom1 = chrom_list[i]
            chrom1_name = chrom_list[i].name
            chrom2 = chrom_list[j]
            chrom2_name = chrom_list[j].name
            if 'Un' in chrom1_name or 'Un' in chrom2_name:
                continue
            if "random" in chrom1_name.lower() or "random" in chrom2_name.lower():
                continue
            if "alt" in chrom1_name.lower() or "alt" in chrom2_name.lower():
                continue
            read_array=read_chrom_array(chrom1,chrom2, normalization, input_hic, resolution)
            if read_array is None:
                print("No data found for",chrom1_name,chrom2_name)
                continue
            if tondarray in [1,3]:
                read_array = read_array.toarray()
            if tondarray in [2,3]:
                output_dict[chrom1_name]=read_array
            else:
                output_dict[chrom1_name+"_"+chrom2_name]=read_array
    if output_pkl is not None:
        output_dir = os.path.dirname(os.path.realpath(output_pkl))
        os.makedirs(output_dir, exist_ok=True)
        write_pkl(output_dict,output_pkl)

    return output_dict

if __name__ == '__main__':
    import os 
    import sys
    if len(sys.argv) != 6:
        print('Usage: python3 hic2array.py [input.hic] [output.pkl] [resolution] [normalization_type] [mode]')
        print("This is the full hic2array script. ")
        print("normalization type: 0: None normalization; 1: VC normalization; 2: VC_SQRT normalization; 3: KR normalization; 4: SCALE normalization")
        print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for sparce matrix (only cis-contact); 3 for dense matrix (only cis-contact).")
        sys.exit(1)
    resolution = int(sys.argv[3])
    normalization_type = int(sys.argv[4])
    mode = int(sys.argv[5])
    normalization_dict={0:"NONE",1:"VC",2:"VC_SQRT",3:"KR",4:"SCALE"}
    if normalization_type not in normalization_dict:
        print('normalization type should be 0,1,2,3,4')
        print("normalization type: 0: None normalization; 1: VC normalization; 2: VC_SQRT normalization; 3: KR normalization; 4: SCALE normalization")
        sys.exit(1)
    normalization_type = normalization_dict[normalization_type]
    if mode not in [0,1,2,3]:
        print('mode should be in choice of 0/1/2/3')
        print("mode: 0 for sparse matrix, 1 for dense matrix, 2 for sparce matrix (only cis-contact); 3 for dense matrix (only cis-contact).")
        sys.exit(1)
    input_hic_path = os.path.abspath(sys.argv[1])
    output_pkl_path = os.path.abspath(sys.argv[2])
    output_dir = os.path.dirname(output_pkl_path)
    os.makedirs(output_dir,exist_ok=True)
    hic2array(input_hic_path,output_pkl_path,resolution,normalization_type,mode)

Convert .hic Files to .pkl Format

In [None]:
# List available .hic files
import glob
hic_files = glob.glob('hic-raw/*.hic')
print("Available .hic files:")
for f in hic_files:
    print(f"  - {f}")

# Convert each file (update filenames as needed)
# Example conversions:
!python3 utils/hic2array.py hic-raw/Ft1-GSM6077013_at_hic_ndx1-4_r2.hic Ftr1.pkl 25000 0 0
!python3 utils/hic2array.py hic-raw/Pt1-GSM4705443_ddcc.hic Ptr1.pkl 25000 0 0
!python3 utils/hic2array.py hic-raw/Pt2-GSM6077012_at_hic_ndx1-4_r1.hic Ptr2.pkl 25000 0 0
!python3 utils/hic2array.py hic-raw/Pv1-GSM5091844_S_WT_2h1_DNB-15.allValidPairs.hic Pv1.pkl 25000 0 0

4. Submatrix Generation
Create scan_array.py

In [None]:
%%writefile utils/scan_array.py
import numpy as np
import pickle
from scipy.sparse import coo_matrix
import os

def write_pickle(output_dict,output_path):
    """
    output_dict: dict, output dictionary
    output_path: str, output path
    """
    with open(output_path, 'wb') as f:
        pickle.dump(output_dict, f)

def scan_matrix(matrix, input_row_size,input_col_size, stride_row,
                stride_col,hic_count,output_dir,current_chrom,
                filter_threshold=0.05):
    """
    matrix: 2D array
    input_row_size: int, row size of scanned output submatrix
    input_col_size: int, column size of scanned output submatrix
    stride_row: int, row stride
    stride_col: int, column stride
    hic_count: int, total read count of the Hi-C experiments
    output_dir: str, output directory
    current_chrom: str, current chromosome
    """
    row_size = matrix.shape[0]
    col_size = matrix.shape[1]
    count_save=0
    region_size = input_row_size * input_col_size
    for i in range(0, row_size - input_row_size//2, stride_row):
        for j in range(0, col_size - input_col_size//2, stride_col):
            submatrix = np.zeros((input_row_size, input_col_size))
            row_start = max(0,i)
            row_end = min(row_size, i + input_row_size)
            col_start = max(0,j)
            col_end = min(col_size, j + input_col_size)
            submatrix[:row_end-row_start,:col_end-col_start] = matrix[row_start: row_end, col_start: col_end]
            #filter out the submatrices with too many zeros
            count_useful = np.count_nonzero(submatrix)
            if count_useful < region_size * filter_threshold:
                continue
            
            output_dict={}
            output_dict['input']=submatrix
            output_dict['input_count']=hic_count
            #judge if the diag is possibly included
            if col_start < row_start and col_end >row_start:
                output_dict['diag']=abs (col_start-row_start)
            elif col_start == row_start:
                output_dict['diag']=0
            elif col_start> row_start and col_start < row_end:
                output_dict['diag']= -abs (col_start-row_start)
            else:
                output_dict['diag']=None
            output_path = os.path.join(output_dir, str(current_chrom) + '_' + str(i) + '_' + str(j) + '.pkl')
            write_pickle(output_dict,output_path)
            count_save+=1
            if count_save%100==0:
                print('Processed %d submatrices' % count_save, " for chromosome ", current_chrom)
        
    return 

def scan_pickle(input_pkl_path, input_row_size,input_col_size, stride_row,
                stride_col,output_dir,filter_threshold):
    """
    input_pkl_path: str, input pickle path  
    input_row_size: int, row size of scanned output submatrix
    input_col_size: int, column size of scanned output submatrix
    stride_row: int, row stride
    stride_col: int, column stride
    output_dir: str, output directory
    """

    os.makedirs(output_dir, exist_ok=True)

    with open(input_pkl_path, 'rb') as f:
        data = pickle.load(f)
    total_count = 0
    for key in data:
        matrix = data[key]
        if isinstance(matrix, np.ndarray):
            cur_count = np.sum(matrix)
        elif isinstance(matrix, coo_matrix):
            cur_count = matrix.sum()
        else:
            print("Type not supported", type(matrix))
            exit()
        total_count += cur_count
    print("Total read count of Hi-C: ", total_count)        

    for key in data:
        matrix = data[key]
        if isinstance(matrix, coo_matrix):
            matrix = matrix.toarray()
            
            if matrix.shape[0]==matrix.shape[1]:
                #intra chromosmoe
                #get the symmetrical one 
                upper_tri = np.triu(matrix,1)
                all_triu = np.triu(matrix)
                matrix = all_triu + upper_tri.T
            else:
                matrix = matrix
        current_chrom = str(key)
        if "chr" not in current_chrom:
            current_chrom = "chr" + current_chrom

        scan_matrix(matrix, input_row_size,input_col_size, stride_row,
                stride_col,total_count,output_dir,current_chrom,filter_threshold)

#run with the simple command line
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_pkl_path', type=str, required=True)
    parser.add_argument('--input_row_size', type=int, required=True)
    parser.add_argument('--input_col_size', type=int, required=True)
    parser.add_argument('--stride_row', type=int, required=True)
    parser.add_argument('--stride_col', type=int, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--filter_threshold', type=float, default=0.05)
    args = parser.parse_args()
    input_pkl_path = os.path.abspath(args.input_pkl_path)
    output_dir = os.path.abspath(args.output_dir)
    scan_pickle(input_pkl_path, args.input_row_size, args.input_col_size, 
                args.stride_row, args.stride_col, output_dir, args.filter_threshold)

Generate Submatrices for Pre-training

In [None]:
# Generate submatrices for pre-training
!python3 utils/scan_array.py --input_pkl_path Ptr1.pkl  --input_row_size 448 \
    --input_col_size 448 --stride_row 224 --stride_col 224 \
    --output_dir HiC-PTR1 --filter_threshold 0.01

!python3 utils/scan_array.py --input_pkl_path Ptr2.pkl  --input_row_size 448 \
    --input_col_size 448 --stride_row 224 --stride_col 224 \
    --output_dir HiC-PTR2 --filter_threshold 0.01

!python3 utils/scan_array.py --input_pkl_path Pv1.pkl  --input_row_size 448 \
    --input_col_size 448 --stride_row 224 --stride_col 224 \
    --output_dir HiC-PV1 --filter_threshold 0.01

Create Configuration Files

In [None]:
# Create train.txt
with open('input-dirs/pre-train-dirs/train.txt', 'w') as f:
    f.write('HiC-PTR1\n')
    f.write('HiC-PTR2\n')

# Create val.txt
with open('input-dirs/pre-train-dirs/val.txt', 'w') as f:
    f.write('HiC-PV1\n')

print("Configuration files created successfully!")

5. Pre-training
Run Pre-training

In [None]:
# Note: This will take considerable time
!python3 pretrain.py --batch_size 1 --accum_iter 4 \
    --epochs 1 --warmup_epochs 1 --pin_mem \
    --mask_ratio 0.75 --sparsity_ratio 0.05 \
    --blr 1.5e-4 --min_lr 1e-7 --weight_decay 0.05 \
    --model "vit_large_patch16" --loss_alpha 1 --seed 888 \
    --data_path "input-dirs/pre-train-dirs/" --train_config "train.txt" \
    --valid_config "val.txt" --output "hicfoundation_finetune" \
    --tensorboard 1 --world_size 1 --dist_url "tcp://localhost:10001" --rank 0 \
    --input_row_size 448 --input_col_size 448 --patch_size 16 \
    --print_freq 1 --save_freq 1

Rename Output Directory

In [None]:
!mv hicfoundation_finetune hicfoundation_pretrain

6. Fine-tuning Preparation
Create downsample_pkl.py

In [None]:
%%writefile utils/downsample_pkl.py
import sys
import os
from collections import defaultdict
import pickle
import numpy as np
from scipy.sparse import coo_matrix

def array_to_coo(array):
    """
    Convert a regular 2D NumPy array to a scipy.sparse.coo_matrix.

    Parameters:
    - array (numpy.ndarray): The input 2D array.

    Returns:
    - scipy.sparse.coo_matrix: The converted COO matrix.
    """
    # Find the non-zero elements in the array
    row, col = np.nonzero(array)

    # Get the values of the non-zero elements
    data = array[row, col]

    # Create the COO matrix
    coo_mat = coo_matrix((data, (row, col)), shape=array.shape)

    return coo_mat

def sparse2tag(coo_mat):
    tag_len = coo_mat.sum()
    tag_len = int(tag_len)
    tag_mat = np.zeros((tag_len, 2))
    tag_mat = tag_mat.astype(int)
    row, col, data = coo_mat.row, coo_mat.col, coo_mat.data
    start_idx = 0
    for i in range(len(row)):
        end_idx = start_idx + int(data[i])
        tag_mat[start_idx:end_idx, :] = (row[i], col[i])
        start_idx = end_idx
    return tag_mat, tag_len

def tag2sparse(tag, nsize):
    """
    Coverts a coo-based tag matrix to sparse matrix.
    """
    coo_data, data = np.unique(tag, axis=0, return_counts=True)
    row, col = coo_data[:, 0], coo_data[:, 1]
    sparse_mat = coo_matrix((data, (row, col)), shape=(nsize, nsize))
    return sparse_mat

def downsampling_sparce(matrix, down_ratio, verbose=False):
    """
    Downsampling method for sparse matrix.
    """
    if verbose: print(f"[Downsampling] Matrix shape is {matrix.shape}")
    tag_mat, tag_len = sparse2tag(matrix)
    sample_idx = np.random.choice(tag_len, int(tag_len *down_ratio))
    sample_tag = tag_mat[sample_idx]
    if verbose: print(f'[Downsampling] Sampling {down_ratio} of {tag_len} reads')
    down_mat = tag2sparse(sample_tag, matrix.shape[0])
    return down_mat


def downsample_pkl(input_pkl, output_pkl, downsample_rate):
    data = pickle.load(open(input_pkl, 'rb'))
    return_dict={}
    for chrom in data:
        current_data = data[chrom]
        if current_data.shape[0] <=100:
            continue
        #if it is numpy array convert to sparse matrix
        if isinstance(current_data, np.ndarray):
            current_data = array_to_coo(current_data)
            
        downsampled_data = downsampling_sparce(current_data, downsample_rate,verbose=1)
        return_dict[chrom] = downsampled_data
    pickle.dump(return_dict, open(output_pkl, "wb"))
    print("finish downsampling %s"%output_pkl)

if __name__ == '__main__':
    if len(sys.argv)!=4:
        print("Usage: python3 downsample_pkl.py [input.pkl] [output.pkl] [downsample_rate]")
        print("This script is used to downsample the input pickle file.")
        print("[input.pkl]: the input pickle file")
        print("[output.pkl]: the output pickle file")
        print("[downsample_rate]: the downsample rate [float].")
        sys.exit(1)
    input_pkl = os.path.abspath(sys.argv[1])
    output_pkl = os.path.abspath(sys.argv[2])
    output_dir = os.path.dirname(output_pkl)
    os.makedirs(output_dir, exist_ok=True)    
    downsample_rate = float(sys.argv[3])
    downsample_pkl(input_pkl, output_pkl, downsample_rate)

Downsample Data

In [None]:
!python3 utils/downsample_pkl.py Ftr1.pkl Ftr1_downsampled.pkl 0.1

Create scan_array_diag.py

In [None]:
%%writefile utils/scan_array_diag.py
import numpy as np
import pickle
from scipy.sparse import coo_matrix
import os

def write_pickle(output_dict,output_path):
    """
    output_dict: dict, output dictionary
    output_path: str, output path
    """
    with open(output_path, 'wb') as f:
        pickle.dump(output_dict, f)

def scan_matrix_paired(original_matrix, downsampled_matrix, input_row_size, input_col_size, stride,
                      hic_count, output_dir, current_chrom):
    """
    original_matrix: 2D array, original high-quality Hi-C matrix
    downsampled_matrix: 2D array, downsampled low-quality Hi-C matrix
    input_row_size: int, row size of scanned output submatrix
    input_col_size: int, column size of scanned output submatrix
    stride: int, row stride
    hic_count: int, total read count of the Hi-C experiments
    output_dir: str, output directory
    current_chrom: str, current chromosome
    """
    row_size = original_matrix.shape[0]
    col_size = original_matrix.shape[1]
    count_save = 0
    
    # Ensure both matrices have the same dimensions
    assert original_matrix.shape == downsampled_matrix.shape, \
        f"Matrix shapes don't match: {original_matrix.shape} vs {downsampled_matrix.shape}"
    
    print(f"Scanning matrix {current_chrom} with shape {original_matrix.shape}")
    print(f"Submatrix size: {input_row_size}x{input_col_size}, stride: {stride}")
    
    # For rectangular matrices, scan with different patterns
    if row_size == col_size:
        # Square matrix: use diagonal scanning
        for i in range(0, row_size - input_row_size + 1, stride):
            j = i  # Diagonal scanning
            if j + input_col_size > col_size:
                continue
                
            original_submatrix = original_matrix[i:i+input_row_size, j:j+input_col_size]
            downsampled_submatrix = downsampled_matrix[i:i+input_row_size, j:j+input_col_size]
            
            # Filter out submatrices with too many zeros
            count_useful = np.count_nonzero(original_submatrix)
            if count_useful < 1:
                continue
            
            # Create paired output dictionary
            output_dict = {}
            output_dict['input'] = downsampled_submatrix.copy()
            output_dict['2d_target'] = original_submatrix.copy()
            output_dict['input_count'] = hic_count
            
            output_path = os.path.join(output_dir, str(current_chrom) + '_' + str(i) + '_' + str(j) + '.pkl')
            write_pickle(output_dict, output_path)
            count_save += 1
            
            if count_save % 100 == 0:
                print('Processed %d paired submatrices' % count_save, " for chromosome ", current_chrom)
    else:
        # Rectangular matrix: scan all possible positions
        for i in range(0, row_size - input_row_size + 1, stride):
            for j in range(0, col_size - input_col_size + 1, stride):
                original_submatrix = original_matrix[i:i+input_row_size, j:j+input_col_size]
                downsampled_submatrix = downsampled_matrix[i:i+input_row_size, j:j+input_col_size]
                
                # Filter out submatrices with too many zeros
                count_useful = np.count_nonzero(original_submatrix)
                if count_useful < 1:
                    continue
                
                # Create paired output dictionary
                output_dict = {}
                output_dict['input'] = downsampled_submatrix.copy()
                output_dict['2d_target'] = original_submatrix.copy()
                output_dict['input_count'] = hic_count
                
                output_path = os.path.join(output_dir, str(current_chrom) + '_' + str(i) + '_' + str(j) + '.pkl')
                write_pickle(output_dict, output_path)
                count_save += 1
                
                if count_save % 100 == 0:
                    print('Processed %d paired submatrices' % count_save, " for chromosome ", current_chrom)
    
    print(f"Total submatrices saved for {current_chrom}: {count_save}")
    return 

def scan_pickle_paired(original_pkl_path, downsampled_pkl_path, input_row_size, input_col_size, 
                      stride, output_dir):
    """
    original_pkl_path: str, path to original (high-quality) pickle file
    downsampled_pkl_path: str, path to downsampled (low-quality) pickle file  
    input_row_size: int, row size of scanned output submatrix
    input_col_size: int, column size of scanned output submatrix
    stride: int, row stride
    output_dir: str, output directory
    """

    os.makedirs(output_dir, exist_ok=True)

    # Load both pickle files
    with open(original_pkl_path, 'rb') as f:
        original_data = pickle.load(f)
    
    with open(downsampled_pkl_path, 'rb') as f:
        downsampled_data = pickle.load(f)
    
    # Ensure both datasets have the same chromosomes
    assert set(original_data.keys()) == set(downsampled_data.keys()), \
        "Original and downsampled data must have the same chromosomes"
    
    # Calculate total count from original data
    total_count = 0
    for key in original_data:
        matrix = original_data[key]
        if isinstance(matrix, np.ndarray):
            cur_count = np.sum(matrix)
        elif isinstance(matrix, coo_matrix):
            cur_count = matrix.sum()
        else:
            print("Type not supported", type(matrix))
            exit()
       total_count += cur_count
   print("Total read count of original Hi-C: ", total_count)        

   # Process each chromosome
   for key in original_data:
       original_matrix = original_data[key]
       downsampled_matrix = downsampled_data[key]
       
       # Convert sparse matrices to dense arrays
       if isinstance(original_matrix, coo_matrix):
           original_matrix = original_matrix.toarray()
       
       if isinstance(downsampled_matrix, coo_matrix):
           downsampled_matrix = downsampled_matrix.toarray()
       
       current_chrom = str(key)
       if "chr" not in current_chrom:
           current_chrom = "chr" + current_chrom
       
       # Only apply symmetry operation if matrix is square
       if original_matrix.shape[0] == original_matrix.shape[1]:
           # Get the symmetrical matrix for square matrices
           upper_tri = np.triu(original_matrix, 1)
           all_triu = np.triu(original_matrix)
           original_matrix = all_triu + upper_tri.T
           
           upper_tri = np.triu(downsampled_matrix, 1)
           all_triu = np.triu(downsampled_matrix)
           downsampled_matrix = all_triu + upper_tri.T
       else:
           print(f"Warning: Matrix for {current_chrom} is not square ({original_matrix.shape}). Skipping symmetry operation.")

       print(f"Processing chromosome {current_chrom}")
       print(f"Original matrix shape: {original_matrix.shape}")
       print(f"Downsampled matrix shape: {downsampled_matrix.shape}")

       scan_matrix_paired(original_matrix, downsampled_matrix, input_row_size, input_col_size, 
                         stride, total_count, output_dir, current_chrom)

# Run with the simple command line
if __name__ == '__main__':
   import argparse
   parser = argparse.ArgumentParser()
   parser.add_argument('--original_pkl_path', type=str, required=True, 
                      help='Path to original (high-quality) pickle file')
   parser.add_argument('--downsampled_pkl_path', type=str, required=True,
                      help='Path to downsampled (low-quality) pickle file')
   parser.add_argument('--input_row_size', type=int, required=True)
   parser.add_argument('--input_col_size', type=int, required=True)
   parser.add_argument('--stride', type=int, required=True)
   parser.add_argument('--output_dir', type=str, required=True)
   args = parser.parse_args()
   
   original_pkl_path = os.path.abspath(args.original_pkl_path)
   downsampled_pkl_path = os.path.abspath(args.downsampled_pkl_path)
   output_dir = os.path.abspath(args.output_dir)
   
   scan_pickle_paired(original_pkl_path, downsampled_pkl_path, args.input_row_size, 
                     args.input_col_size, args.stride, output_dir)

Generate Paired Submatrices

In [None]:
!python3 utils/scan_array_diag.py \
    --original_pkl_path Ftr1.pkl \
    --downsampled_pkl_path Ftr1_downsampled.pkl \
    --input_row_size 224 --input_col_size 224 --stride 20 \
    --output_dir Ftr1

Prepare Fine-tuning Data

In [None]:
import glob
import random
import shutil

# Get all pkl files from Ftr1 directory
ftr1_files = glob.glob('Ftr1/*.pkl')

# Shuffle and split (80-20 split)
random.shuffle(ftr1_files)
split_idx = int(0.8 * len(ftr1_files))

train_files = ftr1_files[:split_idx]
val_files = ftr1_files[split_idx:]

# Copy files to respective directories
for f in train_files:
    shutil.copy(f, 'ft-inputs/train/')
for f in val_files:
    shutil.copy(f, 'ft-inputs/val/')

# Create configuration files
with open('ft-inputs/train_config.txt', 'w') as f:
    f.write('train\n')

with open('ft-inputs/val_config.txt', 'w') as f:
    f.write('val\n')

print(f"Created fine-tuning dataset: {len(train_files)} train, {len(val_files)} validation samples")

7. Fine-tuning
Create Modified train_epoch.py

In [None]:
%%writefile finetune/train_epoch.py
import math
import sys
import numpy as np
from typing import Iterable
import torch
import torch.nn.functional as F
import time

from ops.Logger import MetricLogger,SmoothedValue
import model.lr_sched as lr_sched
from finetune.loss import configure_loss
from ops.train_utils import list_to_device, to_value, create_image, torch_to_nparray, convert_gray_rgbimage


def train_epoch(model, data_loader_train, optimizer, 
                loss_scaler, epoch, device,
                log_writer=None, args=None):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))

    header = 'Epoch: [{}]'.format(epoch)
    print_freq = args.print_freq

    accum_iter = args.accum_iter

    optimizer.zero_grad()
    if log_writer is not None:
        print('Tensorboard log dir: {}'.format(log_writer.log_dir))
    print("number of iterations: ",len(data_loader_train))
    criterion = configure_loss(args)

    num_iter = len(data_loader_train)
    for data_iter_step, train_data in enumerate(metric_logger.log_every(data_loader_train, print_freq, header)):
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
        input_matrix, total_count, target_matrix, embed_target, target_vector = list_to_device(train_data,device=device)
        
        # Forward pass
        output_embedding, output_2d, output_1d = model(input_matrix, total_count)
        
        # Calculate losses - ensure all outputs participate in loss calculation
        loss_components = []
        
        if embed_target is not None:
            embedding_loss = criterion(output_embedding, embed_target)
            loss_components.append(embedding_loss)
        else:
            # Use a small multiplier on the output to ensure gradients flow
            # but don't affect the actual loss value
            embedding_loss = 0.0 * output_embedding.mean()
            loss_components.append(embedding_loss)
            
        if target_matrix is not None:
            #flatten 2d matrix
            output_2d_flatten = torch.flatten(output_2d, start_dim=1,end_dim=-1)
            target_matrix_flatten = torch.flatten(target_matrix, start_dim=1,end_dim=-1)
            output_2d_loss = criterion(output_2d_flatten, target_matrix_flatten)
            loss_components.append(output_2d_loss)
        else:
            # Use a small multiplier on the output to ensure gradients flow
            output_2d_loss = 0.0 * output_2d.mean()
            loss_components.append(output_2d_loss)
            
        if target_vector is not None:
            output_1d_loss = criterion(output_1d, target_vector)
            loss_components.append(output_1d_loss)
        else:
            # Use a small multiplier on the output to ensure gradients flow
            output_1d_loss = 0.0 * output_1d.mean()
            loss_components.append(output_1d_loss)
        
        # Sum all loss components
        loss = sum(loss_components)
        
        # Update metrics
        metric_logger.update(loss=to_value(loss))
        metric_logger.update(embedding_loss=to_value(embedding_loss))
        metric_logger.update(output_2d_loss=to_value(output_2d_loss))
        metric_logger.update(output_1d_loss=to_value(output_1d_loss))
        
        if not math.isfinite(to_value(loss)):
            print("Loss is {}, stopping training".format(to_value(loss)))
            #sys.exit(1)
            optimizer.zero_grad()
            continue
            
        loss = loss / accum_iter
        loss_scaler(loss, optimizer, parameters=model.parameters(),
                    update_grad=(data_iter_step + 1) % accum_iter == 0)

        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

        torch.cuda.synchronize() # Make sure all gradients are finished computing before moving on
        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)
        

        if log_writer is not None and ((data_iter_step + 1) % accum_iter == 0 or data_iter_step==0):
            """ 
            We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader_train) + epoch) * 1000)
            log_writer.add_scalars('Loss/loss', {'train_loss': to_value(loss)}, epoch_1000x)
            log_writer.add_scalars('Loss/embedding_loss', {'train_loss': to_value(embedding_loss)}, epoch_1000x)
            log_writer.add_scalars('Loss/output_2d_loss', {'train_loss': to_value(output_2d_loss)}, epoch_1000x)
            log_writer.add_scalars('Loss/output_1d_loss', {'train_loss': to_value(output_1d_loss)}, epoch_1000x)
            log_writer.add_scalars('LR/lr', {'lr': lr}, epoch_1000x)
            if ((data_iter_step+1)//accum_iter)%50==0 or data_iter_step==0:
                #add visualization for your output and input
                new_samples = create_image(input_matrix)
                select_num = min(8,len(new_samples))
                sample_image = torch_to_nparray(new_samples.clone().detach()[:select_num])
                log_writer.add_images('Input_%s'%"train", sample_image, epoch_1000x)
                output_2d_image = convert_gray_rgbimage(output_2d.clone().detach()[:select_num])
                output_2d_image = torch_to_nparray(output_2d_image)
                log_writer.add_images('Output_2d_%s'%"train", output_2d_image, epoch_1000x)
                # for name, param in model.named_parameters():
                #     log_writer.add_histogram(name, param, epoch_1000x)
                #raise errors, see https://github.com/pytorch/pytorch/issues/91516
                #If you want to use this, install tensorboardX 
                #then change the code in main_worker.py to "from tensorboardX import SummaryWriter"
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

Run Fine-tuning

In [None]:
!python3 finetune.py --batch_size 1 --accum_iter 4 \
    --epochs 1 --warmup_epochs 0 --pin_mem \
    --blr 1e-3 --min_lr 1e-7 --weight_decay 0.05 \
    --layer_decay 0.75 --model vit_large_patch16 \
    --pretrain hicfoundation_pretrain/model/model_best.pth.tar \
    --finetune 1 --seed 888 \
    --loss_type 1 --data_path "ft-inputs" \
    --train_config "train_config.txt" \
    --valid_config "val_config.txt" \
    --output "hicfoundation_finetune" --tensorboard 1 \
    --world_size 1 --dist_url "tcp://localhost:10001" --rank 0 \
    --input_row_size 448 --input_col_size 448 --patch_size 16 \
    --print_freq 1 --save_freq 1

8. Inference
Run Inference

In [None]:
# Update the filename below to match your uploaded test file
!python inference.py --batch_size 1 \
    --input hic-raw/B1-GSM4705442_cmt2cmt3.hic \
    --resolution 10000 \
    --task 3 \
    --input_row_size 224 --input_col_size 224 \
    --stride 32 --bound 0 \
    --num_workers 1 \
    --model hicfoundation_finetune/model/model_best.pth.tar \
    --model_path hicfoundation_finetune/model/model_best.pth.tar \
    --output outputs/B1_enhanced

Download Results

In [None]:
# Zip and download results
import zipfile

zip_filename = 'hicfoundation_results.zip'
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    for root, dirs, files in os.walk('outputs'):
        for file in files:
            file_path = os.path.join(root, file)
            zipf.write(file_path, os.path.relpath(file_path, '.'))

# Download the zip file
files.download(zip_filename)
print(f"Results downloaded as {zip_filename}")

Done! I apologize if there is any errors, this is my first time using Jupyter Notebook. However, all of the commands and steps are exactly what I did to get my results in the final paper and presentation. 