In [23]:
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset

tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
data = pd.read_csv("/opt/scratch/labs/wuc/Drug-Repurposing/data/disease_features.csv")

In [34]:
data.fillna("", inplace=True)

# Change the columns depending on the data
data["agg"] = data[['mondo_name', 'group_id_bert',
       'group_name_bert', 'mondo_definition', 'umls_description',
       'orphanet_definition', 'orphanet_prevalence', 'orphanet_epidemiology',
       'orphanet_clinical_description', 'orphanet_management_and_treatment',
       'mayo_symptoms', 'mayo_causes', 'mayo_risk_factors',
       'mayo_complications', 'mayo_prevention', 'mayo_see_doc']].agg(' '.join, axis=1)

In [37]:
# Tokenize each sentence in the 'text' column
tokenized_data = data['agg'].apply(lambda x: tokenizer(x, padding='max_length', truncation=True, max_length=512, return_tensors='pt'))

# Extract input_ids and attention_mask as lists of tensors
input_ids = [item['input_ids'].squeeze(0) for item in tokenized_data]
attention_mask = [item['attention_mask'].squeeze(0) for item in tokenized_data]

# Stack the lists of tensors into single tensors for input to the model
input_ids = torch.stack(input_ids)
attention_mask = torch.stack(attention_mask)

In [38]:
# Ensure inputs are on the same device as the model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Create a TensorDataset and DataLoader
batch_size = 128  # Adjust batch size based on memory capacity
dataset = TensorDataset(input_ids, attention_mask)
dataloader = DataLoader(dataset, batch_size=batch_size)

# Initialize list to store embeddings
sentence_embeddings = []

# Process each batch separately to save memory
with torch.no_grad():
    for batch in dataloader:
        batch_input_ids, batch_attention_mask = [b.to(device) for b in batch]
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        
        # Get pooled output for each sentence in the batch
        batch_embeddings = outputs.pooler_output
        sentence_embeddings.append(batch_embeddings.cpu())

# Concatenate all embeddings back into a single tensor
sentence_embeddings = torch.cat(sentence_embeddings)

In [39]:
sentence_embeddings.shape

torch.Size([17080, 768])

In [40]:
torch.save(sentence_embeddings, "/opt/scratch/labs/wuc/Drug-Repurposing/data/disease_embeddings.pt")