In [None]:
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "N0PCNp9HSruI",
   "metadata": {
    "id": "N0PCNp9HSruI"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "import pyfaidx\n",
    "import kipoiseq\n",
    "import functools\n",
    "from kipoiseq import Interval\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "82a83ca5-2c8f-46c0-a540-295fada652ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-05 02:26:30.636126: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "64692245-be28-49ce-9fd9-9c104336f5e0",
   "metadata": {
    "id": "64692245-be28-49ce-9fd9-9c104336f5e0"
   },
   "outputs": [],
   "source": [
    "# Data paths\n",
    "human_fasta_path = '/work/magroup/4DN/Enformer/hg38.ml.fa'\n",
    "mouse_fasta_path = '/work/magroup/4DN/Enformer/mm38.ml.fa'\n",
    "data_path = '/work/magroup/4DN/Enformer/data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "tUshpGJOSvJ6",
   "metadata": {
    "id": "tUshpGJOSvJ6"
   },
   "outputs": [],
   "source": [
    "SEQUENCE_LENGTH = 196_608\n",
    "BIN_SIZE = 128\n",
    "TARGET_LENGTH = 896"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "64b61c32-a307-48da-994c-d56c2e9ad6a7",
   "metadata": {
    "id": "64b61c32-a307-48da-994c-d56c2e9ad6a7",
    "outputId": "5176bad3-9792-4b48-f944-acdd7c199b3c"
   },
   "outputs": [],
   "source": [
    "class FastaStringExtractor:\n",
    "\n",
    "    def __init__(self, fasta_file):\n",
    "        self.fasta = pyfaidx.Fasta(fasta_file)\n",
    "        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}\n",
    "\n",
    "    def extract(self, interval: Interval, **kwargs) -> str:\n",
    "        # Truncate interval if it extends beyond the chromosome lengths.\n",
    "        chromosome_length = self._chromosome_sizes[interval.chrom]\n",
    "        trimmed_interval = Interval(interval.chrom,\n",
    "                                    max(interval.start, 0),\n",
    "                                    min(interval.end, chromosome_length),\n",
    "                                    )\n",
    "        # pyfaidx wants a 1-based interval\n",
    "        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,\n",
    "                                          trimmed_interval.start + 1,\n",
    "                                          trimmed_interval.stop).seq).upper()\n",
    "        # Fill truncated values with N's.\n",
    "        pad_upstream = 'N' * max(-interval.start, 0)\n",
    "        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)\n",
    "        return pad_upstream + sequence + pad_downstream\n",
    "\n",
    "    def close(self):\n",
    "        return self.fasta.close()\n",
    "\n",
    "\n",
    "class BasenjiDataSet(torch.utils.data.IterableDataset):\n",
    "  @staticmethod\n",
    "  def get_organism_path(organism):\n",
    "    return os.path.join(data_path, organism)\n",
    "  @classmethod\n",
    "  def get_metadata(cls, organism):\n",
    "    # Keys:\n",
    "    # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,\n",
    "    # pool_width, crop_bp, target_length\n",
    "    path = os.path.join(cls.get_organism_path(organism), 'statistics.json')\n",
    "    with tf.io.gfile.GFile(path, 'r') as f:\n",
    "      return json.load(f)\n",
    "  @staticmethod\n",
    "  def one_hot_encode(sequence):\n",
    "    return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)\n",
    "\n",
    "  @classmethod\n",
    "  def get_tfrecord_files(cls, organism, subset):\n",
    "    # Sort the values by int(*).\n",
    "    return sorted(tf.io.gfile.glob(os.path.join(\n",
    "        cls.get_organism_path(organism), 'tfrecords', f'{subset}-*.tfr'\n",
    "      )), key=lambda x: int(x.split('-')[-1].split('.')[0]))\n",
    "\n",
    "  @property\n",
    "  def num_channels(self):\n",
    "    metadata = self.get_metadata(self.organism)\n",
    "    return metadata['num_targets']\n",
    "\n",
    "  @staticmethod\n",
    "  def deserialize(serialized_example, metadata):\n",
    "    \"\"\"Deserialize bytes stored in TFRecordFile.\"\"\"\n",
    "    # Deserialization\n",
    "    feature_map = {\n",
    "        'sequence': tf.io.FixedLenFeature([], tf.string),  # Ignore this, resize our own bigger one\n",
    "        'target': tf.io.FixedLenFeature([], tf.string),\n",
    "    }\n",
    "    example = tf.io.parse_example(serialized_example, feature_map)\n",
    "    sequence = tf.io.decode_raw(example['sequence'], tf.bool)\n",
    "    sequence = tf.reshape(sequence, (metadata['seq_length'], 4))\n",
    "    sequence = tf.cast(sequence, tf.float32)\n",
    "\n",
    "    target = tf.io.decode_raw(example['target'], tf.float16)\n",
    "    target = tf.reshape(target,\n",
    "                        (metadata['target_length'], metadata['num_targets']))\n",
    "    target = tf.cast(target, tf.float32)\n",
    "\n",
    "    return {'sequence_old': sequence,\n",
    "            'target': target}\n",
    "\n",
    "  @classmethod\n",
    "  def get_dataset(cls, organism, subset, num_threads=8):\n",
    "    metadata = cls.get_metadata(organism)\n",
    "    dataset = tf.data.TFRecordDataset(cls.get_tfrecord_files(organism, subset),\n",
    "                                      compression_type='ZLIB',\n",
    "                                      num_parallel_reads=num_threads).map(\n",
    "                                          functools.partial(cls.deserialize, metadata=metadata)\n",
    "                                      )\n",
    "    return dataset\n",
    "\n",
    "  def __init__(self, organism:str, subset:str, seq_len:int, fasta_path:str, n_to_test:int = -1):\n",
    "    assert subset in {\"train\", \"valid\", \"test\"}\n",
    "    assert organism in {\"human\", \"mouse\"}\n",
    "    self.organism = organism\n",
    "    self.subset = subset\n",
    "    self.base_dir = self.get_organism_path(organism)\n",
    "    self.seq_len = seq_len\n",
    "    self.fasta_reader = FastaStringExtractor(fasta_path)\n",
    "    self.n_to_test = n_to_test\n",
    "    with tf.io.gfile.GFile(f\"{self.base_dir}/sequences.bed\", 'r') as f:\n",
    "      region_df = pd.read_csv(f, sep=\"\\t\", header=None)\n",
    "      region_df.columns = ['chrom', 'start', 'end', 'subset']\n",
    "      self.region_df = region_df.query('subset==@subset').reset_index(drop=True)\n",
    "\n",
    "  def __iter__(self):\n",
    "    worker_info = torch.utils.data.get_worker_info()\n",
    "    assert worker_info is None, \"Only support single process loading\"\n",
    "    # If num_threads > 1, the following will actually shuffle the inputs! luckily we catch this with the sequence comparison\n",
    "    basenji_iterator = self.get_dataset(self.organism, self.subset, num_threads=1).as_numpy_iterator()\n",
    "    for i, records in enumerate(basenji_iterator):\n",
    "      loc_row = self.region_df.iloc[i]\n",
    "      target_interval = Interval(loc_row['chrom'], loc_row['start'], loc_row['end'])\n",
    "      sequence_one_hot = self.one_hot_encode(self.fasta_reader.extract(target_interval.resize(self.seq_len)))\n",
    "      if self.n_to_test >= 0 and i < self.n_to_test:\n",
    "        old_sequence_onehot = records[\"sequence_old\"]\n",
    "        if old_sequence_onehot.shape[0] > sequence_one_hot.shape[0]:\n",
    "          diff = old_sequence_onehot.shape[0] - sequence_one_hot.shape[0]\n",
    "          trim = diff//2\n",
    "          np.testing.assert_equal(old_sequence_onehot[trim:(-trim)], sequence_one_hot)\n",
    "        elif sequence_one_hot.shape[0] > old_sequence_onehot.shape[0]:\n",
    "          diff = sequence_one_hot.shape[0] - old_sequence_onehot.shape[0]\n",
    "          trim = diff//2\n",
    "          np.testing.assert_equal(old_sequence_onehot, sequence_one_hot[trim:(-trim)])\n",
    "        else:\n",
    "          np.testing.assert_equal(old_sequence_onehot, sequence_one_hot)\n",
    "      yield {\n",
    "          \"sequence\": sequence_one_hot,\n",
    "          \"target\": records[\"target\"],\n",
    "      }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a3edeba1-fd93-4725-a34d-605ee17c59f8",
   "metadata": {
    "id": "a3edeba1-fd93-4725-a34d-605ee17c59f8"
   },
   "outputs": [],
   "source": [
    "# Train human\n",
    "organism=\"human\"\n",
    "subset=\"train\"\n",
    "\n",
    "max_steps=-1\n",
    "fasta_path = human_fasta_path if organism == \"human\" else mouse_fasta_path\n",
    "ds = BasenjiDataSet(organism, subset, SEQUENCE_LENGTH, fasta_path)\n",
    "total = len(ds.region_df) # number of records\n",
    "\n",
    "train_human_loader = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "28dd2cce-2293-47b7-b2a6-4a9901aba67d",
   "metadata": {
    "id": "28dd2cce-2293-47b7-b2a6-4a9901aba67d"
   },
   "outputs": [],
   "source": [
    "# losses and metrics\n",
    "\n",
    "def poisson_loss(pred, target):\n",
    "    return (pred - target * log(pred)).mean()\n",
    "\n",
    "def pearson_corr_coef(x, y, dim = 1, reduce_dims = (-1,)):\n",
    "    x_centered = x - x.mean(dim = dim, keepdim = True)\n",
    "    y_centered = y - y.mean(dim = dim, keepdim = True)\n",
    "    return F.cosine_similarity(x_centered, y_centered, dim = dim).mean(dim = reduce_dims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "76d4ab83-f175-49ee-94bc-b5890dba4d59",
   "metadata": {
    "id": "76d4ab83-f175-49ee-94bc-b5890dba4d59",
    "outputId": "d20bd00e-e288-460b-d061-95dc1180f322"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/wenduoc/mambaforge/envs/enformer/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py:171: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n",
      "  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 196608, 4])\n",
      "torch.Size([1, 896, 5313])\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "for i, batch in enumerate(train_human_loader):\n",
    "\n",
    "    batch_gpu = {k:v.to(device) for k,v in batch.items()}\n",
    "    seq = batch_gpu['sequence']\n",
    "    target = batch_gpu['target']\n",
    "    print(seq.shape)\n",
    "    print(target.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a783f0f7-cfa1-4c7c-adc4-7931177a461e",
   "metadata": {
    "id": "a783f0f7-cfa1-4c7c-adc4-7931177a461e"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc291d63-0d0e-47e4-bb3f-c262f8b06cb8",
   "metadata": {
    "id": "cc291d63-0d0e-47e4-bb3f-c262f8b06cb8"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
