In [None]:
# ライブラリのインポート

import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, EsmForMaskedLM
from tqdm.auto import tqdm
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt

In [None]:
# デバイスの設定

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# データの読み込み

df = pd.read_csv("../data/processed/amino-acid-genotypes-to-brightness.csv")

In [None]:
# モデルの読み込み

model_name_or_path = "facebook/esm2_t12_35M_UR50D"

model = EsmForMaskedLM.from_pretrained(model_name_or_path).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

In [None]:
# 対数尤度の計算

@torch.no_grad()
def log_likelihood(sequences, batch_size=8, subbatch_size=512):
    lls = []

    for i in tqdm(range(0, len(sequences), batch_size)):
        batch_sequences = sequences[i : i + batch_size]
        inputs = tokenizer(
            batch_sequences,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"]

        masked_inputs = []
        mask_positions = []  
        for j in range(input_ids.size()[0]):
            length = (
                attention_mask[j].sum().item()
            )
            for pos in range(1, length - 1):
                masked_input_ids = input_ids[j].clone()
                masked_input_ids[pos] = tokenizer.mask_token_id
                masked_inputs.append(masked_input_ids)
                mask_positions.append((j, pos))

        if len(masked_inputs) == 0:
            continue

        masked_tensor = torch.stack(masked_inputs, dim=0).to(device)

        logits = []
        for j in range(0, masked_tensor.size(0), subbatch_size):
            sub = masked_tensor[j : j + subbatch_size]
            outputs = model(sub).logits
            logits.append(outputs.cpu())
        logits = torch.cat(logits, dim=0)
        log_probs = F.log_softmax(logits, dim=-1)

        sum_log_probs = [0.0] * input_ids.size()[0]
        counts = [0] * input_ids.size()[0]

        for j, (k, pos) in enumerate(mask_positions):
            input_id = (
                input_ids[k, pos].cpu().item()
            )
            log_prob = log_probs[
                j, pos, input_id
            ].item()
            sum_log_probs[k] += log_prob
            counts[k] += 1

        for k in range(input_ids.size()[0]):
            if counts[k] > 0:
                lls.append(sum_log_probs[k] / counts[k])
            else:
                lls.append(0.0)

    return lls


df["log_likelihood"] = log_likelihood(df["sequence"].tolist())
df["perplexity"] = np.exp(-df["log_likelihood"])
df["delta"] = df["log_likelihood"] - df["log_likelihood"].iloc[0]

In [None]:
# 結果の表示

X = df["delta"].values
y = df["brightness"].values

# 共分散を計算
cov = np.cov(X, y, ddof=0)[0, 1]

# 相関係数を計算
r = cov / (np.std(X) * np.std(y))
print(f"R: {r:.3f}")

# 線形回帰の計算
lr = LinearRegression().fit(X.reshape(-1, 1), y)
r2 = lr.score(X.reshape(-1, 1), y)
print(f"R^2: {r2:.3f}")

y_pred = lr.predict(X.reshape(-1, 1))

In [None]:
# 結果のプロット

plt.figure(figsize=(12, 8), dpi=100)
plt.scatter(X, y, s=10)
indices = np.argsort(X.flatten())
plt.plot(
    X.flatten()[indices],
    y_pred[indices],
    linewidth=2,
    label=f"R^2 = {r2:.3f}",
    color="red",
)
plt.xlabel("Likelihood")
plt.ylabel("Brightness")
plt.legend()
plt.tight_layout()
plt.savefig("../figures/eda2/likelihood_vs_brightness.png")
plt.show()