<a href="https://colab.research.google.com/github/liu-bioinfo-lab/general_AI_model/blob/main/runbook_obtaining_representations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Please open Google Colab notebook under TPU/GPU setting : **Runtime -> Change runtime type**

In [1]:
!git clone https://github.com/liu-bioinfo-lab/general_AI_model.git
%cd general_AI_model

Cloning into 'general_AI_model'...
remote: Enumerating objects: 133, done.[K
remote: Counting objects: 100% (22/22), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 133 (delta 10), reused 11 (delta 5), pack-reused 111 (from 2)[K
Receiving objects: 100% (133/133), 68.62 MiB | 39.12 MiB/s, done.
Resolving deltas: 100% (11/11), done.
/content/general_AI_model


In [2]:
import os
import gdown
from src.model import build_model
import argparse
import torch
try:
  import torch_xla
  import torch_xla.core.xla_model as xm
  import torch_xla.distributed.xla_multiprocessing as xmp
except Exception as e:
  print(f" Error: {e}")
!pip install kipoiseq==0.5.2 --quiet > /dev/null
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import pickle
import numpy as np
from src.tutorial_utils import FastaStringExtractor, prepare_input, extract_outputs, get_args

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
referencing 0.36.2 requires attrs>=22.2.0, but you have attrs 21.4.0 which is incompatible.
jsonschema 4.24.0 requires attrs>=22.2.0, but you have attrs 21.4.0 which is incompatible.[0m[31m
[0m



**Download Model**

In [3]:
os.makedirs('models', exist_ok=True)
model_path = 'models/ckpt.pt'
gdown.download('https://drive.google.com/uc?id=1aTpGvAUkvaxsDP_isA2n2Udbfqa9walW', model_path, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1aTpGvAUkvaxsDP_isA2n2Udbfqa9walW
From (redirected): https://drive.google.com/uc?id=1aTpGvAUkvaxsDP_isA2n2Udbfqa9walW&confirm=t&uuid=981e3956-c687-4f46-af4b-eb4410b7f210
To: /content/general_AI_model/models/ckpt.pt
100%|██████████| 468M/468M [00:02<00:00, 183MB/s]


'models/ckpt.pt'

In [4]:

### The following codes are copied from https://github.com/deepmind/deepmind-research/blob/master/enformer/enformer-usage.ipynb
fasta_file = '/root/data/genome.fa'
!mkdir -p /root/data
!wget -O - http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz | gunzip -c > {fasta_file}
pyfaidx.Faidx(fasta_file)
!ls /root/data

fasta_extractor = FastaStringExtractor(fasta_file)

--2025-06-22 05:43:27--  http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
Resolving hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)... 128.114.119.163
Connecting to hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)|128.114.119.163|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 983659424 (938M) [application/x-gzip]
Saving to: ‘STDOUT’


2025-06-22 05:43:51 (38.9 MB/s) - written to stdout [983659424/983659424]

genome.fa  genome.fa.fai


In [5]:
# Download a downsampled GM12878 ATAC-seq for example
os.makedirs('tmp_save', exist_ok=True)
atac_path = 'tmp_save/GM12878_ATAC.pickle'
gdown.download('https://drive.google.com/uc?id=1ua-fQHYjPH658oEKEpIaDBHNFbzsO1m0', atac_path, quiet=False)
with open(atac_path, 'rb') as f:
    atac_data = pickle.load(f)

Downloading...
From (original): https://drive.google.com/uc?id=1ua-fQHYjPH658oEKEpIaDBHNFbzsO1m0
From (redirected): https://drive.google.com/uc?id=1ua-fQHYjPH658oEKEpIaDBHNFbzsO1m0&confirm=t&uuid=900c0949-ef09-4503-a524-cfd6aa8681fa
To: /content/general_AI_model/tmp_save/GM12878_ATAC.pickle
100%|██████████| 192M/192M [00:00<00:00, 233MB/s]
  atac_data = pickle.load(f)


### Load model

In [6]:
args = get_args()
try:
    device = xm.xla_device()
    print(f"XLA device detected: {device}")
except Exception as e:
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"No XLA device detected. Error: {e}")

model = build_model(args)
model.load_state_dict(torch.load(model_path, map_location='cpu'),strict=True)
model.eval()
model.to(device)
device

XLA device detected: xla:0


device(type='xla', index=0)

### Run model to get 1D and 2D representations over a 500kb region

In [7]:
# specify a 500kb region

chrom, start, end = ['chr1', 1500000, 2000000]

input_x = prepare_input(
    fasta_extractor,
    chrom, start, end,
    atac_data
).to(device)

outputs = model(input_x,return_rep=True)

rep1d, rep2d = extract_outputs(outputs)

### embedding size

In [8]:
## representation of each 1kb bin in 500kb region

rep1d.shape

(500, 960)

In [9]:
## representation of interactions among 1kb bins in 500kb region

rep2d.shape

(500, 500, 96)