In [None]:
import scanpy as sc

# Load the preprocessed AnnData
input_path = '../data/GSE214979_ADHC/celltype_split_h5ad/Excitatory_filtered_BA46_ADHC_cellsentenced.h5ad'
adata = sc.read_h5ad(input_path)

In [None]:
from epiagent_lora.dataset import CellDatasetForUFEWithLabel, collate_fn_ufe_with_label
from torch.utils.data import DataLoader

# 1) Status → 정수 라벨로 맵
status_series = adata.obs['Status'].astype(str)
classes = {s:i for i, s in enumerate(sorted(status_series.unique()))}
labels = status_series.map(classes).tolist()

# 2) Dataset & Loader
# Extract cell sentences from the AnnData object
cell_sentences = adata.obs['cell_sentences'].tolist()

# Create the training dataset
train_cell_dataset = CellDatasetForUFEWithLabel(
    adata=adata,
    cell_sentences=cell_sentences,
    labels=labels,
    max_length=8192,
    alpha_for_CCA=1,
    num_cCRE=1355445,
    is_random=True,        # (선택) 약한 증강
)

# Create the training DataLoader
train_batch_size = 5
train_loader = DataLoader(
    train_cell_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=8,
    collate_fn=collate_fn_ufe_with_label, 
)


In [None]:
from epiagent_lora.model import EpiAgent
import torch

# Specify the path to the pre-trained model
model_path = '../weights/pretrained_EpiAgent.pth'

# Set the device (GPU if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Initialize the EpiAgent model with appropriate configurations
pretrained_model = EpiAgent(
    vocab_size=1355449,
    num_layers=18,
    embedding_dim=512,
    num_attention_heads=8,
    max_rank_embeddings=8192,
    use_flash_attn=True,
    pos_weight_for_RLM=torch.tensor(1.),
    pos_weight_for_CCA=torch.tensor(1.)
)

# Load the pre-trained weights into the model
pretrained_model.load_state_dict(torch.load(model_path))

# Ensure the CCA loss uses a positive weight of 1
pretrained_model.criterion_CCA.pos_weight = torch.tensor(1.)

# Move the model to the specified device
pretrained_model.to(device)

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,   
    r=16, lora_alpha=32, lora_dropout=0.05,
    bias="none",
    # target_modules=["Wqkv", "out_proj", "fc1", "fc2"]  
    target_modules=["out_proj", "fc1", "fc2"]  
)

pretrained_model = get_peft_model(pretrained_model, lora_config)

pretrained_model.print_trainable_parameters()

In [None]:
# peft 주입 후, 안전하게 한 번 더 확인: 기본 파라미터 모두 동결
for n, p in pretrained_model.named_parameters():
    if 'lora_' in n:  # LoRA 아답터 텐서
        p.requires_grad = True
    else:
        p.requires_grad = False

# CCA/SR 헤드도 명시적으로 동결(보강)
for m in [pretrained_model.fc1_for_CCA, pretrained_model.fc2_for_CCA,
          pretrained_model.fc1_for_RLM, pretrained_model.fc2_for_RLM,
          pretrained_model.signal_decoder,
          pretrained_model.layer_norm_for_CCA, pretrained_model.layer_norm_for_RLM]:
    for p in m.parameters():
        p.requires_grad = False

In [None]:
from epiagent_lora.train import train_with_contrastive_cca_sr
from epiagent_lora.model import SupConLoss
criterion_con = SupConLoss(temperature=0.07)

trained = train_with_contrastive_cca_sr(
    model=pretrained_model,
    train_loader=train_loader,
    device=device,
    lr=1e-4,
    weight_decay=0.01,
    use_noam=False,                   
    warmup_steps=10000,
    criterion_con=criterion_con,
    lambda_con=1.0, lambda_cca=1.0, lambda_sr=1.0,
    epochs=20, log_every=50,
    grad_clip=None,                   
    save_dir="./weights/exp1",
    save_every_steps=2000,
    enable_logging=True,
)

In [None]:
save_dir = "./weights/exp1/final_lora"

trained.save_pretrained(save_dir)

In [None]:
from epiagent_lora.inference import infer_cell_embeddings_from_trainloader

trained.eval()

cell_embeddings = infer_cell_embeddings_from_trainloader(
    trained, device, train_loader,
    normalize=True, use_cls=True, rebuild_noshuffle=True  
)

adata.obsm['cell_embeddings_fine_tuned'] = cell_embeddings  # (N, D)


In [None]:
import scanpy as sc
sc.pp.neighbors(adata, use_rep='cell_embeddings_fine_tuned', n_neighbors=15, metric='cosine')
sc.tl.umap(adata, random_state=0)
sc.pl.umap(adata, color=['Status'], frameon=False)
