In [29]:
import pandas as pd

df = pd.read_csv("../data/domains-and-seqs-merged.csv")


In [30]:
df["homology_path"] = df["class"].astype(str) + "." + \
                  df["architecture"].astype(str) + "." + \
                  df["topology"].astype(str) + "." + \
                  df["homology"].astype(str)


In [31]:
import pandas as pd

MIN_DOMAINS_PER_HOMOLOGY = 10
MAX_DOMAINS_PER_HOMOLOGY = 200

HOMOLOGY_GROUPS = 100
SAMPLES_PER_GROUP = 10

# Define hierarchy columns — this full path defines a unique homology group
hierarchy = ['class', 'architecture', 'topology', 'homology']

# Drop duplicates by s35 within each homology group
# First, create a groupby object on hierarchy columns
grouped = df.groupby(hierarchy)

# Apply drop_duplicates on 's35' to each group and collect results
unique_s35_per_group = []
for name, group in grouped:
    # Drop duplicates by s35 within this specific homology group
    unique_s35 = group.drop_duplicates('s35')
    unique_s35_per_group.append(unique_s35)

# Combine all dataframes with unique s35 values per homology group
df_unique_s35 = pd.concat(unique_s35_per_group)

# Step 1 & 2: Filter groups where the number of domain_id entries is at least 10 and at most 200
filtered_df = df_unique_s35.groupby(hierarchy).filter(lambda x: MIN_DOMAINS_PER_HOMOLOGY <= len(x) <= MAX_DOMAINS_PER_HOMOLOGY)

# Step 3: Get unique full-path homology groups
unique_homology_paths = filtered_df[hierarchy].drop_duplicates()

# Randomly sample 100 unique homology groups (based on full path)
sampled_paths = unique_homology_paths.sample(n=min(HOMOLOGY_GROUPS, len(unique_homology_paths)), random_state=42)

# Step 4: Retain only rows that belong to the sampled groups
sampled_df = pd.merge(sampled_paths, filtered_df, on=hierarchy)

# Within each sampled group, randomly choose 10 domain_id entries
subset = sampled_df.groupby(hierarchy).apply(lambda x: x.sample(n=min(SAMPLES_PER_GROUP, len(x)), random_state=42)).reset_index(drop=True)

# Save to CSV
subset.to_csv("../data/subset.csv", index=False)

  subset = sampled_df.groupby(hierarchy).apply(lambda x: x.sample(n=min(SAMPLES_PER_GROUP, len(x)), random_state=42)).reset_index(drop=True)


In [14]:
import torch
from transformers import T5Tokenizer, T5EncoderModel

# Load ProtT5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model = model.eval()

# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [19]:
# Dictionary to store embeddings
all_embeddings = {}

# Process each sequence
for index, row in subset.head(50).iterrows():
    sequence = row["sequence"]
    sequence = sequence.replace('U', 'X').replace('Z', 'X').replace('O', 'X')
    ids = tokenizer.batch_encode_plus([sequence], add_special_tokens=True, padding=True, return_tensors="pt")
    input_ids = ids['input_ids'].to(device)
    attention_mask = ids['attention_mask'].to(device)

    with torch.no_grad():
        embedding = model(input_ids=input_ids, attention_mask=attention_mask)

    # Average over tokens to get a single vector per sequence
    sequence_embedding = embedding.last_hidden_state.mean(dim=1).squeeze().cpu()

    # Store in dictionary
    all_embeddings[index] = sequence_embedding

    print(f"Processed: {index}")

# Save all embeddings to one file
torch.save(all_embeddings, "../data/all_embeddings.pt")
print("All embeddings saved to all_embeddings.pt")


Processed: 0
Processed: 1
Processed: 2
Processed: 3
Processed: 4
Processed: 5
Processed: 6
Processed: 7
Processed: 8
Processed: 9
Processed: 10
Processed: 11
Processed: 12
Processed: 13
Processed: 14
Processed: 15
Processed: 16
Processed: 17
Processed: 18
Processed: 19
Processed: 20
Processed: 21
Processed: 22
Processed: 23
Processed: 24
Processed: 25
Processed: 26
Processed: 27
Processed: 28
Processed: 29
Processed: 30
Processed: 31
Processed: 32
Processed: 33
Processed: 34
Processed: 35
Processed: 36
Processed: 37
Processed: 38
Processed: 39
Processed: 40
Processed: 41
Processed: 42
Processed: 43
Processed: 44
Processed: 45
Processed: 46
Processed: 47
Processed: 48
Processed: 49
All embeddings saved to all_embeddings.pt


In [None]:
### Would you recommend to cluster by s35 to avoid overlaps
### Should we trim / pick equally sized seqs lengths / How should we pad
### 