In [1]:
import pandas as pd
import numpy as np

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import GroupKFold
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
train = pd.read_parquet(
    'avito-merge-pairs-and-products/train_df.parquet', 
    columns=['attr_keys_1', 'attr_keys_2', 'attr_vals_1', 'attr_vals_2', 'group_id', 'action_date', 'is_double']
)
test = pd.read_parquet(
    'avito-merge-pairs-and-products/test_df.parquet',
    columns=['attr_keys_1', 'attr_keys_2', 'attr_vals_1', 'attr_vals_2']
)

In [3]:
def has_empty(row):
    return (
        len(row['attr_keys_1']) == 0 or
        len(row['attr_vals_1']) == 0 or
        len(row['attr_keys_2']) == 0 or
        len(row['attr_vals_2']) == 0
    )

# на инференсе none кидать если []
train = train[~train.apply(has_empty, axis=1)]
test = test[~test.apply(has_empty, axis=1)]

In [4]:
class AttrDataset(Dataset):
    def __init__(
        self, keys1, vals1, keys2, vals2, targets, 
        key_vocab=None, val_vocab=None, max_key_len=16, max_val_len=16
    ):
        self.max_key_len = max_key_len
        self.max_val_len = max_val_len
        
        if key_vocab is None:
            unique_keys = set()
            for k_list in keys1 + keys2:
                unique_keys.update(k_list)
            self.key_vocab = {'<pad>': 0}
            self.key_vocab.update({k: i + 1 for i, k in enumerate(unique_keys)})
        else:
            self.key_vocab = key_vocab
            
        if val_vocab is None:
            unique_vals = set()
            for v_list in vals1 + vals2:
                unique_vals.update(v_list)
            self.val_vocab = {'<pad>': 0}
            self.val_vocab.update({v: i + 1 for i, v in enumerate(unique_vals)})
        else:
            self.val_vocab = val_vocab
        
        self.keys1 = keys1
        self.vals1 = vals1
        self.keys2 = keys2
        self.vals2 = vals2
        self.targets = targets
        
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, idx):
        k1 = [self.key_vocab.get(k, 0) for k in self.keys1[idx]]
        v1 = [self.val_vocab.get(v, 0) for v in self.vals1[idx]]
        k2 = [self.key_vocab.get(k, 0) for k in self.keys2[idx]]
        v2 = [self.val_vocab.get(v, 0) for v in self.vals2[idx]]
        
        k1_padded = k1[:self.max_key_len] + [0] * max(0, self.max_key_len - len(k1))
        v1_padded = v1[:self.max_val_len] + [0] * max(0, self.max_val_len - len(v1))
        k2_padded = k2[:self.max_key_len] + [0] * max(0, self.max_key_len - len(k2))
        v2_padded = v2[:self.max_val_len] + [0] * max(0, self.max_val_len - len(v2))
        
        k1_mask = torch.BoolTensor([i >= len(k1) for i in range(self.max_key_len)])
        v1_mask = torch.BoolTensor([i >= len(v1) for i in range(self.max_val_len)])
        k2_mask = torch.BoolTensor([i >= len(k2) for i in range(self.max_key_len)])
        v2_mask = torch.BoolTensor([i >= len(v2) for i in range(self.max_val_len)])
        
        return {
            'keys1': torch.LongTensor(k1_padded),
            'vals1': torch.LongTensor(v1_padded),
            'keys2': torch.LongTensor(k2_padded),
            'vals2': torch.LongTensor(v2_padded),
            'k1_mask': k1_mask,
            'v1_mask': v1_mask,
            'k2_mask': k2_mask,
            'v2_mask': v2_mask,
            'target': torch.FloatTensor([self.targets[idx]])
        }

In [5]:
class CrossAttention(nn.Module):
    def __init__(self, emb_dim, n_heads):
        super(CrossAttention, self).__init__()
        self.mh_attn = nn.MultiheadAttention(emb_dim, n_heads, batch_first=True)
        
    def forward(self, q, k, v, key_padding_mask=None):
        attn_output, _ = self.mh_attn(q, k, v, key_padding_mask=key_padding_mask)
        return attn_output

In [6]:
class AttrCrossEncoder(nn.Module):
    def __init__(self, key_vocab_sz, val_vocab_sz, emb_dim=64, n_heads=4, dropout=0.1):
        super(AttrCrossEncoder, self).__init__()
        self.key_emb = nn.Embedding(key_vocab_sz, emb_dim, padding_idx=0)
        self.val_emb = nn.Embedding(val_vocab_sz, emb_dim, padding_idx=0)
        
        self.cross_attn = CrossAttention(emb_dim, n_heads)
        
        self.dropout = nn.Dropout(dropout)
        
        self.fc = nn.Sequential(
            nn.Linear(emb_dim * 4, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1),
            # nn.Sigmoid()
        )
        
    def forward(self, keys1, vals1, keys2, vals2, k1_mask=None, v1_mask=None, k2_mask=None, v2_mask=None):
        k1_attn_mask = k1_mask if k1_mask is not None else None
        v1_attn_mask = v1_mask if v1_mask is not None else None
        k2_attn_mask = k2_mask if k2_mask is not None else None
        v2_attn_mask = v2_mask if v2_mask is not None else None
        
        emb_keys1 = self.key_emb(keys1)
        emb_vals1 = self.val_emb(vals1)
        emb_keys2 = self.key_emb(keys2)
        emb_vals2 = self.val_emb(vals2)
        
        attn_k1_k2 = self.cross_attn(emb_keys1, emb_keys2, emb_keys2, k2_attn_mask)
        attn_k1_v2 = self.cross_attn(emb_keys1, emb_vals2, emb_vals2, v2_attn_mask)
        attn_v1_k2 = self.cross_attn(emb_vals1, emb_keys2, emb_keys2, k2_attn_mask)
        attn_v1_v2 = self.cross_attn(emb_vals1, emb_vals2, emb_vals2, v2_attn_mask)
        
        if k1_mask is not None:
            mask1 = (~k1_mask).float().unsqueeze(-1)
            pooled_k1_k2 = (attn_k1_k2 * mask1).sum(dim=1) / mask1.sum(dim=1).clamp(min=1.0)
        else:
            pooled_k1_k2 = attn_k1_k2.mean(dim=1)
            
        if k1_mask is not None:
            mask1 = (~k1_mask).float().unsqueeze(-1)
            pooled_k1_v2 = (attn_k1_v2 * mask1).sum(dim=1) / mask1.sum(dim=1).clamp(min=1.0)
        else:
            pooled_k1_v2 = attn_k1_v2.mean(dim=1)
            
        if v1_mask is not None:
            mask1 = (~v1_mask).float().unsqueeze(-1)
            pooled_v1_k2 = (attn_v1_k2 * mask1).sum(dim=1) / mask1.sum(dim=1).clamp(min=1.0)
        else:
            pooled_v1_k2 = attn_v1_k2.mean(dim=1)
            
        if v1_mask is not None:
            mask1 = (~v1_mask).float().unsqueeze(-1)
            pooled_v1_v2 = (attn_v1_v2 * mask1).sum(dim=1) / mask1.sum(dim=1).clamp(min=1.0)
        else:
            pooled_v1_v2 = attn_v1_v2.mean(dim=1)
        
        concat = torch.cat([pooled_k1_k2, pooled_k1_v2, pooled_v1_k2, pooled_v1_v2], dim=1)
        concat = self.dropout(concat)
        
        out = self.fc(concat)
        
        return out.squeeze(1)


class BabyAttrCrossEncoder(nn.Module):
    def __init__(self, key_vocab_sz, val_vocab_sz, emb_dim=64, n_heads=4, dropout=0.1):
        super(BabyAttrCrossEncoder, self).__init__()
        self.key_emb = nn.Embedding(key_vocab_sz, emb_dim, padding_idx=0)
        self.val_emb = nn.Embedding(val_vocab_sz, emb_dim, padding_idx=0)
        
        self.cross_attn = CrossAttention(emb_dim, n_heads)
        
        self.dropout = nn.Dropout(dropout)
        
        self.fc = nn.Sequential(
            nn.Linear(emb_dim, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
        
    def forward(self, keys1, vals1, keys2, vals2, k1_mask=None, v1_mask=None, k2_mask=None, v2_mask=None):
        emb_keys1 = self.key_emb(keys1)
        emb_vals1 = self.val_emb(vals1)
        emb_vals2 = self.val_emb(vals2)
        
        group1 = torch.cat([emb_keys1, emb_vals1], dim=1)
        group2 = torch.cat([emb_vals1, emb_vals2], dim=1)
        
        if k1_mask is not None and v1_mask is not None:
            g1_mask = torch.cat([k1_mask, v1_mask], dim=1)
        else:
            g1_mask = None
            
        if v1_mask is not None and v2_mask is not None:
            g2_mask = torch.cat([v1_mask, v2_mask], dim=1)
        else:
            g2_mask = None
        
        attn_output = self.cross_attn(group1, group2, group2, g2_mask)
        
        if g1_mask is not None:
            mask = (~g1_mask).float().unsqueeze(-1)
            pooled_output = (attn_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        else:
            pooled_output = attn_output.mean(dim=1)
        
        pooled_output = self.dropout(pooled_output)
        out = self.fc(pooled_output)
        
        return out.squeeze(1)

In [7]:
# from torch.utils.data import Sampler

# class LengthGroupedSampler(Sampler):
#     def __init__(self, data_source, lengths, batch_size, mega_batch_mult=100):
#         self.data_source = data_source
#         self.lengths = lengths
#         self.batch_size = batch_size
#         self.mega_batch_mult = mega_batch_mult
#         self.indices = self._create_length_grouped_indices()

#     def _create_length_grouped_indices(self):
#         indices = np.random.permutation(len(self.lengths))
        
#         mega_batch_size = self.batch_size * self.mega_batch_mult
#         mega_batches = [indices[i:i + mega_batch_size] for i in range(0, len(indices), mega_batch_size)]

#         sorted_indices = []
#         for mega_batch in mega_batches:
#             sorted_mega_batch = sorted(mega_batch, key=lambda i: self.lengths[i], reverse=True)
#             sorted_indices.extend(sorted_mega_batch)

#         return sorted_indices

#     def __iter__(self):
#         for i in range(0, len(self.indices), self.batch_size):
#             yield self.indices[i:i + self.batch_size]

#     def __len__(self):
#         return (len(self.indices) + self.batch_size - 1) // self.batch_size

In [8]:
def train_with_groupkfold(
    df, n_splits=5, max_key_len=16, max_val_len=16, batch_size=4096, 
    num_epochs=10, lr=0.001, device='cuda'
):
    groups = df['group_id'].values
    targets = df['is_double'].values

    gkf = GroupKFold(n_splits=n_splits)

    oof_preds = np.zeros(len(df))
    oof_targets = targets.copy()

    fold_metrics_prauc = []
    fold_models = []

    for fold, (train_idx, val_idx) in enumerate(gkf.split(df, targets, groups)):
        print(f'fold {fold+1}/{n_splits}')
        train_df = df.iloc[train_idx]
        val_df = df.iloc[val_idx]

        train_dataset = AttrDataset(
            keys1=train_df['attr_keys_1'].tolist(),
            vals1=train_df['attr_vals_1'].tolist(),
            keys2=train_df['attr_keys_2'].tolist(),
            vals2=train_df['attr_vals_2'].tolist(),
            targets=train_df['is_double'].tolist(),
            max_key_len=max_key_len,
            max_val_len=max_val_len
        )

        val_dataset = AttrDataset(
            keys1=val_df['attr_keys_1'].tolist(),
            vals1=val_df['attr_vals_1'].tolist(),
            keys2=val_df['attr_keys_2'].tolist(),
            vals2=val_df['attr_vals_2'].tolist(),
            targets=val_df['is_double'].tolist(),
            key_vocab=train_dataset.key_vocab,
            val_vocab=train_dataset.val_vocab,
            max_key_len=max_key_len,
            max_val_len=max_val_len
        )

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

        model_params = {
            'key_vocab_sz': len(train_dataset.key_vocab),
            'val_vocab_sz': len(train_dataset.val_vocab),
            'emb_dim': 64,
            'n_heads': 4,
            'dropout': 0.2,
            'key_vocab': train_dataset.key_vocab,
            'val_vocab': train_dataset.val_vocab
        }

        model = BabyAttrCrossEncoder(
            key_vocab_sz=model_params['key_vocab_sz'],
            val_vocab_sz=model_params['val_vocab_sz'],
            emb_dim=model_params['emb_dim'],
            n_heads=model_params['n_heads'],
            dropout=model_params['dropout']
        )

        device = torch.device(device if torch.cuda.is_available() else 'cpu')
        model = model.to(device)

        pos_weight = torch.tensor([2574371 / 110662]).to(device)
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

        best_val_prauc = 0.0
        best_model_state = None

        for epoch in range(num_epochs):
            model.train()
            train_loss = 0.0
            train_preds = []
            train_targets = []

            for batch in tqdm(train_loader, desc=f'train epoch {epoch+1}/{num_epochs}'):
                keys1 = batch['keys1'].to(device)
                vals1 = batch['vals1'].to(device)
                keys2 = batch['keys2'].to(device)
                vals2 = batch['vals2'].to(device)
                k1_mask = batch['k1_mask'].to(device)
                v1_mask = batch['v1_mask'].to(device)
                k2_mask = batch['k2_mask'].to(device)
                v2_mask = batch['v2_mask'].to(device)
                targets_batch = batch['target'].squeeze(1).to(device)

                optimizer.zero_grad()

                outputs = model(keys1, vals1, keys2, vals2, k1_mask, v1_mask, k2_mask, v2_mask)
                loss = criterion(outputs, targets_batch)

                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                train_preds.extend(outputs.detach().cpu().numpy())
                train_targets.extend(targets_batch.cpu().numpy())

            train_prauc = average_precision_score(train_targets, train_preds)

            model.eval()
            val_loss = 0.0
            val_preds = []
            val_targets = []

            with torch.no_grad():
                for batch in tqdm(val_loader, desc=f'val epoch {epoch+1}/{num_epochs}'):
                    keys1 = batch['keys1'].to(device)
                    vals1 = batch['vals1'].to(device)
                    keys2 = batch['keys2'].to(device)
                    vals2 = batch['vals2'].to(device)
                    k1_mask = batch['k1_mask'].to(device)
                    v1_mask = batch['v1_mask'].to(device)
                    k2_mask = batch['k2_mask'].to(device)
                    v2_mask = batch['v2_mask'].to(device)
                    targets_batch = batch['target'].squeeze(1).to(device)

                    outputs = model(keys1, vals1, keys2, vals2, k1_mask, v1_mask, k2_mask, v2_mask)
                    loss = criterion(outputs, targets_batch)

                    val_loss += loss.item()
                    val_preds.extend(outputs.cpu().numpy())
                    val_targets.extend(targets_batch.cpu().numpy())

            val_prauc = average_precision_score(val_targets, val_preds)
            
            scheduler.step(val_prauc)

            if val_prauc > best_val_prauc:
                best_val_prauc = val_prauc
                best_model_state = model.state_dict().copy()

            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, "
                  f"Train PR-AUC: {train_prauc:.4f}, "
                  f"Val Loss: {val_loss/len(val_loader):.4f}, Val PR-AUC: {val_prauc:.4f}")

        torch.save({
            'model_state_dict': best_model_state,
            'key_vocab': train_dataset.key_vocab,
            'val_vocab': train_dataset.val_vocab,
            'model_params': model_params,
            'fold': fold,
            'val_prauc': best_val_prauc,
        }, f'attrcrossencoder_fold{fold+1}.pt')

        model.load_state_dict(best_model_state)
        model.eval()
        fold_models.append({
            'model': model,
            'key_vocab': train_dataset.key_vocab,
            'val_vocab': train_dataset.val_vocab
        })

        val_fold_preds = []
        with torch.no_grad():
            for batch in val_loader:
                keys1 = batch['keys1'].to(device)
                vals1 = batch['vals1'].to(device)
                keys2 = batch['keys2'].to(device)
                vals2 = batch['vals2'].to(device)
                k1_mask = batch['k1_mask'].to(device)
                v1_mask = batch['v1_mask'].to(device)
                k2_mask = batch['k2_mask'].to(device)
                v2_mask = batch['v2_mask'].to(device)

                outputs = model(keys1, vals1, keys2, vals2, k1_mask, v1_mask, k2_mask, v2_mask)
                val_fold_preds.extend(outputs.cpu().numpy())

        oof_preds[val_idx] = val_fold_preds
        fold_metrics_prauc.append(best_val_prauc)
        
        print(f"fold {fold+1} finished, best val prauc: {best_val_prauc:.4f}")

    oof_prauc = average_precision_score(oof_targets, oof_preds)
    oof_binary_preds = (oof_preds >= 0.5).astype(int)
    oof_accuracy = accuracy_score(oof_targets, oof_binary_preds)
    oof_precision, oof_recall, oof_f1, _ = precision_recall_fscore_support(
        oof_targets, oof_binary_preds, average='binary'
    )
    
    print(f"{oof_prauc=}")
    print(f"{oof_accuracy=}")
    print(f"{oof_precision=}")
    print(f"{oof_recall=}")
    print(f"{oof_f1=}")
    
    print(f"\nfolds mean prauc{np.mean(fold_metrics_prauc):.4f}, std: {np.std(fold_metrics_prauc):.4f}")
    
    oof_results = {
        'oof_preds': oof_preds,
        'oof_targets': oof_targets,
        'oof_metrics': {
            'prauc': oof_prauc,
            'accuracy': oof_accuracy,
            'precision': oof_precision,
            'recall': oof_recall,
            'f1': oof_f1
        },
        'fold_metrics_prauc': fold_metrics_prauc,
    }
    
    return oof_results, fold_models

In [9]:
oof_results, fold_models = train_with_groupkfold(train, n_splits=5, max_key_len=16, max_val_len=16)

fold 1/5


train epoch 1/10: 100%|██████████| 525/525 [04:09<00:00,  2.11it/s]
val epoch 1/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 1/10, Train Loss: 0.9282, Train PR-AUC: 0.2533, Val Loss: 0.7976, Val PR-AUC: 0.1726


train epoch 2/10: 100%|██████████| 525/525 [04:10<00:00,  2.09it/s]
val epoch 2/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 2/10, Train Loss: 0.7396, Train PR-AUC: 0.3733, Val Loss: 0.7712, Val PR-AUC: 0.1718


train epoch 3/10: 100%|██████████| 525/525 [04:10<00:00,  2.10it/s]
val epoch 3/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 3/10, Train Loss: 0.6347, Train PR-AUC: 0.4490, Val Loss: 0.8114, Val PR-AUC: 0.1643


train epoch 4/10: 100%|██████████| 525/525 [04:10<00:00,  2.10it/s]
val epoch 4/10: 100%|██████████| 132/132 [00:56<00:00,  2.33it/s]


Epoch 4/10, Train Loss: 0.5496, Train PR-AUC: 0.5104, Val Loss: 0.9265, Val PR-AUC: 0.1590


train epoch 5/10: 100%|██████████| 525/525 [04:10<00:00,  2.09it/s]
val epoch 5/10: 100%|██████████| 132/132 [00:56<00:00,  2.33it/s]


Epoch 5/10, Train Loss: 0.4747, Train PR-AUC: 0.5685, Val Loss: 1.0228, Val PR-AUC: 0.1420


train epoch 6/10: 100%|██████████| 525/525 [04:10<00:00,  2.10it/s]
val epoch 6/10: 100%|██████████| 132/132 [00:56<00:00,  2.33it/s]


Epoch 6/10, Train Loss: 0.4397, Train PR-AUC: 0.5953, Val Loss: 1.1358, Val PR-AUC: 0.1530


train epoch 7/10: 100%|██████████| 525/525 [04:09<00:00,  2.10it/s]
val epoch 7/10: 100%|██████████| 132/132 [00:56<00:00,  2.33it/s]


Epoch 7/10, Train Loss: 0.4101, Train PR-AUC: 0.6182, Val Loss: 1.2620, Val PR-AUC: 0.1389


train epoch 8/10: 100%|██████████| 525/525 [04:09<00:00,  2.10it/s]
val epoch 8/10: 100%|██████████| 132/132 [00:56<00:00,  2.33it/s]


Epoch 8/10, Train Loss: 0.3795, Train PR-AUC: 0.6420, Val Loss: 1.3532, Val PR-AUC: 0.1510


train epoch 9/10: 100%|██████████| 525/525 [04:10<00:00,  2.10it/s]
val epoch 9/10: 100%|██████████| 132/132 [00:55<00:00,  2.36it/s]


Epoch 9/10, Train Loss: 0.3667, Train PR-AUC: 0.6512, Val Loss: 1.4462, Val PR-AUC: 0.1430


train epoch 10/10: 100%|██████████| 525/525 [04:09<00:00,  2.11it/s]
val epoch 10/10: 100%|██████████| 132/132 [00:56<00:00,  2.35it/s]


Epoch 10/10, Train Loss: 0.3549, Train PR-AUC: 0.6610, Val Loss: 1.4921, Val PR-AUC: 0.1466
fold 1 finished, best val prauc: 0.1726
fold 2/5


train epoch 1/10: 100%|██████████| 525/525 [04:09<00:00,  2.10it/s]
val epoch 1/10: 100%|██████████| 132/132 [00:57<00:00,  2.32it/s]


Epoch 1/10, Train Loss: 0.8696, Train PR-AUC: 0.2487, Val Loss: 1.0699, Val PR-AUC: 0.2050


train epoch 2/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 2/10: 100%|██████████| 132/132 [00:57<00:00,  2.30it/s]


Epoch 2/10, Train Loss: 0.6879, Train PR-AUC: 0.3621, Val Loss: 1.1200, Val PR-AUC: 0.2087


train epoch 3/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 3/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 3/10, Train Loss: 0.5851, Train PR-AUC: 0.4368, Val Loss: 1.2180, Val PR-AUC: 0.2056


train epoch 4/10: 100%|██████████| 525/525 [04:08<00:00,  2.12it/s]
val epoch 4/10: 100%|██████████| 132/132 [00:56<00:00,  2.33it/s]


Epoch 4/10, Train Loss: 0.5015, Train PR-AUC: 0.5011, Val Loss: 1.5097, Val PR-AUC: 0.1778


train epoch 5/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 5/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 5/10, Train Loss: 0.4310, Train PR-AUC: 0.5567, Val Loss: 1.7862, Val PR-AUC: 0.1598


train epoch 6/10: 100%|██████████| 525/525 [04:09<00:00,  2.11it/s]
val epoch 6/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 6/10, Train Loss: 0.3658, Train PR-AUC: 0.6103, Val Loss: 1.9401, Val PR-AUC: 0.1589


train epoch 7/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 7/10: 100%|██████████| 132/132 [00:56<00:00,  2.34it/s]


Epoch 7/10, Train Loss: 0.3391, Train PR-AUC: 0.6320, Val Loss: 2.1969, Val PR-AUC: 0.1409


train epoch 8/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 8/10: 100%|██████████| 132/132 [00:56<00:00,  2.34it/s]


Epoch 8/10, Train Loss: 0.3163, Train PR-AUC: 0.6528, Val Loss: 2.3189, Val PR-AUC: 0.1331


train epoch 9/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 9/10: 100%|██████████| 132/132 [00:56<00:00,  2.34it/s]


Epoch 9/10, Train Loss: 0.2911, Train PR-AUC: 0.6746, Val Loss: 2.4569, Val PR-AUC: 0.1323


train epoch 10/10: 100%|██████████| 525/525 [04:09<00:00,  2.11it/s]
val epoch 10/10: 100%|██████████| 132/132 [00:57<00:00,  2.32it/s]


Epoch 10/10, Train Loss: 0.2806, Train PR-AUC: 0.6833, Val Loss: 2.5652, Val PR-AUC: 0.1334
fold 2 finished, best val prauc: 0.2087
fold 3/5


train epoch 1/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 1/10: 100%|██████████| 132/132 [00:57<00:00,  2.29it/s]


Epoch 1/10, Train Loss: 0.8579, Train PR-AUC: 0.2493, Val Loss: 1.1551, Val PR-AUC: 0.1784


train epoch 2/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 2/10: 100%|██████████| 132/132 [00:57<00:00,  2.30it/s]


Epoch 2/10, Train Loss: 0.6889, Train PR-AUC: 0.3659, Val Loss: 1.1623, Val PR-AUC: 0.1755


train epoch 3/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 3/10: 100%|██████████| 132/132 [00:57<00:00,  2.29it/s]


Epoch 3/10, Train Loss: 0.5888, Train PR-AUC: 0.4463, Val Loss: 1.2572, Val PR-AUC: 0.1611


train epoch 4/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 4/10: 100%|██████████| 132/132 [00:57<00:00,  2.29it/s]


Epoch 4/10, Train Loss: 0.5057, Train PR-AUC: 0.5142, Val Loss: 1.2944, Val PR-AUC: 0.1619


train epoch 5/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 5/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 5/10, Train Loss: 0.4321, Train PR-AUC: 0.5782, Val Loss: 1.5759, Val PR-AUC: 0.1528


train epoch 6/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 6/10: 100%|██████████| 132/132 [00:57<00:00,  2.29it/s]


Epoch 6/10, Train Loss: 0.3984, Train PR-AUC: 0.6040, Val Loss: 1.6444, Val PR-AUC: 0.1556


train epoch 7/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 7/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 7/10, Train Loss: 0.3698, Train PR-AUC: 0.6276, Val Loss: 1.7158, Val PR-AUC: 0.1479


train epoch 8/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 8/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 8/10, Train Loss: 0.3402, Train PR-AUC: 0.6521, Val Loss: 1.9502, Val PR-AUC: 0.1469


train epoch 9/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 9/10: 100%|██████████| 132/132 [00:57<00:00,  2.29it/s]


Epoch 9/10, Train Loss: 0.3277, Train PR-AUC: 0.6624, Val Loss: 1.9582, Val PR-AUC: 0.1472


train epoch 10/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 10/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 10/10, Train Loss: 0.3167, Train PR-AUC: 0.6717, Val Loss: 1.9704, Val PR-AUC: 0.1446
fold 3 finished, best val prauc: 0.1784
fold 4/5


train epoch 1/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 1/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 1/10, Train Loss: 0.8679, Train PR-AUC: 0.2445, Val Loss: 1.0661, Val PR-AUC: 0.1671


train epoch 2/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 2/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 2/10, Train Loss: 0.6997, Train PR-AUC: 0.3655, Val Loss: 1.1464, Val PR-AUC: 0.1586


train epoch 3/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 3/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 3/10, Train Loss: 0.6024, Train PR-AUC: 0.4432, Val Loss: 1.2465, Val PR-AUC: 0.1504


train epoch 4/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 4/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 4/10, Train Loss: 0.5197, Train PR-AUC: 0.5092, Val Loss: 1.3453, Val PR-AUC: 0.1394


train epoch 5/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 5/10: 100%|██████████| 132/132 [00:57<00:00,  2.30it/s]


Epoch 5/10, Train Loss: 0.4467, Train PR-AUC: 0.5680, Val Loss: 1.4880, Val PR-AUC: 0.1346


train epoch 6/10: 100%|██████████| 525/525 [04:07<00:00,  2.13it/s]
val epoch 6/10: 100%|██████████| 132/132 [00:56<00:00,  2.33it/s]


Epoch 6/10, Train Loss: 0.4130, Train PR-AUC: 0.5948, Val Loss: 1.5663, Val PR-AUC: 0.1336


train epoch 7/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 7/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 7/10, Train Loss: 0.3842, Train PR-AUC: 0.6191, Val Loss: 1.6516, Val PR-AUC: 0.1321


train epoch 8/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 8/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 8/10, Train Loss: 0.3556, Train PR-AUC: 0.6412, Val Loss: 1.8258, Val PR-AUC: 0.1246


train epoch 9/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 9/10: 100%|██████████| 132/132 [00:57<00:00,  2.31it/s]


Epoch 9/10, Train Loss: 0.3431, Train PR-AUC: 0.6526, Val Loss: 1.8555, Val PR-AUC: 0.1263


train epoch 10/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 10/10: 100%|██████████| 132/132 [00:58<00:00,  2.28it/s]


Epoch 10/10, Train Loss: 0.3315, Train PR-AUC: 0.6615, Val Loss: 1.9665, Val PR-AUC: 0.1212
fold 4 finished, best val prauc: 0.1671
fold 5/5


train epoch 1/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 1/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 1/10, Train Loss: 0.8600, Train PR-AUC: 0.2392, Val Loss: 1.0499, Val PR-AUC: 0.2026


train epoch 2/10: 100%|██████████| 525/525 [04:06<00:00,  2.13it/s]
val epoch 2/10: 100%|██████████| 132/132 [00:56<00:00,  2.32it/s]


Epoch 2/10, Train Loss: 0.6840, Train PR-AUC: 0.3585, Val Loss: 1.1495, Val PR-AUC: 0.1951


train epoch 3/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 3/10: 100%|██████████| 132/132 [00:57<00:00,  2.30it/s]


Epoch 3/10, Train Loss: 0.5886, Train PR-AUC: 0.4333, Val Loss: 1.2715, Val PR-AUC: 0.1945


train epoch 4/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 4/10: 100%|██████████| 132/132 [00:57<00:00,  2.30it/s]


Epoch 4/10, Train Loss: 0.5088, Train PR-AUC: 0.5007, Val Loss: 1.4137, Val PR-AUC: 0.1726


train epoch 5/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 5/10: 100%|██████████| 132/132 [00:57<00:00,  2.30it/s]


Epoch 5/10, Train Loss: 0.4367, Train PR-AUC: 0.5619, Val Loss: 1.6004, Val PR-AUC: 0.1720


train epoch 6/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 6/10: 100%|██████████| 132/132 [00:57<00:00,  2.30it/s]


Epoch 6/10, Train Loss: 0.4039, Train PR-AUC: 0.5894, Val Loss: 1.8323, Val PR-AUC: 0.1617


train epoch 7/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 7/10: 100%|██████████| 132/132 [00:57<00:00,  2.29it/s]


Epoch 7/10, Train Loss: 0.3757, Train PR-AUC: 0.6122, Val Loss: 1.9504, Val PR-AUC: 0.1633


train epoch 8/10: 100%|██████████| 525/525 [04:07<00:00,  2.12it/s]
val epoch 8/10: 100%|██████████| 132/132 [00:57<00:00,  2.30it/s]


Epoch 8/10, Train Loss: 0.3463, Train PR-AUC: 0.6368, Val Loss: 2.0853, Val PR-AUC: 0.1534


train epoch 9/10: 100%|██████████| 525/525 [04:08<00:00,  2.11it/s]
val epoch 9/10: 100%|██████████| 132/132 [00:58<00:00,  2.27it/s]


Epoch 9/10, Train Loss: 0.3347, Train PR-AUC: 0.6472, Val Loss: 2.1216, Val PR-AUC: 0.1520


train epoch 10/10: 100%|██████████| 525/525 [04:10<00:00,  2.10it/s]
val epoch 10/10: 100%|██████████| 132/132 [00:57<00:00,  2.28it/s]


Epoch 10/10, Train Loss: 0.3228, Train PR-AUC: 0.6572, Val Loss: 2.2841, Val PR-AUC: 0.1445
fold 5 finished, best val prauc: 0.2026
oof_prauc=0.13450917686914585
oof_accuracy=0.7555326135656433
oof_precision=0.09671350261669452
oof_recall=0.591332164609351
oof_f1=0.16623839833146825

folds mean prauc0.1859, std: 0.0166


In [10]:
def ensemble_predict(
    fold_models, keys1, vals1, keys2, vals2, 
    max_key_len=16, max_val_len=16, device='cuda'
):
    all_preds = []
    
    for model_data in fold_models:
        model = model_data['model']
        key_vocab = model_data['key_vocab']
        val_vocab = model_data['val_vocab']
        
        k1 = [key_vocab.get(k, 0) for k in keys1]
        v1 = [val_vocab.get(v, 0) for v in vals1]
        k2 = [key_vocab.get(k, 0) for k in keys2]
        v2 = [val_vocab.get(v, 0) for v in vals2]
        
        k1_padded = k1[:max_key_len] + [0] * max(0, max_key_len - len(k1))
        v1_padded = v1[:max_val_len] + [0] * max(0, max_val_len - len(v1))
        k2_padded = k2[:max_key_len] + [0] * max(0, max_key_len - len(k2))
        v2_padded = v2[:max_val_len] + [0] * max(0, max_val_len - len(v2))
        
        k1_mask = torch.BoolTensor([i >= len(k1) for i in range(max_key_len)])
        v1_mask = torch.BoolTensor([i >= len(v1) for i in range(max_val_len)])
        k2_mask = torch.BoolTensor([i >= len(k2) for i in range(max_key_len)])
        v2_mask = torch.BoolTensor([i >= len(v2) for i in range(max_val_len)])
        
        device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        keys1_tensor = torch.LongTensor([k1_padded]).to(device)
        vals1_tensor = torch.LongTensor([v1_padded]).to(device)
        keys2_tensor = torch.LongTensor([k2_padded]).to(device)
        vals2_tensor = torch.LongTensor([v2_padded]).to(device)
        
        k1_mask = k1_mask.to(device)
        v1_mask = v1_mask.to(device)
        k2_mask = k2_mask.to(device)
        v2_mask = v2_mask.to(device)
        
        model.eval()
        with torch.no_grad():
            pred = model(
                keys1_tensor, vals1_tensor, keys2_tensor, vals2_tensor, 
                k1_mask, v1_mask, k2_mask, v2_mask
            )
            all_preds.append(pred.item())
    
    ensemble_pred = np.mean(all_preds)
    
    return ensemble_pred