In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [3]:
from tables import *

In [4]:
import numpy as np

In [5]:
class Particle(IsDescription):
    name      = StringCol(16)   # 16-character String
    idnumber  = Int64Col()      # Signed 64-bit integer
    ADCcount  = UInt16Col()     # Unsigned short integer
    TDCcount  = UInt8Col()      # unsigned byte
    grid_i    = Int32Col()      # 32-bit integer
    grid_j    = Int32Col()      # 32-bit integer
    pressure  = Float32Col()    # float  (single-precision)
    energy    = Float64Col()    # double (double-precision)

In [6]:
h5file = open_file("tutorial1.h5", mode="w", title="Test file")
group = h5file.create_group("/", 'detector', 'Detector information')
table = h5file.create_table(group, 'readout', Particle, "Readout example")

In [7]:
table = h5file.root.detector.readout

In [8]:
table

/detector/readout (Table(0,)) 'Readout example'
  description := {
  "ADCcount": UInt16Col(shape=(), dflt=0, pos=0),
  "TDCcount": UInt8Col(shape=(), dflt=0, pos=1),
  "energy": Float64Col(shape=(), dflt=0.0, pos=2),
  "grid_i": Int32Col(shape=(), dflt=0, pos=3),
  "grid_j": Int32Col(shape=(), dflt=0, pos=4),
  "idnumber": Int64Col(shape=(), dflt=0, pos=5),
  "name": StringCol(itemsize=16, shape=(), dflt=b'', pos=6),
  "pressure": Float32Col(shape=(), dflt=0.0, pos=7)}
  byteorder := 'little'
  chunkshape := (1394,)

In [3]:
from __future__ import annotations

import os
import tables

from collections import defaultdict

from scr.utils import get_folder_file_names, checkNgen_folder
from scr.encoding.encoding_classes import get_emb_info, OnehotEncoder
from scr.preprocess.data_process import ProtranDataset


class GenerateEmbeddings:
    """A class for generating and saving embeddings"""

    def __init__(
        self,
        dataset_path: str,
        encoder_name: str,
        reset_param: bool = False,
        resample_param: bool = False,
        embed_batch_size: int = 128,
        flatten_emb: bool | str = False,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        embed_folder: str = "embeddings",
        **encoder_params,
    ) -> None:

        """
        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format
            columns include: sequence, target, set, validation,
            mut_name (optional), mut_numb (optional)
        - encoder_name: str, the name of the encoder
        - reset_param: bool = False, if update the full model to xavier_uniform_
        - resample_param: bool = False, if update the full model to xavier_normal_
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_path: str = None, path to presaved embedding
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - embed_folder: str = "results/train_val_test", the parent folder for all results
        - encoder_params: kwarg, additional parameters for encoding
        """

        self.dataset_path = dataset_path
        self.encoder_name = encoder_name
        self.reset_param = reset_param
        self.resample_param = resample_param
        self.flatten_emb = flatten_emb

        self.embed_folder = embed_folder

        if self.reset_param and "-rand" not in self.embed_folder:
            self.embed_folder = f"{self.embed_folder}-rand"

        if self.resample_param and "-stat" not in self.embed_folder:
            self.embed_folder = f"{self.embed_folder}-stat"

        subset_list = ["train", "val", "test"]

        self.encoder_name, encoder_class, total_emb_layer = get_emb_info(
            self.encoder_name
        )

        print(encoder_class)

        assert encoder_class != OnehotEncoder, "Generate onehot on the fly instead"

        # get the encoder
        self._encoder = encoder_class(
            encoder_name=encoder_name,
            reset_param=reset_param,
            resample_param=resample_param,
            **encoder_params,
        )

        # get the dim of the array to be saved
        earray_dim = (0, self._encoder.embed_dim)

        dataset_folder, _ = get_folder_file_names(
            parent_folder=self.embed_folder,
            dataset_path=self.dataset_path,
            encoder_name=self.encoder_name,
            embed_layer=0,
            flatten_emb=self.flatten_emb,
        )

        # Close all the open files
        tables.file._open_files.close_all()

        for subset in subset_list:
            init_array_list = [None] * total_emb_layer

            file_path = os.path.join(
                checkNgen_folder(os.path.join(dataset_folder, subset)), "embedding.h5"
            )

            # check all the embedding file h5 files
            # to remove old ones before generating new ones
            if os.path.isfile(file_path):
                print("Overwritting {0}".format(file_path))
                os.remove(file_path)

            # init file open
            f = tables.open_file(file_path, mode="a")
            for emb_layer in range(total_emb_layer):
                init_array_list[emb_layer] = f.create_earray(
                    f.root, "layer" + str(emb_layer), tables.Float32Atom(), earray_dim
                )

            # get the dataset to be encoded
            ds = ProtranDataset(
                dataset_path=dataset_path,
                subset=subset,
                encoder_name=encoder_name,
                reset_param=reset_param,
                resample_param=resample_param,
                embed_batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                embed_path=None,
                seq_start_idx=seq_start_idx,
                seq_end_idx=seq_end_idx,
                if_encode_all=False,
                **encoder_params,
            )

            # init an empty dict with empty list to append emb
            # encoded_dict = defaultdict(list)

            # use the encoder generator for batch emb
            # assume no labels included
            for encoded_batch_dict in self._encoder.encode(
                mut_seqs=ds.sequence,
                batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
            ):

                for emb_layer, emb in encoded_batch_dict.items():
                    getattr(f.root, "layer" + str(emb_layer)).append(emb)

In [6]:
GenerateEmbeddings( dataset_path="data/proeng/thermo/mixed_split.csv",
                    encoder_name="esm1_t6_43M_UR50S",
                    reset_param = False,
                    resample_param = False,
                    embed_batch_size = 128,
                    flatten_emb = "mean",
                    seq_start_idx = False,
                    seq_end_idx= False,
                    embed_folder = "embeddings",
                    # **encoder_params,
                    )

<class 'scr.encoding.encoding_classes.ESMEncoder'>
Generating esm1_t6_43M_UR50S upto 6 layer embedding ...


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main
Closing remaining open files:embeddings/proeng/gb1/low_vs_high/esm1_t6_43M_UR50S/mean/test/embedding.h5...doneembeddings/proeng/gb1/low_vs_high/esm1_t6_43M_UR50S/mean/train/embedding.h5...doneembeddings/proeng/gb1/low_vs_high/esm1_t6_43M_UR50S/mean/val/embedding.h5...done
  0%|          | 0/175 [00:00<?, ?it/s]

Making embeddings/proeng/thermo ...
Making embeddings/proeng/thermo/mixed_split ...
Making embeddings/proeng/thermo/mixed_split/esm1_t6_43M_UR50S ...
Making embeddings/proeng/thermo/mixed_split/esm1_t6_43M_UR50S/mean ...
Making embeddings/proeng/thermo/mixed_split/esm1_t6_43M_UR50S/mean/train ...


  1%|          | 1/175 [00:37<1:49:15, 37.67s/it]