# Hyperbolical Embedding of CATH Labels and Protein Sequences

Hierarchy:
[CATH Labels Hierarchy](http://www.cathdb.info/browse/tree)

Embedding Method:
[Poincaré Embeddings for Learning Hierarchical Representations](https://papers.nips.cc/paper/2017/file/59dfa2df42d9e3d41f5b02bfc32229dd-Paper.pdf) & 
[Joint Learning of Hyperbolic Label Embeddings for Hierarchical Multi-label Classification: hierarchy embedding method](https://github.com/geoopt/geoopt/blob/master/examples/hyperbolic_multiclass_classification.ipynb)

Loss function for embedding a known hierarchy:
[Unit Ball Model for Embedding Hierarchical Structures in the Complex Hyperbolic Space](https://arxiv.org/pdf/2105.03966v3.pdf)

Plots and manifold computations:
[Geoopt library and usage example](https://github.com/geoopt/geoopt/blob/master/examples/hyperbolic_multiclass_classification.ipynb)

Protein sequence embedding:
[T5 model for protein sequences](https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc)

In [1]:
import torch; torch.manual_seed(1)
from tqdm import tqdm
from datetime import datetime
import logging

import os, sys
sys.path.append(os.path.join(os.getcwd(), '../..'))

import numpy as np
import matplotlib.pyplot as plt

from transformers import T5Tokenizer, T5EncoderModel

# from src.cath.datasets import LabelDataset
# from src.utils import add_geodesic_grid
# from src.utils import choice

from src.datasets import SequenceLabelDataset, SequenceDataset
from src.cath import CathLabelDataset
from src.losses import SequenceLabelLoss
from src.models import LabelEmbedModel

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

# logging.basicConfig(filename='../../logs/embedding.txt', level=logging.DEBUG)

%matplotlib inline

Using cpu


In [2]:
seq_encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
seq_encoder = seq_encoder.to(device) # move model to GPU
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

In [3]:
label_dataset = CathLabelDataset('../../data/cath/labels.csv', 10)
seq_dataset = SequenceDataset('../../data/cath/train.csv', label_dataset.no_labels())

trainloader = torch.utils.data.DataLoader(
    SequenceLabelDataset(seq_dataset, label_dataset),
    batch_size=512,
    shuffle=True,
    num_workers=1,
    pin_memory=True
)

In [4]:
optimizer = torch.optim.Adam(seq_encoder.parameters(), lr=0.001)
loss_fn = SequenceLabelLoss(_lambda=0.1)

label_model = LabelEmbedModel(label_dataset.no_labels())

all_labels = torch.arange(label_dataset.no_labels())

# timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# logger = torch.utils.tensorboard.SummaryWriter('{}'.format(timestamp))

In [7]:
def train_epoch():
    total_loss = 0

    for idx, batch in tqdm(enumerate(trainloader, 1)):
        seqs, labels, edges = batch
        # seqs, labels, edges = seqs.cuda(), labels.cuda(), edges.cuda()
        
        optimizer.zero_grad()

        ids = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")

        input_ids = torch.tensor(ids['input_ids'])
        attention_mask = torch.tensor(ids['attention_mask'])

        print('tokenized')

        seq_emb_repr = seq_encoder(input_ids=input_ids, attention_mask=attention_mask)

        print('encoded')

        label_embs = label_model(all_labels)
        
        print('labels embeds')
        
        seq_embs = seq_emb_repr.last_hidden_state

        print(seq_embs.shape, label_embs.shape)
        seq_embs_mean = torch.empty((len(seq_embs), seq_embs.shape[-1]))

        for i in range(len(seqs)):
            seq_embs_mean[i] = seq_embs[i, :len(seqs[i]) // 2 + 1].mean(axis=0)

        # dot = seq_emb @ label_emb.T
        # loss = loss_fn(dot, labels, label_model(edges))

        # print('batch loss:', loss.item())

        # loss.backward()

        # print('loss done')

        # optimizer.step()

        # print('optimizer done')

        # total_loss += loss.item()

        break


    return total_loss / len(trainloader)

In [8]:
EPOCHS = 1

for epoch in range(EPOCHS):

    # logging.info(f"Epoch {epoch+1}/{EPOCHS}")
    label_model.train(True)
    seq_encoder.train(True)

    loss = train_epoch()
    
    # label_model.eval()
    # seq_encoder.eval()

    # logging.info(f"\ttrain Loss {loss:.6f}")

0it [00:00, ?it/s]

tokenized


0it [00:19, ?it/s]


KeyboardInterrupt: 