# Hyperbolical Embedding of CATH Labels and Protein Sequences

Hierarchy:
[CATH Labels Hierarchy](http://www.cathdb.info/browse/tree)

Embedding Method:
[Poincaré Embeddings for Learning Hierarchical Representations](https://papers.nips.cc/paper/2017/file/59dfa2df42d9e3d41f5b02bfc32229dd-Paper.pdf) & 
[Joint Learning of Hyperbolic Label Embeddings for Hierarchical Multi-label Classification: hierarchy embedding method](https://github.com/geoopt/geoopt/blob/master/examples/hyperbolic_multiclass_classification.ipynb)

Loss function for embedding a known hierarchy:
[Unit Ball Model for Embedding Hierarchical Structures in the Complex Hyperbolic Space](https://arxiv.org/pdf/2105.03966v3.pdf)

Plots and manifold computations:
[Geoopt library and usage example](https://github.com/geoopt/geoopt/blob/master/examples/hyperbolic_multiclass_classification.ipynb)

Protein sequence embedding:
[T5 model for protein sequences](https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc)

## Setup

In [1]:
import torch
torch.manual_seed(1)
torch.cuda.manual_seed(1)

from tqdm import tqdm
from datetime import datetime
import logging

import os, sys
sys.path.append(os.path.join(os.getcwd(), '../..'))

import numpy as np
import matplotlib.pyplot as plt

from transformers import T5Tokenizer, T5EncoderModel

# from src.cath.datasets import LabelDataset
# from src.utils import add_geodesic_grid
# from src.utils import choice

from src.datasets import SequenceLabelDataset, SequenceDataset
from src.cath import CathLabelDataset
from src.losses import SequenceLabelLoss
from src.models import LabelEmbedModel

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

# logging.basicConfig(filename='../../logs/embedding.txt', level=logging.DEBUG)

# timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# logger = torch.utils.tensorboard.SummaryWriter('{}'.format(timestamp))

%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


Using cuda:0


## Data

In [30]:
label_dataset = CathLabelDataset('../../data/cath/labels.csv', 10)
seq_dataset = SequenceDataset('../../data/cath/train.csv', label_dataset.no_labels())

trainloader = torch.utils.data.DataLoader(
    SequenceLabelDataset(seq_dataset, label_dataset),
    batch_size=1,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

## Models

In [27]:
seq_encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
seq_encoder = seq_encoder.to(device)

tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

In [28]:
label_model = LabelEmbedModel(label_dataset.no_labels(), eye=True) #.to(device)
label_model = torch.nn.DataParallel(label_model).to(device) # TODO: Should i do this?

all_labels = torch.arange(label_dataset.no_labels()).to(device)

## Utils

In [5]:
def seq_embd_means_no_inplace_operation(seqs, seq_embds):
    mask = torch.empty((seq_embds.shape[0], seq_embds.shape[1]))

    for i in range(len(seq_embds)):
        mask[i] = torch.where(torch.arange(seq_embds.shape[1]) <= len(seqs[i]) // 2 + 1, 1, 0)

    mask = mask.unsqueeze(2).repeat(1, 1, 1024).to(device)

    return (seq_embds * mask.to(device)).mean(axis=1)

In [6]:
def seq_embd_means_inplace_operation(seqs, seq_embds):
    seq_embd_means = torch.empty((len(seq_embds), seq_embds.shape[-1]))

    for i in range(len(seqs)):
        seq_embd_means[i] = seq_embds[i, :len(seqs[i]) // 2 + 1].mean(axis=0)
        
    return seq_embd_means.to(device)

## Train

In [31]:
loss_fn = SequenceLabelLoss(_lambda=0.1)

optimizer = torch.optim.Adam([
        {'params': label_model.parameters(), 'lr': 0.001},
        {'params': seq_encoder.parameters(), 'lr': 0.001}
    ])

In [36]:
def train_epoch():
    total_loss = 0

    pbar = tqdm(enumerate(trainloader, 1))
    for idx, batch in pbar:
        seqs, labels, edges = batch
        labels, edges = labels.to(device), edges.to(device)

        # TODO: can i tokenize on cuda?
        ids = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")

        with torch.no_grad():
            input_ids = torch.tensor(ids['input_ids']).to(device)
            attention_mask = torch.tensor(ids['attention_mask']).to(device)

        optimizer.zero_grad()

        label_embds = label_model(all_labels)

        seq_embd_repr = seq_encoder(input_ids=input_ids, attention_mask=attention_mask)
        seq_embds     = seq_embd_repr.last_hidden_state

        with torch.no_grad():
            # TODO: which one to use?
            seq_embds = seq_embd_means_no_inplace_operation(seqs, seq_embds)
            # seq_embds =  seq_embd_means_inplace_operation(seqs, seq_embds)
        
        dot = seq_embds @ label_embds.T
        loss = loss_fn(dot, labels, label_model(edges))
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        pbar.set_description(f"Batch {idx} / {len(trainloader)}, loss {loss.item()}")


    return total_loss / len(trainloader)

In [None]:
EPOCHS = 1

pbar = tqdm(range(EPOCHS))
for epoch in range(EPOCHS):

    # logging.info(f"Epoch {epoch+1}/{EPOCHS}")
    label_model.train(True)
    seq_encoder.train(True)

    loss = train_epoch()
    
    pbar.set_description(f"Epoch {epoch}, loss {loss.item()}")
    
    # label_model.eval()
    # seq_encoder.eval()

    # logging.info(f"\ttrain Loss {loss:.6f}")

## Playground

In [10]:
a = torch.where(torch.arange(3) < 1, 1, 0).unsqueeze(1).repeat(1, 2).to(float)
a = a.unsqueeze(0)
a

tensor([[[1., 1.],
         [0., 0.],
         [0., 0.]]], dtype=torch.float64)