# Hyperbolical Embedding of CATH Labels and Protein Sequences

Hierarchy:
[CATH Labels Hierarchy](http://www.cathdb.info/browse/tree)

Embedding Method:
[Poincaré Embeddings for Learning Hierarchical Representations](https://papers.nips.cc/paper/2017/file/59dfa2df42d9e3d41f5b02bfc32229dd-Paper.pdf) & 
[Joint Learning of Hyperbolic Label Embeddings for Hierarchical Multi-label Classification: hierarchy embedding method](https://github.com/geoopt/geoopt/blob/master/examples/hyperbolic_multiclass_classification.ipynb)

Loss function for embedding a known hierarchy:
[Unit Ball Model for Embedding Hierarchical Structures in the Complex Hyperbolic Space](https://arxiv.org/pdf/2105.03966v3.pdf)

Plots and manifold computations:
[Geoopt library and usage example](https://github.com/geoopt/geoopt/blob/master/examples/hyperbolic_multiclass_classification.ipynb)

Protein sequence embedding:
[T5 model for protein sequences](https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc)

## Setup

In [1]:
import torch
torch.manual_seed(1)
torch.cuda.manual_seed(1)

from tqdm import tqdm
from datetime import datetime
import logging

import os, sys
sys.path.append(os.path.join(os.getcwd(), '../..'))

import numpy as np
import matplotlib.pyplot as plt

from transformers import T5Tokenizer, T5EncoderModel

# from src.cath.datasets import LabelDataset
# from src.utils import add_geodesic_grid
# from src.utils import choice

from src.datasets import SequenceLabelDataset, SequenceDataset
from src.cath import CathLabelDataset
from src.losses import SequenceLabelLoss, SequenceLoss
from src.models import LabelEmbedModel

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

# logging.basicConfig(filename='../../logs/embedding.txt', level=logging.DEBUG)

# logger = torch.utils.tensorboard.SummaryWriter('{}'.format(timestamp))

# timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# logger = torch.utils.tensorboard.SummaryWriter('logs/embedding_{}'.format(timestamp))

%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


Using cuda:0


## Data

In [2]:
label_dataset = CathLabelDataset('../../data/cath/labels.csv', 10)
seq_train_dataset = SequenceDataset('../../data/cath/train.csv', label_dataset.no_labels())

trainloader = torch.utils.data.DataLoader(
    SequenceLabelDataset(seq_train_dataset, label_dataset),
    batch_size=1,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

In [3]:
seq_val_dataset = SequenceDataset('../../data/cath/val.csv', label_dataset.no_labels())

valloader = torch.utils.data.DataLoader(
    seq_val_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

## Models

In [4]:
seq_encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
seq_encoder = seq_encoder.to(device)

tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

In [5]:
label_model = LabelEmbedModel(label_dataset.no_labels(), eye=True) #.to(device)
label_model = torch.nn.DataParallel(label_model).to(device) # TODO: Should i do this?

all_labels = torch.arange(label_dataset.no_labels()).to(device)

## Utils

In [6]:
def seq_embd_means_no_inplace_operation(seqs, seq_embds):
    mask = torch.empty((seq_embds.shape[0], seq_embds.shape[1]))

    for i in range(len(seq_embds)):
        mask[i] = torch.where(torch.arange(seq_embds.shape[1]) <= len(seqs[i]) // 2 + 1, 1, 0)

    mask = mask.unsqueeze(2).repeat(1, 1, 1024).to(device)

    return (seq_embds * mask.to(device)).mean(axis=1)

In [7]:
def seq_embd_means_inplace_operation(seqs, seq_embds):
    seq_embd_means = torch.empty((len(seq_embds), seq_embds.shape[-1]))

    for i in range(len(seqs)):
        seq_embd_means[i] = seq_embds[i, :len(seqs[i]) // 2 + 1].mean(axis=0)
        
    return seq_embd_means.to(device)

## Train

In [8]:
train_loss = SequenceLabelLoss(_lambda=0.1)
eval_loss = SequenceLoss()

optimizer = torch.optim.Adam([
        {'params': label_model.parameters(), 'lr': 0.001},
        {'params': seq_encoder.parameters(), 'lr': 0.001}
    ])

In [9]:
def compute_logits(seqs):
    # TODO: can i tokenize on gpu?
    ids = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")

    with torch.no_grad():
        input_ids = torch.tensor(ids['input_ids']).to(device)
        attention_mask = torch.tensor(ids['attention_mask']).to(device)

    label_embds = label_model(all_labels)

    seq_embd_repr = seq_encoder(input_ids=input_ids, attention_mask=attention_mask)
    seq_embds     = seq_embd_repr.last_hidden_state

    with torch.no_grad():
        # TODO: which one to use?
        seq_embds = seq_embd_means_no_inplace_operation(seqs, seq_embds)
        # seq_embds =  seq_embd_means_inplace_operation(seqs, seq_embds)

    return seq_embds @ label_embds.T

In [10]:
def evaluate():
    running_loss = 0
    
    label_model.train(False)
    seq_encoder.train(False)
    
    for idx, batch in enumerate(valloader):
        seqs, labels = batch
        
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        logits = compute_logits(seqs)
        loss = eval_loss(logits, labels)
        
        running_loss += loss.item()
    
    return running_loss / len(valloader)
    

In [14]:
def train():
    running_loss = 0
    last_loss = 0
    
    pbar = tqdm(enumerate(trainloader))
    
    label_model.train(True)
    seq_encoder.train(True)
        
    for idx, batch in pbar:
        seqs, labels, edges = batch
        labels, edges = labels.to(device), edges.to(device)
        
        optimizer.zero_grad()
        
        logits = compute_logits(seqs)
        loss =  train_loss(logits, labels, label_model(edges))
        
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        # REPORT
        
        if idx % 500 == 499:
            print(f'validation loss: {evaluate()}')
            label_model.train(True)
            seq_encoder.train(True)

        if idx % 50 == 49:
            last_loss = running_loss / 50
            pbar.set_description(f"Batch {idx} / {len(trainloader)}, loss {loss.item()}")
            running_loss = 0.

    return last_loss

In [None]:
EPOCHS = 1

pbar = tqdm(range(EPOCHS))
for epoch in range(EPOCHS):
    loss = train()
    
    pbar.set_description(f"Epoch {epoch}, loss {loss.item()}")


  0%|          | 0/1 [00:11<?, ?it/s][A
Batch 499 / 16292, loss 0.2808671318900209: : 500it [00:44,  1.63s/it] 

validation loss: 0.047186862360913795


Batch 999 / 16292, loss 0.23617966498120938: : 1004it [01:27,  1.12s/it]

validation loss: 0.03915620674522501


Batch 1499 / 16292, loss 0.33049736534960683: : 1502it [02:10,  1.17s/it]

validation loss: 0.033846691709725804


Batch 1999 / 16292, loss 0.2844398075113364: : 2001it [02:52,  1.41s/it] 

validation loss: 0.02978396832131858


Batch 2499 / 16292, loss 0.26878058582655956: : 2501it [03:36,  1.84s/it]

validation loss: 0.026668220270902668


Batch 2999 / 16292, loss 0.2883907044419228: : 3004it [04:22,  1.12s/it] 

validation loss: 0.0240701770636582


Batch 3499 / 16292, loss 0.16909000094042623: : 3502it [05:05,  1.27s/it]

validation loss: 0.021944884514876705


Batch 3999 / 16292, loss 0.32136328722181384: : 4003it [05:48,  1.12s/it]

validation loss: 0.02015962071898317


Batch 4499 / 16292, loss 0.22879372132330397: : 4502it [06:31,  1.22s/it]

validation loss: 0.018648089219694925


Batch 4999 / 16292, loss 0.16397730413033265: : 5003it [07:14,  1.08s/it]

validation loss: 0.017303401942273672


Batch 5499 / 16292, loss 0.25279350309019466: : 5504it [07:57,  1.04s/it]

validation loss: 0.016127039247421617


Batch 5999 / 16292, loss 0.16687490567297456: : 6002it [08:41,  1.22s/it]

validation loss: 0.015100147737112816


Batch 6499 / 16292, loss 0.2405601055158807: : 6501it [09:24,  1.52s/it] 

validation loss: 0.014162905533842071


Batch 6999 / 16292, loss 0.17428385072900263: : 7002it [10:06,  1.14s/it]

validation loss: 0.013303213168855976


Batch 7249 / 16292, loss 0.17409555257177567: : 7261it [10:22, 14.77it/s]

## Playground