In [None]:
# deepgo_pipeline_with_val_loss.py
import os
import time
import csv
import numpy as np
from scipy.sparse import load_npz
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# ---------------------------
# CONFIG
# ---------------------------
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16       # giảm nếu OOM
EPOCHS = 40
LR = 5e-4
WEIGHT_DECAY = 1e-4
CLIP_NORM = 5.0
FEATURE_DROPOUT = 0.2
LABEL_SMOOTH = 0.03
TOP_K = 150
ALPHA = 0.2
THRESH_GRID = np.arange(0.0, 0.201, 0.005)

OUT_MODEL = "/kaggle/working/best_model.pt"
OUT_SUBMIT = "/kaggle/working/submission.tsv"

OUT_MODEL2 = "/kaggle/working/best_model2.pt"
OUT_SUBMIT2 = "/kaggle/working/submission2.tsv"

# reproducibility
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# ---------------------------
# LOAD DATA
# ---------------------------



X_train = np.load('/kaggle/input/cafa56-end/650_taxon_features_X_INPUT.npy')
X_test = np.load('/kaggle/input/cafa56-end/X_test.npy')
ids_test = np.load('/kaggle/input/cafa56-end/protein_ids_test.npy')
Y_sparse = load_npz("/kaggle/input/cafa56-end/Y_full.npz")  # CSR sparse
GO_terms = np.load("/kaggle/input/cafa56-end/GO_terms_full.npy", allow_pickle=True)

# normalization
global_mean = X_train.mean(axis=0).astype(np.float32)
global_std  = X_train.std(axis=0).astype(np.float32) + 1e-6

# train/val split
Y_sparse = Y_sparse.tocsr()
row_nnz = np.diff(Y_sparse.indptr)
valid_idx = np.where(row_nnz > 0)[0]
train_idx, val_idx = train_test_split(valid_idx, test_size=0.1, random_state=SEED, shuffle=True)

# pos_weight
train_sparse = Y_sparse[train_idx]
label_freq = np.array(train_sparse.sum(axis=0)).squeeze()
N_train = len(train_idx)
pos_weight_arr = (N_train - label_freq) / (label_freq + 1e-8)
pos_weight_arr = np.clip(pos_weight_arr, 1.0, 5.0)
pos_weight = torch.tensor(pos_weight_arr, dtype=torch.float32).to(DEVICE)

# IA vector
IA_dict = {}
with open("/kaggle/input/cafa56-end/IA.tsv") as f:
    for line in f:
        go, value = line.strip().split("\t")
        IA_dict[go] = float(value)
IA_vec = np.array([IA_dict.get(go, 0.0) for go in GO_terms], dtype=np.float32)

# ontology mapping
go2asp = {}
with open("/kaggle/input/mapping-wf1/go_to_aspect.tsv") as f:
    next(f)
    for line in f:
        go, asp = line.strip().split(",")
        go2asp[go] = asp
idx_MF = [i for i, go in enumerate(GO_terms) if go2asp.get(go) == "F"]
idx_CC = [i for i, go in enumerate(GO_terms) if go2asp.get(go) == "C"]
idx_BP = [i for i, go in enumerate(GO_terms) if go2asp.get(go) == "P"]

IA_vec_torch = torch.tensor(IA_vec, dtype=torch.float32).to(DEVICE)

# ---------------------------
# DATASET
# ---------------------------
class ProteinDataset(Dataset):
    def __init__(self, X, Y_sparse=None, indices=None, mean=None, std=None, feature_dropout=0.0, train=True):
        self.X = X
        self.Y = Y_sparse
        self.indices = np.array(indices) if indices is not None else np.arange(X.shape[0])
        self.mean = mean
        self.std = std
        self.feature_dropout = feature_dropout
        self.train = train

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        i = int(self.indices[idx])
        x = self.X[i].astype(np.float32)
        if self.mean is not None and self.std is not None:
            x = (x - self.mean) / self.std
        if self.train and self.feature_dropout > 0.0 and np.random.rand() < 0.5:
            mask = (np.random.rand(x.shape[0]) >= self.feature_dropout).astype(np.float32)
            x = x * mask
        x = torch.from_numpy(x)
        if self.Y is not None:
            y = torch.from_numpy(self.Y[i].toarray().squeeze().astype(np.float32))
            return x, y
        else:
            return x

train_ds = ProteinDataset(X_train, Y_sparse, train_idx, global_mean, global_std, FEATURE_DROPOUT, True)
val_ds   = ProteinDataset(X_train, Y_sparse, val_idx, global_mean, global_std, 0.0, False)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# ---------------------------
# MODEL
# ---------------------------
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden=[1024, 512], dropout=0.3):
        super().__init__()
        layers = []
        in_dim = input_dim
        for h in hidden:
            layers.append(nn.Linear(in_dim, h))
            layers.append(nn.LayerNorm(h))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
            in_dim = h
        layers.append(nn.Linear(in_dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

model = MLP(X_train.shape[1], Y_sparse.shape[1]).to(DEVICE)

# ---------------------------
# LOSS & OPTIMIZER
# ---------------------------
def bce_label_smooth(logits, targets, pos_weight=None, eps=LABEL_SMOOTH):
    smooth_pos = 1.0 - eps
    smooth_neg = eps * 0.5
    targets_sm = targets * smooth_pos + (1 - targets) * smooth_neg
    criterion = nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)
    loss = criterion(logits, targets_sm)
    return loss.mean()

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, epochs=EPOCHS, steps_per_epoch=len(train_loader))
scaler = torch.cuda.amp.GradScaler()

# ---------------------------
# METRICS
# ---------------------------
def f1_weighted_batch(y_true, y_pred_bin, idx_labels, IA_vec):
    y_true_sub = y_true[:, idx_labels]
    y_pred_sub = y_pred_bin[:, idx_labels]
    weights = IA_vec[idx_labels]
    w_precision_list, w_recall_list = [], []

    for i in range(y_true_sub.shape[0]):
        true_terms = y_true_sub[i] == 1
        pred_terms = y_pred_sub[i] == 1
        if pred_terms.sum() > 0:
            TP_w = weights[pred_terms & true_terms].sum()
            Pred_w = weights[pred_terms].sum()
            w_precision_list.append(TP_w / (Pred_w + 1e-9))
        True_w = weights[true_terms].sum()
        if True_w > 0:
            TP_w = weights[pred_terms & true_terms].sum()
            w_recall_list.append(TP_w / (True_w + 1e-9))
    wpr = np.mean(w_precision_list) if w_precision_list else 0.0
    wrc = np.mean(w_recall_list) if w_recall_list else 0.0
    return 2 * wpr * wrc / (wpr + wrc + 1e-9) if (wpr + wrc) > 0 else 0.0

@torch.no_grad()
def eval_model(model, loader, threshold=0.5, pos_weight=None):
    F1_MF_list, F1_CC_list, F1_BP_list = [], [], []
    total_loss = 0.0
    n_samples = 0
    model.eval()
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        loss = bce_label_smooth(logits, yb, pos_weight)
        total_loss += float(loss.item()) * xb.size(0)
        n_samples += xb.size(0)
        probs = torch.sigmoid(logits)
        y_pred_bin = (probs.cpu().numpy() >= threshold).astype(np.float32)
        y_true = yb.cpu().numpy().astype(np.float32)
        F1_MF_list.append(f1_weighted_batch(y_true, y_pred_bin, idx_MF, IA_vec))
        F1_CC_list.append(f1_weighted_batch(y_true, y_pred_bin, idx_CC, IA_vec))
        F1_BP_list.append(f1_weighted_batch(y_true, y_pred_bin, idx_BP, IA_vec))
    val_loss = total_loss / n_samples
    F1_MF = np.mean(F1_MF_list)
    F1_CC = np.mean(F1_CC_list)
    F1_BP = np.mean(F1_BP_list)
    F1_avg = (F1_MF + F1_CC + F1_BP) / 3
    return val_loss, F1_MF, F1_CC, F1_BP, F1_avg

# ---------------------------
# TRAIN LOOP
# ---------------------------
best_val_f1 = -1
best_val_loss = float('inf')
best_threshold = 0.2

train_loss_history = []
val_loss_history = [] 
val_f1_history = []

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    model.train()
    total_loss = 0
    n_samples = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            logits = model(xb)
            loss = bce_label_smooth(logits, yb, pos_weight)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        total_loss += float(loss.item()) * xb.size(0)
        n_samples += xb.size(0)
    train_loss = total_loss / n_samples

    val_loss, F1_MF, F1_CC, F1_BP, F1_val = eval_model(model, val_loader, threshold=best_threshold, pos_weight=pos_weight)
    train_loss_history.append(train_loss)
    val_loss_history.append(val_loss)
    val_f1_history.append(F1_val)
    print(f"[Epoch {epoch}] train_loss={train_loss:.6f} | val_loss={val_loss:.6f} | val_F1={F1_val:.6f} (MF={F1_MF:.6f} CC={F1_CC:.6f} BP={F1_BP:.6f}) | time={time.time()-t0:.1f}s")

    # save best
    if F1_val > best_val_f1:
        best_val_f1 = F1_val
        torch.save({
            "model_state": model.state_dict(),
            "best_threshold": best_threshold
        }, OUT_MODEL)
        print(" -> saved best model")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            "model_state": model.state_dict(),
            "best_threshold": best_threshold
        }, OUT_MODEL2)
        print(" -> saved best model (Val Loss)")



# ---------------------------
epochs = range(1, EPOCHS+1)
import matplotlib.pyplot as plt
plt.figure(figsize=(12,5))

# Loss
plt.subplot(1,2,1)
plt.plot(epochs, train_loss_history, label='Train Loss')
plt.plot(epochs, val_loss_history, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train vs Val Loss')
plt.legend()
plt.grid(True)

# F1
plt.subplot(1,2,2)
plt.plot(epochs, val_f1_history, label='Val F1 Weighted', color='green')
plt.xlabel('Epoch')
plt.ylabel('F1 Weighted')
plt.title('Validation F1 Weighted')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()
# ---------------------------
# PREDICTION
# ---------------------------
print("Predicting test set no 1...")
ckpt = torch.load(OUT_MODEL, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()
best_threshold = ckpt["best_threshold"]

with open(OUT_SUBMIT, "w", newline="") as f:
    writer = csv.writer(f, delimiter="\t")
    writer.writerow(["ID", "GO_ID", "score"])

    test_loader = DataLoader(ProteinDataset(X_test, None, mean=global_mean, std=global_std, train=False),
                             batch_size=128, shuffle=False)
    for i, xb in enumerate(test_loader):
        xb = xb.to(DEVICE)
        with torch.no_grad():
            logits = model(xb)
            probs = torch.sigmoid(logits).cpu().numpy()
            probs_rescore = probs  # Chỉ dùng xác suất gốc
    
        for j in range(probs.shape[0]):
            pid = ids_test[i*128 + j]
            row_rescore = probs_rescore[j]
            topk_idx = np.argsort(row_rescore)[::-1][:TOP_K]
            for idx in topk_idx:
                score = float(probs[j, idx])
                if score > 0.0:
                    writer.writerow([pid, GO_terms[idx], score])

print("Predicting test set no 2...")
ckpt = torch.load(OUT_MODEL2, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()
best_threshold = ckpt["best_threshold"]

with open(OUT_SUBMIT2, "w", newline="") as f:
    writer = csv.writer(f, delimiter="\t")
    writer.writerow(["ID", "GO_ID", "score"])

    test_loader = DataLoader(ProteinDataset(X_test, None, mean=global_mean, std=global_std, train=False),
                             batch_size=128, shuffle=False)
    for i, xb in enumerate(test_loader):
        xb = xb.to(DEVICE)
        with torch.no_grad():
            logits = model(xb)
            probs = torch.sigmoid(logits).cpu().numpy()
            probs_rescore = probs  # Chỉ dùng xác suất gốc
    
        for j in range(probs.shape[0]):
            pid = ids_test[i*128 + j]
            row_rescore = probs_rescore[j]
            topk_idx = np.argsort(row_rescore)[::-1][:TOP_K]
            for idx in topk_idx:
                score = float(probs[j, idx])
                if score > 0.0:
                    writer.writerow([pid, GO_terms[idx], score])

print("All done.")