# Task: PDB
## Modifying LongLLaMA: Focused Transformer Training for Context Scaling

**Original Notebook**: https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_code_instruct_colab.ipynb


References:
* [LongLLaMA-Instruct-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_instruct)
* [FoT paper](https://arxiv.org/abs/2307.03170) and [GitHub repository](https://github.com/CStanKonrad/long_llama)

# Setup

In [None]:
!pip install --upgrade pip
!pip install transformers==4.30.0  sentencepiece accelerate -q

In [None]:
import numpy as np
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM, TextStreamer, PreTrainedModel, PreTrainedTokenizer
from typing import List, Optional
import os

In [None]:
os.listdir(os.getcwd())

In [None]:
MODEL_PATH = (
    "syzymon/long_llama_3b_instruct"
)
TOKENIZER_PATH = MODEL_PATH
# to fit into colab GPU we will use reduced precision
TORCH_DTYPE = torch.bfloat16

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
device

In [None]:
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=TORCH_DTYPE,
    device_map=device,
    trust_remote_code=True,
    # mem_attention_grouping is used
    # to trade speed for memory usage
    # for details, see the section Additional configuration
    mem_attention_grouping=(1, 2048),
)
model.eval()

# Load Input Documents (usually Papers)

In [None]:
import os

Specify directory here

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Specify paths containing the input JSONS and where the results should be saved.

In [None]:
# @title Input and Result Output Dirs
INPUT_RTDIR = '/content/drive/My Drive/local_benchmark/' # @param {type:"string"}
DIRPATH = '/content/drive/My Drive/local_benchmark/results/' # @param {type:"string"}

In [None]:
os.listdir(INPUT_RTDIR)

In [None]:
#DIRPATH = 'inference/t_1/'

In [None]:
# @title Task Specific Config
TASK_NAME = "pdb" # @param {type:"string"}
PROMPT_NAME = "reconstruct_protein_amino_acid_sequence_0_shot" # @param {type:"string"}
PROMPT_PATH = PROMPT_NAME + ".txt"

In [None]:
INPUT_DIR= f'{INPUT_RTDIR}/{TASK_NAME}/inputs/'

Modifications for Paper, Prompt order:
* below are the tertiary structure -> The PROTEIN TERTIARY STRUCTURE is provided above.
* Append "PROTEIN TERTIARY STRUCTURE" to input['text']
* AMINO ACID SEQUENCE:

In [None]:
PMODIFIED = """You are a computational biologist and I want you to reconstruct a protein's amino acid sequence from its tertiary structure.
* The input is a PDB that is a textual format describing the three-dimensional structures of a protein.
* Return the amino acid sequence in the standard FASTA format, which starts with a definition line with the greater than (>) line,
  followed by the single-letter codes for all amino acids in the second line.
* Make sure the amino acid sequence is in the second line.
* If there is an unknown amino acid in the structure, put "X" in the sequence.
* Make sure you go through the whole structure and get all the amino acids.
* No extra explanation is needed.

The PROTEIN TERTIARY STRUCTURE is provided above.

AMINO ACID SEQUENCE:
"""

In [None]:
PREFIX = "PROTEIN TERTIARY STRUCTURE: "

In [None]:
import json
import os

In [None]:
os.listdir(INPUT_DIR), len(os.listdir(INPUT_DIR))

In [None]:
def get_paper_list(inputdir):
  files = os.listdir(inputdir)
  papers = []
  for f in files:
    if f.endswith('.json'):
      papers.append(f[:f.rindex(".json")])
  return papers

In [None]:
# modified output dict metadata prep
def prepare_task_for_paper(paper: str, prompt_path: str, lm_id: str)-> dict[str, str]:
  paper_input = f'{INPUT_DIR}/{paper}.json'
  inputs = json.load(open(paper_input, 'r'))

  return {'record_id': inputs['record_id'], 'model_id': lm_id, 'prompt_path': prompt_path,
          'prompt_text': PREFIX + inputs['text'] + PMODIFIED , 'response_text': ''}

## Run on all sequences

In [None]:
from io import StringIO
import sys

In [None]:
import os
#for the paper
@torch.no_grad()
def load_to_memory(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text: str):
    tokenized_data = tokenizer(text, return_tensors="pt")
    input_ids = tokenized_data.input_ids
    input_ids = input_ids.to(model.device)
    # torch.manual_seed(0)
    output = model(input_ids=input_ids)
    memory = output.past_key_values
    return memory

In [None]:
@torch.no_grad()
def generate_with_memory_new(
    model: PreTrainedModel, tokenizer: PreTrainedTokenizer, memory, prompt: str, temperature=1.0
):
    tokenized_data = tokenizer(prompt, return_tensors="pt")
    input_ids = tokenized_data.input_ids
    input_ids = input_ids.to(model.device)

    streamer = TextStreamer(tokenizer, skip_prompt=False)

    new_memory = memory

    catch_stout = StringIO()
    sys.stdout = catch_stout

    stop = False
    while not stop:
        output = model(input_ids, past_key_values=new_memory)
        new_memory = output.past_key_values
        assert len(output.logits.shape) == 3
        assert output.logits.shape[0] == 1
        last_logit = output.logits[[0], [-1], :]
        dist = torch.distributions.Categorical(logits=last_logit / temperature)
        next_token = dist.sample()
        if next_token[0] == tokenizer.eos_token_id:
            streamer.put(next_token[None, :])
            streamer.end()
            stop = True
            # Restore stdout to its original state
            sys.stdout = sys.__stdout__
        else:
            input_ids = next_token[None, :]
            streamer.put(input_ids)
    return catch_stout.getvalue()

In [None]:
import inspect

In [None]:
PROMPT_PATH

In [None]:
def run_eval_loop(paper_list: List[str], results_dir: str, temperature: float):
  for PAPER in paper_list:
    print(PAPER)
    outpath = f'{results_dir}/{PAPER}.json'
    if os.path.exists(outpath):
      print(f'Skipping since result for {PAPER} already exists.')
    else:
      inputs = json.load(open(f'{INPUT_DIR}/{PAPER}.json', 'r'))
      out_dict = prepare_task_for_paper(paper=PAPER, prompt_path=PROMPT_PATH, lm_id=MODEL_PATH)

      fot_memory = load_to_memory(model, tokenizer, PREFIX + inputs['text']) # loads the paper to memory
      answer = generate_with_memory_new(model, tokenizer, fot_memory, PMODIFIED, temperature) #asks the prompt after
      out_dict['response_text'] = answer
      json.dump(out_dict, open(outpath, 'w'))
  return

d: Which run (trial) this is. If you're running multiple trials of the same experiment.

In [None]:
# @title Specify Run_d here
trial = "run_0" # @param {type:"string"}
EXP_DIR = f"{DIRPATH}/{TASK_NAME}/{PROMPT_NAME}/longllama/{trial}/success/"

In [None]:
print(EXP_DIR)

In [None]:
os.makedirs(EXP_DIR, exist_ok=True)

In [None]:
PAPERS = get_paper_list(INPUT_DIR)
print(len(PAPERS))

Now run on all papers

In [None]:
print(EXP_DIR)

In [None]:
run_eval_loop(PAPERS, EXP_DIR, 1.0)

Aside: Handling failures if any

In [None]:
PAPERS_FAILED = ['18', '19', '20', '7', '14', '5', '21']

In [None]:
PAPERS_SUCCESS = PAPERS.copy()
for p in PAPERS_FAILED:
  PAPERS_SUCCESS.remove(p)

In [None]:
len(PAPERS_SUCCESS)

In [None]:
torch.cuda.empty_cache()

In [None]:
run_eval_loop(PAPERS_SUCCESS, EXP_DIR, 1.0)

In [None]:
PAPERS_SUCCESS

In [None]:
os.listdir(EXP_DIR)

Render Outputs

In [None]:
test_paper = PAPERS[2]

In [None]:
sd0 = json.load(open(f'{EXP_DIR}/{test_paper}.json', 'r'))

In [None]:
sd0['response_text']