In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from functools import partial
from saveAndLoad import *
from functools import partial
import math
import torch.nn.functional as F
import torch.nn as nn

import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, CPUOffload
from torch.distributed.fsdp import (
    StateDictType,
    FullStateDictConfig
)
from torch.utils.data import DistributedSampler, DataLoader
import os
from torch.amp import autocast, GradScaler
from torch.distributions import Beta

import time as time_
from contextlib import nullcontext
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# from peft import get_peft_model, LoraConfig, PrefixTuningConfig, TaskType
import torch, torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class MLP(nn.Module):

    def __init__(self, config, use_dropout=True, d = 'emb'):
        super().__init__()
        assert d in ['emb', 'input'], "d must be either 'emb' or 'input'"
        d = config.emb_dim if d == 'emb' else config.input_dim
        self.c_fc    = nn.Linear(d, 4 * d, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * d, d, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        self.use_dropout = use_dropout

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        if self.use_dropout: x = self.dropout(x)
        return x

class Block_CrossAttentionResampler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm_q = config.norm_fn(config.input_dim)
        self.norm_k = config.norm_fn(config.input_dim)
        self.norm_mlp = config.norm_fn(config.input_dim)
        self.lin = nn.Linear(config.input_dim, config.input_dim)
        self.mlp = MLP(config, d='input')

        self.resampler = nn.MultiheadAttention(embed_dim=config.input_dim, 
                                            num_heads=config.num_heads, 
                                            dropout=config.dropout, 
                                            batch_first=True)

    def forward(self, x):
        seq, q, key_padding_mask = x
        q_norm =    self.norm_q(q)
        seq_norm =  self.norm_k(seq)
        x = (seq_norm, q_norm, key_padding_mask)
        resampled = self.resampler(x)
        resampled = self.lin(resampled)
        q = q + resampled
        q = self.norm_mlp(q)
        q = q + self.mlp(q)
        return q

class CrossAttentionResampler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.query = nn.Embedding(config.n_resampled_tokens, config.input_dim)
        self.blocks = nn.ModuleList([
            Block_CrossAttentionResampler(config)
            for _ in range(config.n_cross_attention_resampler_blocks)
        ])
        
    def forward(self, x):
        seq, key_padding_mask = x
        B, seq_len, emb_dim = seq.shape
        q = self.query(torch.arange(self.query.num_embeddings, device=seq.device))
        q = q.unsqueeze(0).expand(B, -1, -1)
        for block in self.blocks:
            x = (seq, q, key_padding_mask)
            q = block(x)
        return q


class MutantResampler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.resampler = CrossAttentionResampler(config)
        self.lm = AutoModelForCausalLM.from_pretrained(
            config.lm_name,
            torch_dtype=config.torch_dtype,
            device_map=config.device_map,
            gradient_checkpointing=True
        )

    def forward(self, x):
        # seq, key_padding_mask = x
        mut_repr = self.resampler(x)
        lm_output = self.lm(mut_repr)
        return lm_output

In [4]:
model_name = "aaditya/OpenBioLLM-Llama3-8B"

base_lm = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:2"
)

tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)

base_lm.to('cuda:2')

Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.21s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [None]:
create_message = lambda q, a: [
    {'role': 'system','content': ('You are an expert in genomics and proteomics. Respond clearly and concisely.')},
    {'role': 'user', 'content': f'{q}'},
    # {'role': 'assistant', 'content': f'{a}'} 
    ]

q = create_message("Describe all protein domains of Tp53.", "Example Response")

llama3_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n    {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n    {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n    {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n    {%- set system_message = messages[0]['content']|trim %}\n    {%- set messages = messages[1:] %}\n{%- else %}\n    {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n    {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n    {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n    {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n    {#- Extract the first user message so we can plug it in here #}\n    {%- if messages | length != 0 %}\n        {%- set first_user_message = messages[0]['content']|trim %}\n        {%- set messages = messages[1:] %}\n    {%- else %}\n        {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n    {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n    {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n    {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n    {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n    {{- \"Do not use variables.\\n\\n\" }}\n    {%- for t in tools %}\n        {{- t | tojson(indent=4) }}\n        {{- \"\\n\\n\" }}\n    {%- endfor %}\n    {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n        {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n    {%- elif 'tool_calls' in message %}\n        {%- if not message.tool_calls|length == 1 %}\n            {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n        {%- endif %}\n        {%- set tool_call = message.tool_calls[0].function %}\n        {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n            {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n            {%- for arg_name, arg_val in tool_call.arguments | items %}\n                {{- arg_name + '=\"' + arg_val + '\"' }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- endif %}\n                {%- endfor %}\n            {{- \")\" }}\n        {%- else  %}\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n            {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n            {{- '\"parameters\": ' }}\n            {{- tool_call.arguments | tojson }}\n            {{- \"}\" }}\n        {%- endif %}\n        {%- if builtin_tools is defined %}\n            {#- This means we're in ipython mode #}\n            {{- \"<|eom_id|>\" }}\n        {%- else %}\n            {{- \"<|eot_id|>\" }}\n        {%- endif %}\n    {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n        {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n        {%- if message.content is mapping or message.content is iterable %}\n            {{- message.content | tojson }}\n        {%- else %}\n            {{- message.content }}\n        {%- endif %}\n        {{- \"<|eot_id|>\" }}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"

# llama3_template = """<|begin_of_text|>{% for message in messages %}
# <|start_header_id|>{{ message['role'] }}<|end_header_id|>
# {{ message['content'] }}<|eot_id|>
# {% endfor %}{% if add_generation_prompt %}
# <|start_header_id|>assistant<|end_header_id|>
# {% endif %}"""

tokenizer.chat_template = llama3_template
prompt = tokenizer.apply_chat_template(q, tokenize=False, add_generation_prompt=True)
print('--------prompt--------')
print(prompt)
print('--------output--------')
tokenized = tokenizer(prompt, return_tensors="pt").to('cuda:2')
terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
with torch.no_grad():
    outputs = base_lm.generate(**tokenized,
                               num_beams = 1,
                               do_sample = False,
                               repetition_penalty = 1.2,
                               eos_token_id = terminators,
                                pad_token_id = tokenizer.pad_token_id,
                               max_new_tokens = 1000)  
output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
print(output_text)
print('\n--------response only--------')
print(output_text.split('\n')[-1])  


--------prompt--------
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are an expert in genomics and proteomics. Respond clearly and concisely.<|eot_id|><|start_header_id|>user<|end_header_id|>

Describe all protein domains of Tp53.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


--------output--------
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are an expert in genomics and proteomics. Respond clearly and concisely.<|eot_id|><|start_header_id|>user<|end_header_id|>

Describe all protein domains of Tp53.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The p53 gene encodes for the tumor suppressor protein TP53, which is involved in regulating cell division and preventing cancer formation by inducing DNA repair or apoptosis when necessary. The structure of TP53 includes several functional domains that contribu

In [69]:
'''--------prompt--------
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert in genomics and proteomics. Respond clearly and concisely.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Describe all protein domains of Tp53.<|eot_id|>
'''



'--------prompt--------\n<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are an expert in genomics and proteomics. Respond clearly and concisely.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\nDescribe all protein domains of Tp53.<|eot_id|>\n'