<a href="https://colab.research.google.com/github/liu-bioinfo-lab/general_AI_model/blob/main/epcotv2_basic_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Please open Google Colab notebook under TPU/GPU setting : **Runtime -> Change runtime type**

In [None]:
!git clone https://github.com/liu-bioinfo-lab/general_AI_model.git
%cd general_AI_model

Cloning into 'general_AI_model'...
remote: Enumerating objects: 141, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 141 (delta 15), reused 13 (delta 6), pack-reused 111 (from 2)[K
Receiving objects: 100% (141/141), 68.62 MiB | 11.44 MiB/s, done.
Resolving deltas: 100% (16/16), done.
/content/general_AI_model


In [None]:
import os
import gdown
from src.model import build_model
import argparse
import torch
try:
  import torch_xla
  import torch_xla.core.xla_model as xm
  import torch_xla.distributed.xla_multiprocessing as xmp
except Exception as e:
  print(f" Not using torch_xla")
!pip install kipoiseq==0.5.2 --quiet > /dev/null
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import pickle
import numpy as np
from src.tutorial_utils import FastaStringExtractor, prepare_input, extract_outputs, get_args

 Not using torch_xla


**Download Model**

In [None]:
os.makedirs('models', exist_ok=True)
model_path = 'models/ckpt.pt'
gdown.download('https://drive.google.com/uc?id=1aTpGvAUkvaxsDP_isA2n2Udbfqa9walW', model_path, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1aTpGvAUkvaxsDP_isA2n2Udbfqa9walW
From (redirected): https://drive.google.com/uc?id=1aTpGvAUkvaxsDP_isA2n2Udbfqa9walW&confirm=t&uuid=fe309e31-648d-4103-95b2-2b04de8dfa02
To: /content/general_AI_model/models/ckpt.pt
100%|██████████| 468M/468M [00:07<00:00, 58.8MB/s]


'models/ckpt.pt'

In [None]:

### The following codes are copied from https://github.com/deepmind/deepmind-research/blob/master/enformer/enformer-usage.ipynb
fasta_file = '/root/data/genome.fa'
!mkdir -p /root/data
!wget -O - http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz | gunzip -c > {fasta_file}
pyfaidx.Faidx(fasta_file)
!ls /root/data

fasta_extractor = FastaStringExtractor(fasta_file)

--2025-09-14 05:10:26--  http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
Resolving hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)... 128.114.119.163
Connecting to hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)|128.114.119.163|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 983659424 (938M) [application/x-gzip]
Saving to: ‘STDOUT’


2025-09-14 05:11:26 (15.9 MB/s) - written to stdout [983659424/983659424]

genome.fa  genome.fa.fai


In [None]:
# Download a downsampled GM12878 ATAC-seq for example
os.makedirs('tmp_save', exist_ok=True)
atac_path = 'tmp_save/GM12878_ATAC.pickle'
gdown.download('https://drive.google.com/uc?id=1ua-fQHYjPH658oEKEpIaDBHNFbzsO1m0', atac_path, quiet=False)
with open(atac_path, 'rb') as f:
    atac_data = pickle.load(f)

Downloading...
From (original): https://drive.google.com/uc?id=1ua-fQHYjPH658oEKEpIaDBHNFbzsO1m0
From (redirected): https://drive.google.com/uc?id=1ua-fQHYjPH658oEKEpIaDBHNFbzsO1m0&confirm=t&uuid=3b4b888c-4fd6-4057-9d2b-99e5661747a6
To: /content/general_AI_model/tmp_save/GM12878_ATAC.pickle
100%|██████████| 192M/192M [00:02<00:00, 68.3MB/s]
  atac_data = pickle.load(f)


### Load model

In [None]:
args = get_args()
try:
    device = xm.xla_device()
    print(f"XLA device detected: {device}")
except Exception as e:
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"No XLA device detected. Error: {e}")

model = build_model(args)
model.load_state_dict(torch.load(model_path, map_location='cpu'),strict=True)
model.eval()
model.to(device)
device

No XLA device detected. Error: name 'xm' is not defined


device(type='cuda', index=0)

### Run model to get 1D and 2D representations over a 500kb region

In [None]:
### specify the coordinates of a 500kb genomic region

chrom, start, end = ['chr1', 1500000, 2000000]

input_x = prepare_input(
    fasta_extractor,
    chrom, start, end,
    atac_data
).to(device)

# with torch.no_grad():
outputs = model(input_x,return_rep=True)

rep1d, rep2d = extract_outputs(outputs)

### embedding size

In [None]:
## representation of each 1kb bin in 500kb region

rep1d.shape

(500, 960)

In [None]:
## representation of interactions among 1kb bins in 500kb region

rep2d.shape

(500, 500, 96)

### Prediction

In [None]:
modalities=['epi', 'rna', 'bru', 'microc', 'hic','intacthic','rna_strand','external_tf', 'tt', 'groseq', 'grocap', 'proseq','netcage','starr']
_, _, output,external_output = outputs
mix_output=[out.cpu().data.detach().numpy() for out in (output+external_output)]
out_dic=dict(zip(modalities,mix_output))

In [None]:
# the last dimension corresponds to modalities being predicted
for m in modalities:
    print(m,out_dic[m].shape)

epi (1, 500, 247)
rna (1, 500, 3)
bru (1, 500, 3)
microc (1, 500, 500, 2)
hic (1, 100, 100, 3)
intacthic (1, 500, 500, 2)
rna_strand (1, 500, 2)
external_tf (1, 500, 708)
tt (1, 500, 2)
groseq (1, 500, 2)
grocap (1, 500, 4)
proseq (1, 500, 3)
netcage (1, 500, 2)
starr (1, 500, 1)


### Explanation of each modality that can be predicted

* __Epigenomic features (epi).__ The list of epigenomic features can be found in "data/epi_list".

* __RNA-seq (rna).__
  - CAGE-seq
  - Total RNA-seq
  - PolyA+ RNA-seq

* __Bru-seq (bru).__
  - Bru-seq
  - BruUV-seq
  - BruChase-seq

* __Micro-c (microc).__
  - O/E normalized Micro-C
  - KR normalized Micro-C

* __Hi-C (hic).__
  - CTCF ChIA-PET
  - RNApol2 ChIA-PET
  - Hi-C

* __Intact Hi-C (intacthic).__
  - O/E normalized intact Hi-C
  - KR normalized intact Hi-C

* __RNA Strand (rna_strand).__
  - Total RNA-seq (forward)
  - Total RNA-seq (reverse)

* __Additional TFs (external_tf).__ The list of additional TFs can be found on Github in a file named unseen_tf.txt.

* __TT-seq (tt).__
  - TT-seq (forward)
  - TT-seq (reverse)

* __GRO-seq (groseq).__
  - GRO-seq (forward)
  - GRO-seq (reverse)

* __GRO-cap (grocap).__
  - GRO-cap (forward)
  - GRO-cap (reverse)
  - GRO-cap_wTAP (forward)
  - GRO-cap_wTAP(reverse)

* __PRO-seq (proseq).__
  - PRO-seq (forward)
  - PRO-seq (reverse)
  - PRO-cap

* __NET-CAGE (netcage).__
  - NET-CAGE (forward)
  - NET-CAGE (reverse)

* __STARR-seq (starr).__ STARR-seq