In [None]:
!pip3 install transformers
!pip3 install pyfaidx

from torch.utils.data import DataLoader, Dataset
import torch
from transformers import BertTokenizer, BertModel
from pathlib import Path
from pyfaidx import Fasta
from typing import Dict, Tuple, List
import json
import h5py
import tqdm
from google.colab import files

Collecting transformers
  Downloading transformers-4.19.1-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 10.4 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 53.3 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.6.0-py3-none-any.whl (84 kB)
[K     |████████████████████████████████| 84 kB 3.9 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 40.9 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed huggingface

In [None]:
class TokenSeqDataset(Dataset):
    def __init__(self, fasta_path: Path, labels_dict: Dict[str, int], max_num_residues: int, protbert_cache: Path, device: torch.device):
        self.device = device
        self.max_num_residues = max_num_residues
        self.labels = labels_dict
        self.data = self.parse_fasta_input(fasta_path)
        self.protbert_cache = protbert_cache
        self.protbert = self.load_model()
        self.tokenizer = self.load_tokenizer()


    def load_tokenizer(self) -> BertTokenizer:
        return BertTokenizer.from_pretrained('Rostlab/prot_bert_bfd', do_lower_case=False, cache_dir=self.protbert_cache)

    def load_model(self) -> BertModel:
        model = BertModel.from_pretrained('Rostlab/prot_bert_bfd', cache_dir=self.protbert_cache)
        model = model.to(self.device)
        return model

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        key = self.keys[index]
        label = self.labels[key]
        seq = self.data[key][:self.max_num_residues]
        tokens, attention_mask = self.tokenize(seq)
        tokens = tokens.to(self.device)
        attention_mask = tokens.to(self.device)
        embedding = self.embedd(tokens, attention_mask)

        return (key, embedding)

    def __len__(self) -> int:
        return len(self.data)

    def tokenize(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
        seq = [" ".join(seq)]
        tokenized = self.tokenizer(text=seq, padding='max_length', max_length=self.max_num_residues+2, add_special_tokens=True, return_tensors='pt') 

        return tokenized['input_ids'], tokenized['attention_mask']

    def embedd(self, tokens: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            embedding = self.protbert.forward(input_ids=tokens, attention_mask=attention_mask)[0]

        embedding = torch.mean(embedding, 1)
        embedding = torch.squeeze(embedding)

        return embedding


    def parse_fasta_input(self, input_file: Path) -> Dict[str, str]:
        fasta = Fasta(str(input_file))
        self.data = {key:str(fasta[key]) for key in fasta.keys()}
        self.keys = list(fasta.keys())
        return self.data


In [None]:
with open("/content/drive/MyDrive/data_exe2/train_lbl.json") as f:
  labels = json.load(f)

dataset = TokenSeqDataset(fasta_path="/content/drive/MyDrive/data_exe2/train_seqs.fasta", labels_dict=labels, max_num_residues=1024, protbert_cache="/content/drive/MyDrive/protbert_weights", device=torch.device('cuda'))


Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
with h5py.File('embeddings.h5', 'w') as hf:
  for id, embedding in tqdm.notebook.tqdm(dataset, leave=True, ascii=True):
    hf.create_dataset(id, data=embedding.cpu())

files.download('embeddings.h5')

  0%|          | 0/18650 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>