In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import torch
from torch import mps
from transformers import (
    T5EncoderModel,
    )
import src.config as config
from src.model import (
    get_prottrans_tokenizer_model,
    df_to_dataset,
    inject_linear_layer,
    compute_metrics_fast
    )
from src.utils import get_project_root_path
import umap
import plotly.express as px
import gc
from tqdm import tqdm

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [3]:
base_model_name = config.base_model_name
print("Base Model:\t", base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 True
Path:		 /Users/finnlueth/Developer/gits/prottrans-t5-signalpeptide-prediction
Using device:	 mps


In [4]:
df_data = pd.read_parquet(ROOT + '/data/processed/5.0_train_full.parquet.gzip')

In [6]:
# df_data[df_data.Split.isin([4])]

In [None]:
base_model_name = config.base_model_name
model_architecture = T5EncoderModel
t5_tokenizer, t5_base_model = get_prottrans_tokenizer_model(base_model_name, model_architecture)

In [None]:
ds_test = df_data[df_data.Split.isin([4])]
ds_test = df_to_dataset(
    t5_tokenizer,
    ds_test.Sequence.to_list(),
    ds_test.Label.to_list()
)

In [None]:
test_tensor = torch.tensor(ds_test['input_ids']).to(device)

In [7]:
# test_tensor.shape

In [None]:
test_tensor_0 = test_tensor#[:100]#.unsqueeze(0)

In [None]:
# test_tensor_0.shape

In [None]:
batch_size = 100
n_batches = (test_tensor_0.size(0) + batch_size - 1) // batch_size
print(n_batches)

In [None]:
for i in tqdm(range(n_batches), desc="Processing Batches"):
    batch = test_tensor_0[i * batch_size:(i + 1) * batch_size]
    
    with torch.no_grad():
        batch_predictions = t5_base_model(batch)
    gc.collect()
    mps.empty_cache()
    
    if i == 0:
        extracted_embeddings = batch_predictions.last_hidden_state.to('cpu')
    else:
        extracted_embeddings = torch.cat((extracted_embeddings, batch_predictions.last_hidden_state.to('cpu')), dim=0)

In [None]:
torch.save(extracted_embeddings, ROOT + '/data/processed/5.0_train_full_embeddings.pt')

---

In [8]:
extracted_embeddings = torch.load(ROOT + '/data/processed/5.0_train_full_embeddings.pt')

In [9]:
extracted_embeddings.shape

torch.Size([4147, 71, 1024])

In [57]:
flattened_output = extracted_embeddings.view(extracted_embeddings.size(0), -1)

In [58]:
split_outputs = [flattened_output[i].cpu().numpy() for i in range(flattened_output.size(0))]

In [59]:
df_data_test = df_data[df_data.Split.isin([4])]
df_data_test.reset_index(drop=True, inplace=True)
df_data_test.insert(0, 'Split_Output', split_outputs)

In [60]:
df_data_test.head()

Unnamed: 0,Split_Output,Uniprot_AC,Kingdom,Type,Partition_No,Sequence,Label,Split
0,"[0.046009257, -0.28403857, -0.3852475, 0.23229...",P55317,EUKARYA,NO_SP,4,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4
1,"[0.03449983, -0.25233316, -0.31544554, 0.22147...",P35583,EUKARYA,NO_SP,4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4
2,"[0.35288915, -0.2297022, 0.24450038, 0.2898977...",Q8UVD9,EUKARYA,NO_SP,4,M E I S T P D F G F G T E D S S A Q Q S A N R ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4
3,"[0.26283765, -0.22823627, 0.014691218, 0.10528...",Q99PF5,EUKARYA,NO_SP,4,M S D Y S T G G P P P G P P P P A G G G G G A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4
4,"[0.023035612, -0.21926892, -0.03637588, 0.0277...",Q9URU9,EUKARYA,NO_SP,4,M N F R P E Q Q Y I L E K P G I L L S F E Q L ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4


In [61]:
df_data_test.loc[0].Split_Output.shape

(72704,)

In [62]:
extracted_embeddings.view(extracted_embeddings.size(0), -1).shape

torch.Size([4147, 72704])

In [63]:
umap_2d = umap.UMAP(n_components=2, random_state=42)
umap_2d_embeddings = umap_2d.fit_transform(df_data_test.Split_Output.to_list())


n_jobs value -1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [64]:
df_data_test = pd.concat([df_data_test, pd.DataFrame(umap_2d_embeddings, columns=['2d_x', '2d_y'])], axis=1)

In [65]:
df_data_test.head()

Unnamed: 0,Split_Output,Uniprot_AC,Kingdom,Type,Partition_No,Sequence,Label,Split,2d_x,2d_y
0,"[0.046009257, -0.28403857, -0.3852475, 0.23229...",P55317,EUKARYA,NO_SP,4,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,15.535522,0.998961
1,"[0.03449983, -0.25233316, -0.31544554, 0.22147...",P35583,EUKARYA,NO_SP,4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,15.521823,0.992337
2,"[0.35288915, -0.2297022, 0.24450038, 0.2898977...",Q8UVD9,EUKARYA,NO_SP,4,M E I S T P D F G F G T E D S S A Q Q S A N R ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,13.42875,1.388388
3,"[0.26283765, -0.22823627, 0.014691218, 0.10528...",Q99PF5,EUKARYA,NO_SP,4,M S D Y S T G G P P P G P P P P A G G G G G A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,14.318105,1.394176
4,"[0.023035612, -0.21926892, -0.03637588, 0.0277...",Q9URU9,EUKARYA,NO_SP,4,M N F R P E Q Q Y I L E K P G I L L S F E Q L ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,13.863209,2.725915


In [78]:
fig = px.scatter(
    df_data_test,
    x='2d_x',
    y='2d_y',
    title='UMAP on ProtTransT5 Embeddings SignalP5.0 Dataset Split 4',
    color='Type',
    hover_data=['Uniprot_AC', 'Sequence', 'Kingdom', 'Type'],
    # color_discrete_sequence=px.colors.qualitative.Vivid_r,
    )

# fig.update_layout(
#     margin=dict(l=30, r=30, t=30, b=30),
# )

fig.write_image("./plots/umap_1_2d.png")

fig.show()

---

In [67]:
umap_3d = umap.UMAP(n_components=3, random_state=42)
umap_3d_embeddings = umap_3d.fit_transform(df_data_test.Split_Output.to_list())


n_jobs value -1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [70]:
df_data_test = pd.concat([df_data_test, pd.DataFrame(umap_3d_embeddings, columns=['3d_x', '3d_y', '3d_z'])], axis=1)

In [71]:
df_data_test.head()

Unnamed: 0,Split_Output,Uniprot_AC,Kingdom,Type,Partition_No,Sequence,Label,Split,2d_x,2d_y,3d_x,3d_y,3d_z
0,"[0.046009257, -0.28403857, -0.3852475, 0.23229...",P55317,EUKARYA,NO_SP,4,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,15.535522,0.998961,-1.527074,10.921324,12.058817
1,"[0.03449983, -0.25233316, -0.31544554, 0.22147...",P35583,EUKARYA,NO_SP,4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,15.521823,0.992337,-1.522945,10.927733,12.018054
2,"[0.35288915, -0.2297022, 0.24450038, 0.2898977...",Q8UVD9,EUKARYA,NO_SP,4,M E I S T P D F G F G T E D S S A Q Q S A N R ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,13.42875,1.388388,-1.88047,11.255307,13.102851
3,"[0.26283765, -0.22823627, 0.014691218, 0.10528...",Q99PF5,EUKARYA,NO_SP,4,M S D Y S T G G P P P G P P P P A G G G G G A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,14.318105,1.394176,-1.192153,11.118048,13.23115
4,"[0.023035612, -0.21926892, -0.03637588, 0.0277...",Q9URU9,EUKARYA,NO_SP,4,M N F R P E Q Q Y I L E K P G I L L S F E Q L ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4,13.863209,2.725915,-0.932008,11.361708,11.746972


In [75]:
fig = px.scatter_3d(
    df_data_test,
    x='3d_x',
    y='3d_y',
    z='3d_z',
    title='UMAP on ProtTransT5 Embeddings SignalP5.0 Dataset Split 4',
    color='Type',
    hover_data=['Sequence', 'Kingdom', 'Type']
    )

fig.write_image("./plots/umap_1_3d.png")

fig.show()