In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import gc

# ==========================================
# CONFIGURATION
# ==========================================
CFG = {
    'input_dim': 1280,      
    'num_classes': 1500,    
    'batch_size': 32,      
    'epochs': 10,           
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu")
}
print(f"Using device: {CFG['device']}")

# ==========================================
# 1. LOAD DATA
# ==========================================
print(">>> Loading Data...")
train_emb = np.load("/kaggle/input/emb-models-ttt/train_embeds.npy").astype(np.float32)
train_ids = np.load("/kaggle/input/emb-models-ttt/train_ids.npy")
test_emb = np.load("/kaggle/input/emb-models-ttt/test_embeds.npy").astype(np.float32)
test_ids = np.load("/kaggle/input/emb-models-ttt/test_ids.npy")

train_terms = pd.read_csv("/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv", sep="\t")
term_counts = train_terms['term'].value_counts()
top_terms = term_counts.index[:CFG['num_classes']].tolist()
term_to_idx = {term: i for i, term in enumerate(top_terms)}

print(">>> Building Target Matrix...")
train_id_set = set(train_ids)
filtered_terms = train_terms[train_terms['term'].isin(top_terms) & train_terms['EntryID'].isin(train_id_set)]
id_to_index = {pid: i for i, pid in enumerate(train_ids)}
labels = np.zeros((len(train_ids), CFG['num_classes']), dtype=np.float32)

for pid, term in tqdm(zip(filtered_terms['EntryID'], filtered_terms['term']), total=len(filtered_terms)):
    if pid in id_to_index:
        labels[id_to_index[pid], term_to_idx[term]] = 1.0

del train_terms, filtered_terms, term_counts
gc.collect()

# ==========================================
# 2. DEFINE ADVANCED ARCHITECTURES
# ==========================================

# --- MODEL A: SE-ResNet (Signal Processing Logic) ---
# It recalibrates channel importance dynamically
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x shape: [Batch, Channel, 1]
        b, c, _ = x.size()
        y = self.fc(x.view(b, c)).view(b, c, 1)
        return x * y.expand_as(x)

class SEResNet1D(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        # We treat the 1280 embedding as 1280 "Channels" of length 1
        self.input_proj = nn.Linear(input_dim, 1280) 
        self.se1 = SEBlock(1280)
        self.res1 = nn.Sequential(nn.Linear(1280, 1280), nn.ReLU(), nn.Dropout(0.2))
        
        self.se2 = SEBlock(1280)
        self.res2 = nn.Sequential(nn.Linear(1280, 1280), nn.ReLU(), nn.Dropout(0.2))
        
        self.classifier = nn.Linear(1280, num_classes)

    def forward(self, x):
        # Project and reshape for SE Block
        x = self.input_proj(x)
        residual = x
        
        # Block 1
        x_se = x.view(x.size(0), x.size(1), 1)
        x_se = self.se1(x_se).view(x.size(0), x.size(1))
        x = self.res1(x_se) + residual
        
        # Block 2
        residual = x
        x_se = x.view(x.size(0), x.size(1), 1)
        x_se = self.se2(x_se).view(x.size(0), x.size(1))
        x = self.res2(x_se) + residual
        
        return self.classifier(x)

# --- MODEL B: Bottleneck Net (Compression Logic) ---
# Forces the model to learn the "Essence" by squeezing data
class BottleneckNet(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 256), # The Bottleneck (Compression)
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
        
        self.decoder_classifier = nn.Sequential(
            nn.Linear(256, 1024), # Expansion
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.encoder(x)
        return self.decoder_classifier(x)

# ==========================================
# 3. TRAINING LOOP
# ==========================================
class ProteinDataset(Dataset):
    def __init__(self, embeddings, targets=None):
        self.embeddings = embeddings
        self.targets = targets
    def __len__(self): return len(self.embeddings)
    def __getitem__(self, idx):
        x = torch.tensor(self.embeddings[idx], dtype=torch.float32)
        if self.targets is not None:
            return x, torch.tensor(self.targets[idx], dtype=torch.float32)
        return x

# Graph Parsing for Propagation (Load once)
print(">>> Parsing Graph for Propagation...")
parents = {}
with open("/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo", 'r') as f:
    current = None
    for line in f:
        line = line.strip()
        if line.startswith("id: "): current = line.split("id: ")[1]
        elif line.startswith("is_a: ") and current:
            p = line.split("is_a: ")[1].split(" ! ")[0]
            if current not in parents: parents[current] = set()
            parents[current].add(p)

idx_to_parents = {}
for i, term in enumerate(top_terms):
    if term in parents:
        idx_to_parents[i] = [term_to_idx[p] for p in parents[term] if p in term_to_idx]

# --- TRAIN AND SAVE SEQUENTIALLY ---
models_to_run = [
    ("SE_ResNet", SEResNet1D(CFG['input_dim'], CFG['num_classes'])),
    ("Bottleneck", BottleneckNet(CFG['input_dim'], CFG['num_classes']))
]

for name, model in models_to_run:
    print(f"\n>>> ====================================")
    print(f">>> Processing Architecture: {name}")
    print(f">>> ====================================")
    
    train_loader = DataLoader(ProteinDataset(train_emb, labels), batch_size=CFG['batch_size'], shuffle=True, num_workers=2)
    model = model.to(CFG['device'])
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.BCEWithLogitsLoss()

    # Train
    for epoch in range(CFG['epochs']):
        model.train()
        for batch_x, batch_y in tqdm(train_loader, desc=f"Ep {epoch+1}", leave=False):
            batch_x, batch_y = batch_x.to(CFG['device']), batch_y.to(CFG['device'])
            optimizer.zero_grad()
            loss = criterion(model(batch_x), batch_y)
            loss.backward()
            optimizer.step()

    # Predict
    print(f">>> Inference for {name}...")
    test_loader = DataLoader(ProteinDataset(test_emb), batch_size=CFG['batch_size'], shuffle=False, num_workers=2)
    model.eval()
    preds = []
    with torch.no_grad():
        for batch_x in tqdm(test_loader, desc="Predicting"):
            probs = torch.sigmoid(model(batch_x.to(CFG['device']))).cpu().numpy()
            preds.append(probs)
    preds = np.concatenate(preds)

    # Propagate (Biology Rules)
    print(f">>> Propagating Scores for {name}...")
    for i in tqdm(range(len(preds)), desc="Logic Check"):
        for child_idx, parent_indices in idx_to_parents.items():
            val = preds[i, child_idx]
            for p_idx in parent_indices:
                if val > preds[i, p_idx]:
                    preds[i, p_idx] = val

    # Save
    filename = f"submission_{name.lower()}.tsv"
    print(f">>> Saving {filename}...")
    submission_lines = []
    THRESHOLD = 0.01
    for i in tqdm(range(len(test_ids)), desc="Writing"):
        pid = test_ids[i]
        indices = np.where(preds[i] > THRESHOLD)[0]
        for idx in indices:
            submission_lines.append(f"{pid}\t{top_terms[idx]}\t{preds[i][idx]:.3f}")
            
    with open(filename, 'w') as f:
        f.write('\n'.join(submission_lines))
    
    # Cleanup
    del model, optimizer, preds, submission_lines
    torch.cuda.empty_cache()
    gc.collect()

print("\n>>> ALL JOBS COMPLETE. You have 2 new files ready for the reset!")

Using device: cuda
>>> Loading Data...
>>> Building Target Matrix...


  0%|          | 0/342098 [00:00<?, ?it/s]

>>> Parsing Graph for Propagation...

>>> Processing Architecture: SE_ResNet


Ep 1:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 2:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 3:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 4:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 5:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 6:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 7:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 8:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 9:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 10:   0%|          | 0/2576 [00:00<?, ?it/s]

>>> Inference for SE_ResNet...


Predicting:   0%|          | 0/7010 [00:00<?, ?it/s]

>>> Propagating Scores for SE_ResNet...


Logic Check:   0%|          | 0/224309 [00:00<?, ?it/s]

>>> Saving submission_se_resnet.tsv...


Writing:   0%|          | 0/224309 [00:00<?, ?it/s]


>>> Processing Architecture: Bottleneck


Ep 1:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 2:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 3:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 4:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 5:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 6:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 7:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 8:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 9:   0%|          | 0/2576 [00:00<?, ?it/s]

Ep 10:   0%|          | 0/2576 [00:00<?, ?it/s]

>>> Inference for Bottleneck...


Predicting:   0%|          | 0/7010 [00:00<?, ?it/s]

>>> Propagating Scores for Bottleneck...


Logic Check:   0%|          | 0/224309 [00:00<?, ?it/s]

>>> Saving submission_bottleneck.tsv...


Writing:   0%|          | 0/224309 [00:00<?, ?it/s]


>>> ALL JOBS COMPLETE. You have 2 new files ready for the reset!
