In [1]:
import time
notebook_start_time = time.perf_counter()

# NetOGlyc4 data embedding augmentation with labels

Important notes:
- It is recommended to make a copy of the embeddings just in case something goes wrong while appending to it
- H5 datasets are initialized to 0, so make sure that you are either not using uninitialized values or zeros don't matter
- string/object labels with variable lengths will probably not convert correctly, instead copy them manually
- labels should have the shape (n_seqs, seq_length, ...), for other types of labels copy them manually

## Imports

### Built-in imports

In [2]:
import math
import gzip
import pickle
from pathlib import Path
import warnings
import re

### Shared library imports

### External imports

In [3]:
import numpy as np
import pandas as pd
import h5py
from tqdm.auto import tqdm

## Paths & Constants

In [4]:
#BASE_DIR = Path("/mnt/g/My Drive/CloudVault/Masters/Data")
BASE_DIR = Path("/home/jakob/Cloudvault_new/Data")

# Paths of imported label files to add to embeddings file
# Existing labels in embedding files will not be overwritten, for this use LABEL_NAME_MAPPINGS
LABEL_FILES = [
    BASE_DIR/'NetOGlyc5 data'/'GalNAc data'/'05-embedding'/'netoglyc4_protein_glyc_labels_max.h5',
    BASE_DIR/'NetOGlyc5 data'/'GalNAc data'/'05-embedding'/'netoglyc4_protein_netsurfp_output.h5',
]

# Paths of embedding files to add labels to
EMBEDDING_FILES = [
    BASE_DIR/'NetOGlyc5 data'/'GalNAc data'/'05-embedding'/'netoglyc4_protein_embeddings_netsurfp_output_glyc_labels_max.h5',
]

# Add any label name mappings here if you need to add labels that already exist in embedding files
LABEL_NAME_MAPPINGS = {
#    'gly': 'gly_new',
}

# If True, ignores any proteins in label files that are not in embedding files
IGNORE_NONEMBEDDED_PROTEINS = True

# If True, will truncate labels to match embeddings file
ALLOW_TRUNCATED_LABELS = True

## Add labels to protein embeddings

In [5]:
# cast_type - numpy dtypes in pytorch tensors:
#     'u1': torch.uint8
#     'i1': torch.int8
#     'i2': torch.int16
#     'i4': torch.int32
#     'i8': torch.int64
#     'f2': torch.float16
#     'f4': torch.float32
#     'f8': torch.float64
#     'F': torch.complex64
#     'D': torch.complex128
#     '?': torch.bool

for embedding_file_path in EMBEDDING_FILES:
    with h5py.File(embedding_file_path, 'a') as embedding_file:

        embedding_identifiers_list = embedding_file['identifiers'].asstr()[:].tolist()
        embedding_identifiers_set = set(embedding_identifiers_list)
        embedding_sequences_list = embedding_file['sequences'].asstr()[:].tolist()
        embedding_n_seqs = len(embedding_identifiers_list)
        embedding_max_seq_length = max(len(seq) for seq in embedding_sequences_list)

        for label_file_path in LABEL_FILES:
            with h5py.File(label_file_path, 'r') as label_file:
                print(f"Adding labels to '{embedding_file_path}' from '{label_file_path}'")

                label_identifiers_list = label_file['identifiers'].asstr()[:].tolist()
                label_identifiers_set = set(label_identifiers_list)
                label_sequences_list = label_file['sequences'].asstr()[:].tolist()
                label_n_seqs = len(label_identifiers_list)
                label_max_seq_length = max(len(seq) for seq in label_sequences_list)

                # Verifications that there won't be any issues when augmenting embedding file

                if embedding_identifiers_set != label_identifiers_set:
                    missing_identifiers = embedding_identifiers_set - label_identifiers_set
                    ignored_identifiers = label_identifiers_set - embedding_identifiers_set
                    if len(missing_identifiers) > 0:
                        raise Exception(f"{len(missing_identifiers)} embedding proteins from '{embedding_file_path}' not found in '{label_file_path}': {missing_identifiers}")
                    elif IGNORE_NONEMBEDDED_PROTEINS:
                        print(f"Ignored {len(ignored_identifiers)} label proteins from '{label_file_path}' not found in '{embedding_file_path}': {ignored_identifiers}")
                    else:
                        raise Exception(f"{len(ignored_identifiers)} label proteins from '{label_file_path}' not found in '{embedding_file_path}': {ignored_identifiers}")

                if embedding_max_seq_length < label_max_seq_length:
                    if ALLOW_TRUNCATED_LABELS:
                        print(f"Truncating labels from '{label_file_path}' to {embedding_max_seq_length} in '{embedding_file_path}'")
                    else:
                        raise Exception(f"Labels from '{label_file_path}' are truncated compared to embeddings in '{embedding_file_path}'")
                elif label_max_seq_length < embedding_max_seq_length:
                    raise Exception(f"Embeddings from '{embedding_file_path}' are truncated compared to labels in '{label_file_path}'")

                embedding_to_label_indices_mapping = [label_identifiers_list.index(embedding_identifiers_list[idx]) for idx in range(embedding_n_seqs)]

                for embedding_idx, label_idx in enumerate(embedding_to_label_indices_mapping):
                    if embedding_sequences_list[embedding_idx] != label_sequences_list[label_idx][:embedding_max_seq_length]:
                        raise Exception(f"Label sequence of '{embedding_identifiers_list[embedding_idx]}' from '{label_file_path}' did not match Embedding sequence from '{embedding_file_path}'")

                # Need to re-fetch these for each label file as they may have changed
                if 'labels' in embedding_file:
                    embedding_label_names = list(embedding_file['labels'])
                else:
                    embedding_label_names = []
                    embedding_label_shapes = []
                label_labels_group = label_file['labels']
                label_label_names = list(label_labels_group)
                
                for label in label_label_names:
                    mapped_label = LABEL_NAME_MAPPINGS.get(label, label)
                    label_shape = label_labels_group[label].shape
                    if mapped_label in embedding_label_names:
                        raise Exception(f"Label '{label}' from '{label_file_path}' was already in '{embedding_file_path}' as label '{mapped_label}'")
                    if label_shape[0] != embedding_n_seqs or label_shape[1] != embedding_max_seq_length:
                        raise Exception(f"Shape of label '{label}' from '{label_file_path}' does not match embeddings from '{embedding_file_path}'")
                    
                # Embedding augmentation with labels

                if 'labels' in embedding_file:
                    embedding_labels_group = embedding_file['labels']
                else:
                    embedding_labels_group = embedding_file.create_group('labels')
                    
                for label in label_labels_group:
                    mapped_label = LABEL_NAME_MAPPINGS.get(label, label)
                    label_label_dataset = label_labels_group[label]
                    label_shape = label_label_dataset.shape
                    label_dtype = label_label_dataset.dtype
                    embedding_label_dataset = embedding_labels_group.create_dataset(mapped_label, (embedding_n_seqs, embedding_max_seq_length, *label_shape[2:]), dtype=label_dtype, maxshape=(None, None, *label_shape[2:]))
                    for attr_key, attr_value in label_label_dataset.attrs.items():
                        embedding_label_dataset.attrs[attr_key] = attr_value
                    embedding_label_dataset[:] = label_label_dataset[embedding_to_label_indices_mapping, :embedding_max_seq_length]


Adding labels to '/home/jakob/Cloudvault_new/Data/NetOGlyc5 data/GalNAc data/05-embedding/netoglyc4_protein_embeddings_netsurfp_output_glyc_labels_max.h5' from '/home/jakob/Cloudvault_new/Data/NetOGlyc5 data/GalNAc data/05-embedding/netoglyc4_protein_glyc_labels_max.h5'
Adding labels to '/home/jakob/Cloudvault_new/Data/NetOGlyc5 data/GalNAc data/05-embedding/netoglyc4_protein_embeddings_netsurfp_output_glyc_labels_max.h5' from '/home/jakob/Cloudvault_new/Data/NetOGlyc5 data/GalNAc data/05-embedding/netoglyc4_protein_netsurfp_output.h5'


In [6]:
notebook_end_time = time.perf_counter()
print(f"Notebook took {notebook_end_time-notebook_start_time} seconds to run")

Notebook took 2.5119755630148575 seconds to run
