In [None]:
import sys

sys.path.insert(0, '../src/SynBIoModules')

In [None]:
from model import ResLinear

import numpy as np
import pandas as pd
import torch
from Bio import SeqIO
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer

In [None]:
MODEL_WEIGHTS_PATH = '../data/model_weights/regression_model.weights'
OUTPUT_PATH = '../data/predicted_brightness.txt'
SEQS_PATH = '../data/mutations_data/res_mutations.csv'

In [None]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [None]:
brightness_model = ResLinear(1280, 10)

In [None]:
brightness_model.load_state_dict(torch.load(MODEL_WEIGHTS_PATH, weights_only=True))
brightness_model.eval()

In [None]:
model_checkpoint = "facebook/esm2_t33_650M_UR50D"
model_name = model_checkpoint.split("/")[-1]

embedding_model = AutoModel.from_pretrained(model_checkpoint).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
SEQS_BATCH_SIZE = 600

In [None]:
with pd.read_csv(SEQS_PATH, chunksize=SEQS_BATCH_SIZE, names=['mut', 'seq']) as reader:
    for chunk in reader:
        seq_batch = chunk['seq'].to_list()
        with torch.no_grad():
            tokens = tokenizer(seq_batch, return_tensors="pt", padding=True)
            tokens = tokens.to(device)
            outputs = embedding_model(**tokens)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            brightness = brightness_model(embeddings)
            np.savetxt(OUTPUT_PATH, brightness.cpu().numpy().flatten())
            