# Dataset Preparation

### Imports

In [6]:
import h5py
import os
from pathlib import Path

import pandas as pd
from tqdm.notebook import tqdm
import pickle
import pandas as pd
from itertools import chain
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import hashlib
import plotly.express as px

from Bio import SeqIO
from Bio.Seq import Seq, back_transcribe
from Bio.SeqRecord import SeqRecord
from io import StringIO
from typing import List
import dnachisel
from dnachisel.biotools import reverse_translate

In [2]:
def parse_3line(path: str, label: str) -> List[SeqRecord]:
    file = open(path)
    def record_for(id: str, seq: str, anno: str) -> SeqRecord:
        amino_seq = seq.strip().replace('X', '*')
        base_seq = reverse_translate(amino_seq)
        # put back together to feed into SeqIO
        # alternative would be to manually construct SeqRecord here
        record = SeqIO.read(StringIO(id + base_seq), "fasta")
        record.annotations["tmh"] = anno.strip()
        record.annotations["label"] = label
        assert record.translate().seq == amino_seq # sanity check
        return record

    return [record_for(id, seq, anno) for id, seq, anno in zip(file, file, file)]

def as_dict(record: SeqRecord, embeddings_file) -> dict:
    seq = str(record.translate().seq)
    hash = hashlib.md5(seq.encode("UTF-8")).hexdigest()
    embedding = embeddings_file.get(hash)
    mean_embedding = np.mean(embedding, axis=0) if embedding is not None else np.zeros(1024)
    return {
        "protein": record.name,
        "base_seq": str(record.seq),
        "seq": seq,
        "seq_anno": record.annotations["tmh"],
        "seq_hash": hash,
        "embedding": mean_embedding,
        "class": record.annotations["label"]
    }


In [3]:
data_root = Path("/Users/fga/data/tmh") # adjust this path... it leads to the 3line files
embeddings = h5py.File(data_root / "embeddings.h5")
labels = ["Glob", "Glob_SP", "TM", "TM_SP"]
all_dicts = [as_dict(record, embeddings) for label in tqdm(labels) for record in tqdm(parse_3line(data_root / f"labels/{label}.3line", label=label))]
df = pd.DataFrame(all_dicts)

  0%|          | 0/2927 [00:00<?, ?it/s]

  0%|          | 0/1314 [00:00<?, ?it/s]

  0%|          | 0/286 [00:00<?, ?it/s]

  0%|          | 0/627 [00:00<?, ?it/s]

In [4]:
df.head()

Unnamed: 0,protein,base_seq,seq,seq_anno,seq_hash,embedding,class
0,P38448,ATGAAAAATTGGAAAACTTCTGCTGAACAAATTTTAACTGCTGGTC...,MKNWKTSAEQILTAGPVVPVIVINKLEHAVPMAKALVAGGVRVLEL...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,54e684071346a58b0a3ba296aadbbbb1,"[0.06995, 0.078, 0.014824, 0.01726, -0.007504,...",Glob
1,Q00594,ATGTTAGGTCAAATGATGCGTAATCAATTAGTTATTGGTTCTTTAG...,MLGQMMRNQLVIGSLVEHAARYHGAREVVSVETSGEVTRSCWKEVE...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,1a386478aa9da27641954d39dae2137e,"[0.02861, 0.04712, 0.02577, -0.01444, -0.0017,...",Glob
2,P27888,ATGTTAAATATTAATTTTGTTAATGAAGAATCTTCTACTAATCAAG...,MLNINFVNEESSTNQGLIVFIDEQLKLNNNLIALDQQHYELISKTI...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,1086f33fa3767c01bc5cb3349293292b,"[0.03363, -0.02246, 0.00748, 0.02954, -0.04858...",Glob
3,P0CL03,ATGTTACCTGATAAAGGTTGGTTAGTTGAAGCTCGTCGTGTTCCTT...,MLPDKGWLVEARRVPSPHYDCRPDDEKPSLLVVHNISLPPGEFGGP...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,0fd7f42849b7643b61c81d2e1bab7551,"[0.02228, 0.05377, 0.002089, -0.010506, 0.0328...",Glob
4,P15034,ATGTCTGAAATTTCTCGTCAAGAATTTCAACGTCGTCGTCAAGCTT...,MSEISRQEFQRRRQALVEQMQPGSAALIFAAPEVTRSADSEYPYRQ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,275512afb21914aefcc9a9a502193acd,"[0.08923, 0.1051, 0.012955, 0.02536, 0.00245, ...",Glob


In [5]:
df.to_pickle(data_root / "processed.pkl")

In [8]:
px.histogram(df["class"])