In [None]:
# import torch
# from models.Text2ProteinGenModel import Text2ProteinGenModel

# # 1. 指定权重文件路径
# checkpoint_path = "./weights/text2protein_model/text2protein_complete.pt"

# print(f"正在从 {checkpoint_path} 加载模型...")

# # 2. 实例化并加载 (不需要任何额外的 config 配置，全在 pt 文件里)
# model = Text2ProteinGenModel(checkpoint_path)

In [None]:
import torch
from models.Text2ProteinGenModel import Text2ProteinGenModel
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from Bio import SeqIO
from tqdm import tqdm
import torch.optim as optim

# ================= 配置与初始化 =================
# 请替换为你的实际路径
CHECKPOINT_PATH = "./weights/text2protein_model/text2protein_complete.pt" 
SWISS_PROT_PATH = "./data/uniprot_sprot.fasta" # 替换为你的fasta路径
T5_PATH = "./weights/pinal-official-t5-large"
PROGEN_TOK_PATH = "./models/progen3_module/tokenizer.json"

# 1. 加载 Tokenizers
from transformers import AutoTokenizer
from tokenizers import Tokenizer
text_tokenizer = AutoTokenizer.from_pretrained(T5_PATH)
progen_tokenizer = Tokenizer.from_file(PROGEN_TOK_PATH)
PAD_ID = 0
BOS_ID = 1
EOS_ID = 2

# 2. 实例化模型 (使用上面修复后的类)
model = Text2ProteinGenModel(CHECKPOINT_PATH)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ================= 冻结参数逻辑 (完全按照你的要求) =================
print("\n=== Freezing Parameters ===")

# 1. 冻结 T5 Encoder
for param in model.lm.parameters():
    param.requires_grad = False

# 2. 冻结 ProGen3 的原始参数，只训练 cross_attn
for name, param in model.plm.named_parameters():
    if "cross_attn" in name:
        param.requires_grad = True
        # print(f"Training: {name}") # 可以取消注释查看
    else:
        param.requires_grad = False

# 检查可训练参数量
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable_params} / {all_params} ({trainable_params/all_params:.2%})")

# ================= Dataset 定义 =================
class SwissProtDataset(Dataset):
    def __init__(self, fasta_file, limit=None):
        self.data = []
        print(f"Loading {fasta_file}...")
        for i, record in enumerate(SeqIO.parse(fasta_file, "fasta")):
            if limit and i >= limit: break
            self.data.append({"text": record.description, "protein": str(record.seq)})
            
    def __len__(self): return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        # Text
        text_enc = text_tokenizer(item['text'], max_length=512, truncation=True, return_tensors='pt')
        # Protein
        prot_ids = [BOS_ID] + progen_tokenizer.encode(item['protein']).ids[:1022] + [EOS_ID]
        
        return {
            "text_ids": text_enc.input_ids.squeeze(0),
            "text_mask": text_enc.attention_mask.squeeze(0),
            "protein_ids": torch.tensor(prot_ids, dtype=torch.long)
        }

def collate_fn(batch):
    text_ids = pad_sequence([b['text_ids'] for b in batch], batch_first=True, padding_value=text_tokenizer.pad_token_id)
    text_masks = pad_sequence([b['text_mask'] for b in batch], batch_first=True, padding_value=0)
    protein_ids = pad_sequence([b['protein_ids'] for b in batch], batch_first=True, padding_value=PAD_ID)
    return {"text_ids": text_ids, "text_masks": text_masks, "protein_ids": protein_ids, "labels": protein_ids}

# 实例化 DataLoader
dataset = SwissProtDataset(SWISS_PROT_PATH) # 先用少量数据测试
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

# ================= 训练循环 =================
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
epochs = 20

model.train()
print("\n=== Starting Training ===")
for epoch in range(epochs):
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    total_loss = 0
    for batch in loop:
        batch = {k: v.to(device) for k, v in batch.items()}
        
        optimizer.zero_grad()
        outputs = model(batch) # Forward pass
        loss = outputs['loss']
        
        loss.backward() # 现在应该不会报错了
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
        
    print(f"Epoch {epoch+1} done. Avg Loss: {total_loss/len(dataloader):.4f}")