In [2]:
%reload_ext autoreload
%autoreload 2

from _header import *

from open_clip.loss import ClipLoss, SigLipLoss
from transformers import AutoModelForCausalLM, AutoTokenizer
from pytorch_lightning import LightningModule
from omegaconf import DictConfig, OmegaConf
from torch import Tensor
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping



  from .autonotebook import tqdm as notebook_tqdm


Using device:	 cuda:0


In [None]:
# find metric to optimize
# hydra
# dataloader +: tokenizer padding False, do padding in dataloader
# do plots
# implement GO-terms metric into val/test step

# implement lora
# cosine / lamdba scheduler
# add test_step in Lightning TrainingModule 
# add batch enums to header
# normalize during forward of protclip
# add mlp / linear projection layer
# implement test step for lightning train module
# sigliploss logit_bias logit_scale
# coca

---
## Load Models and apply LoRA

In [3]:
data = [
    {"uid": "A001", "seq": "MLEVPVWIPILAFAVGLGLGLLIPHLQKPFQRF", "text": "This protein is involved in membrane transport."},
    {"uid": "A002", "seq": "MSLEQKKGADIISKILQIQNSIGKTTSPSTLKT", "text": "This enzyme catalyzes the hydrolysis of ATP."},
    {"uid": "A003", "seq": "MKMKQQGLVADLLPNIRVMKTFGHFVFNYYNDN", "text": "This transcription factor regulates gene expression."}
]

ex1_text = data[0]["text"]
ex2_seq = data[0]["seq"]
print(ex1_text)
print(ex2_seq)

This protein is involved in membrane transport.
MLEVPVWIPILAFAVGLGLGLLIPHLQKPFQRF


In [40]:
class PLMEncoder(nn.Module):
    """wraps PLM encoders"""
    def __init__(
        self,
        model_name: Literal["protT5, prostT5, esm2"] | str,
        device: torch.device,
        lora: bool = False,
        **kwargs: Any,
    ):
        super().__init__()
        
        if model_name == "protT5":
            self.mod_type = "pt"
            self.model = T5EncoderModel.from_pretrained(
            BASE_MODEL_PLM,
            device_map=device,
            torch_dtype='auto',
            cache_dir="/mnt/volume/mathias/pretrained_models"
            )
            self.tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_PLM)
            
        if lora:
            print("IF LORA DOES COOL THINGS")
            
    def freeze(self):
        """Freeze model params"""
        for param in self.model.parameters():
            param.requires_grad = False
        
    def forward(self, x, emb_type):
        inputs = self.tokenizer(
            x,
            return_tensors = "pt",
            max_length=10_000,
            truncation=True,
            padding=True,
            add_special_tokens=True
        )
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        #outputs = self.model(**inputs).last_hidden_state.cpu()
        outputs = self.model(**inputs).last_hidden_state
        # if emb_type == "per_res":
        #     if self.mod_type in ("pt", "ank"):
        #         outputs = outputs[:-1, :]
        #     elif self.mod_type == "esm":
        #         output = np.squeeze(outputs, axis=0)[:-1, :]
        #     return outputs
        
        if emb_type == "per_prot":
            #return outputs.mean(axis=1).flatten()
            return outputs.mean(axis=1)
        else:
            raise ValueError("Input valid embedding type")            

In [26]:
class LLMEncoder(nn.Module):
    """wraps LLM encoders"""
    def __init__(
        self,
        model_name: Literal["phi3.5"] | str,
        device: torch.device,
        lora: bool = False,
        **kwargs: Any,
    ):
        super().__init__()
        self.device = device
        if model_name == "phi3.5":
            self.model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL_LLM,
                device_map=device,
                torch_dtype="auto",
                trust_remote_code=True,
                cache_dir="/mnt/volume/mathias/pretrained_models"
            )
            self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_LLM)
         
        if lora:
            print("IF LORA: DO COOL THINGS")
            
        
    def freeze(self):
        """Freeze model params"""
        for param in self.model.parameters():
            param.requires_grad = False


    def forward(self, x, sentence_level=True):
        """Forward pass, extract token or sentence embeddings"""
        inputs = self.tokenizer(x, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        outputs = self.model(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        
        if sentence_level:
            # Average over tokens to get sentence embedding
            embeddings = last_hidden_state.mean(dim=1)
        else:
            # Keep per-token embeddings
            embeddings = last_hidden_state.squeeze(0)
        
        #return embeddings.detach().cpu().float().numpy()   # detach to numpy, remove detach() and numpy() if needed for further computation
        return embeddings.float()

In [27]:
class ProtCLIP(nn.Module):
    def __init__(
        self,
        plm_name: Literal["protT5"],
        llm_name: Literal["phi3.5"],
        loss: Literal["CLIP", "SIGLIP"],
        temperature: float,
        device: torch.device,
        lora: bool = False,
        
    ):
        super().__init__()
        # REMOVE FREEZE
        self.plm_encoder = PLMEncoder(model_name=plm_name, device=device, lora=lora)
        self.llm_encoder = LLMEncoder(model_name=llm_name, device=device, lora=lora)
        if loss == "CLIP":  
            self.loss = ClipLoss()
            
        self.temperature = temperature
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) # temporary
    
    def forward(self, batch: dict[str, Tensor]):
        prot_embed = self.plm_encoder(batch["seq"], "per_prot")
        txt_embed = self.llm_encoder(batch["text"], sentence_level=True) # sentence level true correct?
        # add normalization
        return prot_embed, txt_embed
    
    def compute_loss(self, prot_embed, txt_embed):
        return self.loss(prot_embed, txt_embed, logit_scale=self.logit_scale.exp()) # logit_scale temperature in init?   

In [50]:
class TrainingModule(LightningModule):
    def __init__(
        self,
        plm_name: Literal["protT5"], 
        llm_name: Literal["phi3.5"],
        loss: Literal["CLIP", "SIGLIP"],
        temperature: float,
        device: torch.device,
        lora: bool,
        lr: float,
        weight_decay: float,
        cfg: DictConfig
    ):
        super().__init__()
        self.cfg = cfg
        self.save_hyperparameters(self.cfg)
        self.model = ProtCLIP(plm_name=plm_name, llm_name=llm_name, loss=loss, temperature=temperature, device=device)
        self.lr = lr
        self.weight_decay = weight_decay
        
        # eval
        self.eval_outputs = {"val": []}
        # add more 
        
        
    def configure_optimizers(self):
        params = []
        for p in self.model.parameters():
            if p.requires_grad:
                params.append(p)
        optimizer = torch.optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay) # configure betas
        # also add scheduler here
        return optimizer
    
    def training_step(self, batch, batch_idx):
        loss, metrics = self.shared_step(batch, "train")
        self.log_dict(metrics, sync_dist=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, metrics = self.shared_step(batch, "val")
        self.log_dict(metrics,sync_dist=True)
        return loss

    def shared_step(self, batch, set: str):
        prot_embed, txt_embed = self.model(batch)
        
        print(f"prot_embed_dim: {prot_embed.shape}")
        print(f"txt_embed_dim: {txt_embed.shape}")

        loss = self.model.compute_loss(prot_embed.T, txt_embed.T)        
        metrics = {f"{set}_loss": loss}
        return loss, metrics
    
    # def predict_step(self, batch, batch_idx, dataloader_idx=0):
    #     return self.model(batch)


In [None]:

class ExampleDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        prot = self.data[idx]
        return prot

def train(cfg: DictConfig):
    data_train = [
    {"uid": "A001", "seq": "MLEVPVWIPILAFAVGLGLGLLIPHLQKPFQRF", "text": "This protein is involved in membrane transport."},
    {"uid": "A002", "seq": "MSLEQKKGADIISKILQIQNSIGKTTSPSTLKT", "text": "This enzyme catalyzes the hydrolysis of ATP."}
    ]
    data_val = [
    {"uid": "A003", "seq": "MKMKQQGLVADLLPNIRVMKTFGHFVFNYYNDN", "text": "This transcription factor regulates gene expression."}
    ]
    
    train_set = ExampleDataset(data_train)
    val_set = ExampleDataset(data_val)
    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=cfg.batch_size, shuffle=False)
    
    model = TrainingModule(
        plm_name=cfg.plm_name,
        llm_name=cfg.llm_name,
        loss=cfg.loss,
        temperature=cfg.temperature,
        device=device,
        lora=cfg.lora,
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        cfg=cfg
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        save_last=True,
        save_top_k=2,
        mode="min",
        dirpath="/mnt/volumne/mathias/checkpoints/",
        auto_insert_metric_name=True,
    )
    lr_monitor = LearningRateMonitor(logging_interval="step")
    es = EarlyStopping(min_delta=1e-3, patience=5, mode="min", monitor="val_loss")
    # add Logger (wandb/tensorboard)
    
    trainer = Trainer(
        max_epochs=cfg.epochs,
        accelerator="gpu",
        enable_progress_bar=True,
        strategy="auto",
        precision="bf16-mixed",
        log_every_n_steps=1,
        check_val_every_n_epoch=1,
        gradient_clip_val=5,
        callbacks=[checkpoint_callback, lr_monitor, es],
        #logger=logger,
          # Mixed precision for faster training (optional)
    )
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    # add result plots
    

cfg = DictConfig(
    {
    "plm_name": "protT5",
    "llm_name": "phi3.5",
    "loss": "CLIP",
    "temperature": 0.07,
    "batch_size": 16,
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "epochs": 10,
    "lora": False
    }
)

train(cfg)
    

                                                                   

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.43s/it]
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type     | Params | Mode 
-------------------------------------------
0 | model | ProtCLIP | 5.0 B  | train
-------------------------------------------
5.0 B     Trainable params
0         Non-trainable params
5.0 B     Total params
20,116.886Total estimated model params size (MB)
4         Modules in train mode
862       Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/home/mathias/.conda/envs/protclip/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


prot_embed_dim: torch.Size([1, 1024])
txt_embed_dim: torch.Size([1, 3072])


ValueError: Expected input batch_size (3072) to match target batch_size (1024).

In [4]:
#llm1 = LLMEncoder(model_name="phi3.5", device=device, lora=False)
#out1 = llm1(ex1_text, True)

plm1 = PLMEncoder(model_name="protT5", device=device, lora=False)
#out2 = plm1(ex2_seq, "per_prot")

#pc1 = ProtCLIP("protT5", "phi3.5", "CLIP", device, False)
#out3 = pc1(data[0])

#print(llm1)
print(plm1)
#print(out2)
#print(out3)


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


PLMEncoder(
  (model): T5EncoderModel(
    (shared): Embedding(128, 1024)
    (encoder): T5Stack(
      (embed_tokens): Embedding(128, 1024)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=1024, out_features=4096, bias=False)
                (k): Linear(in_features=1024, out_features=4096, bias=False)
                (v): Linear(in_features=1024, out_features=4096, bias=False)
                (o): Linear(in_features=4096, out_features=1024, bias=False)
                (relative_attention_bias): Embedding(32, 32)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseActDense(
                (wi): Linear(in_features=1024, out_features=16384, bias=False)
                (wo): Linear(in_features=163

In [None]:
# PLM

# plm_tokenizer = T5Tokenizer.from_pretrained(
#     pretrained_model_name_or_path=BASE_MODEL_PLM,
#     do_lower_case=False,
#     use_fast=True,
#     legacy=False,
# )

# plm_model, plm_loading_info = T5EncoderModel.from_pretrained(
#     pretrained_model_name_or_path=BASE_MODEL_PLM,
#     output_loading_info=True,
#     device_map=device,
#     # load_in_8bit=False,
#     # custom_dropout_rate=0.1,
# )

plm_lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q", "k", "v", "o"],
    bias="none",
)

plm_model = peft.get_peft_model(plm_model, plm_lora_config)
plm_model.print_trainable_parameters()


# LLM

# llm_tokenizer = AutoTokenizer.from_pretrained(
#     pretrained_model_name_or_path=BASE_MODEL_LLM
# )

# llm_model, llm_loading_info = AutoModelForCausalLM.from_pretrained(
#     BASE_MODEL_LLM,
#     device_map=device,
#     torch_dtype="auto",
#     trust_remote_code=True,
#     output_loading_info=True,
# )

llm_lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
    bias="none",
)

llm_model = peft.get_peft_model(llm_model, llm_lora_config)
llm_model.print_trainable_parameters()

In [None]:
# class ClipLoss(nn.Module):
#     def __init__(self, strategy: str):
#         self.strategy = strategy
#         super().__init__()

#     def forward(self, prot_features, txt_features, logit_scale):
#         logits_prot = logit_scale * prot_features @ txt_features.T
#         logits_txt = logit_scale * txt_features @ prot_features.T

#         labels = torch.arange(logits_prot.shape[0], dtype=torch.long, device=logits_prot.device)

#         cl_prot = F.cross_entropy(logits_prot, labels)
#         cl_txt = F.cross_entropy(logits_txt, labels)

#         total_loss = (cl_prot + cl_txt) / 2
#         return total_loss.mean(), cl_prot.mean(), cl_txt.mean()

---
## Train

In [None]:
data = [
    {"uid": "A001", "seq": "MLEVPVWIPILAFAVGLGLGLLIPHLQKPFQRF", "text": "This protein is involved in membrane transport."},
    {"uid": "A002", "seq": "MSLEQKKGADIISKILQIQNSIGKTTSPSTLKT", "text": "This enzyme catalyzes the hydrolysis of ATP."},
    {"uid": "A003", "seq": "MKMKQQGLVADLLPNIRVMKTFGHFVFNYYNDN", "text": "This transcription factor regulates gene expression."}
]

---
## Inference

In [None]:
with torch.no_grad():
    protein_features = plm_base_model()['last_hidden_state']
    language_features = llm_base_model()['last_hidden_state']