In [None]:
import gc
import os
import re
import random
from datetime import datetime
import yaml

import numpy as np
import pandas as pd
import torch
from datasets import load_from_disk
from peft import (
    LoraConfig,
    get_peft_model,
)
from peft.utils.constants import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
from transformers import (
    AutoConfig,
    AutoTokenizer,
    T5Tokenizer,
    TrainingArguments,
)

from src.model.configuration_protein_clip import ProtT5CLIPConfig
from src.model.data_collator_multi_input import DataCollatorForProtT5CLIP
from src.model.modeling_protein_clip import ProtT5CLIP
from src.model.trainer_protein_subset import ProteinSampleSubsetTrainer
from src.model.metrics import metrics_factory
import src.model.utils as utils


def process_sequence(sequence):
    """Process a protein sequence for tokenization"""
    return " ".join(list(sequence))


def get_embeddings(model, tokenizer_plm, tokenizer_llm, sequences=None, texts=None, device="cuda", mean_pooling=False):
    """Extract embeddings for a sequence and/or text"""
    with torch.no_grad():
        protein_embedding = None
        text_embedding = None
        
        if sequences is not None:
            processed_seq = [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in sequences]
            inputs = tokenizer_plm(
                processed_seq,
                return_tensors="pt",
                padding=True,
            ).to(device)
            
            outputs = model(
                input_ids_sequence=inputs["input_ids"],
                attention_mask_sequence=inputs["attention_mask"]
            )
            protein_embedding = outputs['proj_protein_embeds']
            if mean_pooling:
                protein_embedding = torch.mean(protein_embedding, dim=1)
        
        if texts is not None:
            inputs = tokenizer_llm(
                texts,
                return_tensors="pt",
                padding=True,
            ).to(device)
            
            outputs = model(
                input_ids_text=inputs["input_ids"],
                attention_mask_text=inputs["attention_mask"]
            )
            text_embedding = outputs['proj_text_embeds']
            if mean_pooling:
                text_embedding = torch.mean(text_embedding, dim=1)
    return protein_embedding, text_embedding


def main():
    model_path = "../tmp/models/protT5-CLIP-2025-01-07-12-49-43-ddp"
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    
    with open("../configs/model.yaml", "r") as f:
        train_config = yaml.safe_load(f)

    
    plm_name = train_config["model"]["protein_encoder_name"]
    llm_name = train_config["model"]["text_encoder_name"]

    plm_config = AutoConfig.from_pretrained(plm_name)
    llm_config = AutoConfig.from_pretrained(llm_name, trust_remote_code=True)

    # Don't change
    model_config = ProtT5CLIPConfig(
        name_or_path_plm=plm_name,
        name_or_path_llm=llm_name,
        plm_config=plm_config,
        llm_config=llm_config,
        output_hidden_states=True,
        output_attentions=True,
        return_dict=True,
        projection_dim=train_config["model"]["text_projection_dim"],
        logit_scale_init_value=train_config["model"]["logit_scale_init_value"],
        device=device,
    )
    
    model = ProtT5CLIP(model_config)
    model.load_adapter(model_path)
    model.to(device)
    model.eval()
    
    tokenizer_plm = T5Tokenizer.from_pretrained(model_config.name_or_path_plm)
    tokenizer_llm = AutoTokenizer.from_pretrained(model_config.name_or_path_llm)
    
    sequences = ["MLKPAFKGMASKV"]
    texts = ["This protein is involved in DNA binding"]
    
    protein_emb, text_emb = get_embeddings(
        model, tokenizer_plm, tokenizer_llm,
        sequences=sequences,
        texts=texts,
        device=device,
        mean_pooling=True
    )
    
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()

    
    print("Protein embedding shape:", protein_emb.shape)
    print("Protein embedding:", protein_emb)
    print("Text embedding shape:", text_emb.shape)
    print("Text embedding:", text_emb)


if __name__ == "__main__":
    main()