In [None]:
! pip install matplotlib

In [None]:
! pip install seaborn

In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import fbeta_score, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
from sentence_transformers import SentenceTransformer, InputExample, losses, util
from torch.utils.data import DataLoader
import torch
import gc
import os

import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Tuple, List, Optional

RANDOM_SEED = 42
N_FOLDS = 5
EPOCHS=3
THRESHOLDS = [0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]  # Thresholds to test

# Udostępnij obie karty
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

# Model powinien automatycznie rozłożyć się na DataParallel
# model = SentenceTransformer('sentence-transformers/LaBSE')

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Clean 
objects_to_delete = ['model', 'train_loss', 'evaluator', 'train_dataloader', 'trainer']

for obj in objects_to_delete:
    if obj in globals():
        del globals()[obj]

# 2. Wymuś Garbage Collection w Pythonie (czyści RAM CPU)
gc.collect()

# 3. Wyczyść cache alokatora PyTorch (czyści RAM GPU)
torch.cuda.empty_cache()

# 4. Sprawdź dostępną pamięć
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


# --- 1. PRZYGOTOWANIE DANYCH ---

In [2]:
df_train_orig = pd.read_parquet("train_df.parquet")
df_test_orig = pd.read_parquet("real_test_df.parquet")

In [3]:
df_train_orig.info()

<class 'pandas.core.frame.DataFrame'>
Index: 1000 entries, 9901 to 8380
Data columns (total 22 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   match_count_m0   1000 non-null   float64
 1   match_count_m1   1000 non-null   float64
 2   match_count_m2   1000 non-null   float64
 3   match_count_m3   1000 non-null   float64
 4   match_count_m4   1000 non-null   float64
 5   match_count_m5   1000 non-null   float64
 6   match_count_m6   1000 non-null   float64
 7   match_count_m7   1000 non-null   float64
 8   match_count_m8   1000 non-null   float64
 9   match_count_m9   1000 non-null   float64
 10  match_count_m10  1000 non-null   float64
 11  match_count_m11  1000 non-null   float64
 12  match_count_m12  1000 non-null   float64
 13  index            1000 non-null   float64
 14  label_name_list  1000 non-null   object 
 15  person           1000 non-null   object 
 16  plLabel          1000 non-null   object 
 17  alias           

In [4]:
df_train_orig

Unnamed: 0,match_count_m0,match_count_m1,match_count_m2,match_count_m3,match_count_m4,match_count_m5,match_count_m6,match_count_m7,match_count_m8,match_count_m9,...,match_count_m12,index,label_name_list,person,plLabel,alias,alias_name_list,truth,len_label,len_alias
9901,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,...,0.666667,1518.0,"[Andrzej, Gawroński]",Q9152047,Andrzej Gawroński,Andrzej Antoni Ignacy Gawroński,"[Andrzej, Antoni, Ignacy, Gawroński]",True,2,4
22138,0.600000,0.600000,0.600000,0.600000,0.600000,0.600000,0.600000,0.600000,0.600000,0.600000,...,0.600000,4089.0,"[Michał, Hieronim, Czacki]",Q16576655,Michał Hieronim Czacki,Michał Hieronim Czacki z Czacza h. Świnka,"[Michał, Hieronim, Czacki, z, Czacza, h., Świnka]",True,3,7
5665,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,...,0.666667,850.0,"[Marcin, Bielski]",Q516773,Marcin Bielski,Marcin Bielski h. Prawdzic,"[Marcin, Bielski, h., Prawdzic]",True,2,4
16889,0.666667,0.833333,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,...,0.666667,3047.0,"[Jan, Kazimierz, de, Alten, Bokum]",Q11717974,Jan Kazimierz de Alten Bokum,Jan Kazimierz Bokum ab Alten h. Kuszaba,"[Jan, Kazimierz, Bokum, ab, Alten, h., Kuszaba]",True,5,7
53169,0.857143,0.857143,0.857143,0.857143,0.857143,0.857143,0.857143,0.857143,0.857143,0.857143,...,0.857143,11060.0,"[Fryderyk, Karol, Hauke]",Q8861642,Fryderyk Karol Hauke,Fryderyk Karol Emanuel Hauke,"[Fryderyk, Karol, Emanuel, Hauke]",True,3,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28139,0.400000,0.666667,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,...,0.400000,5944.0,"[Jan, Pac]",Q14527751,Jan Pac,Jan Vermeer van Delft,"[Jan, Vermeer, van, Delft]",False,2,4
4153,0.400000,0.400000,0.400000,0.400000,0.666667,0.666667,0.400000,0.400000,0.400000,0.400000,...,0.400000,616.0,"[Adam, Malczewski]",Q9139955,Adam Malczewski,Adam Piotr Celestyn Kaliszewski,"[Adam, Piotr, Celestyn, Kaliszewski]",False,2,4
2834,0.400000,0.666667,0.400000,0.400000,0.666667,0.666667,0.666667,0.666667,0.666667,0.400000,...,0.666667,416.0,"[Jean, Baptiste, Pillement]",Q374713,Jean-Baptiste Pillement,Jan Baptysta Gisleni,"[Jan, Baptysta, Gisleni]",False,3,3
63521,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.666667,0.400000,0.400000,...,0.400000,13661.0,"[Jan, Žižka]",Q215968,Jan Žižka,Janicz,[Janicz],False,2,1


In [5]:
df_test_orig

Unnamed: 0,match_count_m0,match_count_m1,match_count_m2,match_count_m3,match_count_m4,match_count_m5,match_count_m6,match_count_m7,match_count_m8,match_count_m10,match_count_m12,index,label_name_list,person,plLabel,alias,alias_name_list,truth,len_label,len_alias
19942,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,3490.0,"[Aleksander, Cetner]",Q17095539,Aleksander Cetner,Aleksander Cetner h. Przerowa,"[Aleksander, Cetner, h., Przerowa]",True,2,4
9752,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,1487.0,"[Piotr, Gamrat]",Q4108928,Piotr Gamrat,Piotr Gamrat z Sowoklęsk,"[Piotr, Gamrat, z, Sowoklęsk]",True,2,4
22999,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.400000,0.800000,4451.0,"[Michał, Bartłomiej, Tarło]",Q11778607,Michał Bartłomiej Tarło,Mikołaj Tarło,"[Mikołaj, Tarło]",True,3,2
76710,0.888889,0.888889,0.400000,0.888889,0.888889,0.888889,0.888889,0.888889,0.666667,0.666667,0.888889,17039.0,"[Johann, Hartwig, Ernst, von, Bernstorff]",Q70887,Johann Hartwig Ernst von Bernstorff,Johan Hartvig Ernst Bernstorff,"[Johan, Hartvig, Ernst, Bernstorff]",True,5,4
70588,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,15371.0,"[Nicolas, Appert]",Q39637,Nicolas Appert,Nicolas François Appert,"[Nicolas, François, Appert]",True,2,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
34411,1.000000,1.000000,0.400000,0.400000,1.000000,1.000000,0.400000,1.000000,1.000000,1.000000,1.000000,8282.0,"[Rene, Andegaweński]",Q170353,Rene Andegaweński,René Andegaweński,"[René, Andegaweński]",False,2,2
17017,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,0.800000,3054.0,"[Jan, Kościelecki]",Q11718120,Jan Kościelecki,Jan II Kościelecki,"[Jan, II, Kościelecki]",False,2,3
35263,0.666667,0.666667,0.400000,0.666667,0.666667,0.666667,0.666667,0.666667,0.400000,0.400000,0.666667,8413.0,"[Maria, Ludwika, Burbon]",Q229826,Maria Ludwika Burbon,Ludwik de Burbon,"[Ludwik, de, Burbon]",False,3,3
67326,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,0.666667,14698.0,"[Jean, Baptiste, Kléber]",Q319302,Jean-Baptiste Kléber,Jean Baptiste Pigalle,"[Jean, Baptiste, Pigalle]",False,3,3


## Opcja B (Zalecana, jeśli chcesz więcej danych): Złącz wszystko i wydziel nowy OOS


In [6]:
df_full = pd.concat([df_train_orig, df_test_orig]).reset_index(drop=True)
print(f"Liczba wierszy przed czyszczeniem: {len(df_full)}")

Liczba wierszy przed czyszczeniem: 80400


In [7]:
# 2. Definicja kolumn, które decydują o byciu duplikatem
# Dla modelu LaBSE kluczowa jest unikalność pary tekstów wejściowych.
duplicate_subset = ['plLabel', 'alias']

# 3. Usuwanie duplikatów
# keep='first' zachowuje pierwsze wystąpienie, a usuwa kolejne powtórzenia.
# inplace=False (domyślnie) zwraca nowy DataFrame.
df_full_clean = df_full.drop_duplicates(subset=duplicate_subset, keep='first')

# 4. Resetowanie indeksu (BARDZO WAŻNE dla KFold)
# drop=True usuwa starą kolumnę 'index', żeby nie śmieciła.
df_full_clean = df_full_clean.reset_index(drop=True)

print(f"Liczba wierszy po usunięciu duplikatów: {len(df_full_clean)}")
print(f"Usunięto {len(df_full) - len(df_full_clean)} duplikatów.")

Liczba wierszy po usunięciu duplikatów: 48403
Usunięto 31997 duplikatów.


In [8]:
# 1. Główny podział (Twoje standardowe 85% / 15%)
df_cv_pool_full, df_oos_original = train_test_split(
    df_full_clean, 
    test_size=0.15, 
    random_state=RANDOM_SEED, 
    stratify=df_full_clean['truth']
)

print(f"Dane do Cross-Validation: {len(df_cv_pool_full)}")
print(f"Dane do OOS Evaluation: {len(df_oos_original)}")

# 2. Podział df_cv_pool_full na "mały wycinek do treningu" i "resztę"
# train_size=0.05 -> to będzie Twoje 5% do pętli CV
# reszta -> to będzie Twój gigantyczny zbiór walidacyjny
df_cv_pool_test, df_rest = train_test_split(
    df_cv_pool_full,
    train_size=0.0304, 
    random_state=RANDOM_SEED,
    stratify=df_cv_pool_full['truth']
)

print(f"Dane do SZYBKIEGO TESTU Cross-Validation: {len(df_cv_pool_test)}")

Dane do Cross-Validation: 41142
Dane do OOS Evaluation: 7261
Dane do SZYBKIEGO TESTU Cross-Validation: 1250


In [9]:
def compute_metrics(model, df, thresholds):
    """
    Liczy metryki dla danego modelu i DataFrame'u dla listy progów.
    """
    # Generowanie embeddingów
    sentences1 = df['plLabel'].tolist()
    sentences2 = df['alias'].tolist()
    gt_labels = [1 if x else 0 for x in df['truth'].tolist()] # Zakładamy truth jako bool lub 0/1

    # Encode (batchowanie dla szybkości)
    embs1 = model.encode(sentences1, convert_to_tensor=True, show_progress_bar=False)
    embs2 = model.encode(sentences2, convert_to_tensor=True, show_progress_bar=False)

    # Cosine similarity
    cosine_scores = util.cos_sim(embs1, embs2).diagonal().cpu().numpy()
    
    results = {}
    for t in thresholds:
        pred_labels = (cosine_scores >= t).astype(int)
        
        # Obliczamy metryki
        f2 = fbeta_score(gt_labels, pred_labels, beta=2, zero_division=0)
        acc = accuracy_score(gt_labels, pred_labels)
        prec = precision_score(gt_labels, pred_labels, zero_division=0)
        rec = recall_score(gt_labels, pred_labels, zero_division=0)
        
        results[t] = {
            'f2': f2, 
            'acc': acc, 
            'prec': prec, 
            'rec': rec
        }
    return results

# --- 3. GŁÓWNA PĘTLA CROSS-VALIDATION ---

In [10]:
kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

# Struktura do przechowywania wyników z każdego folda
# fold_results = [
#    { 0.35: {'val_f2': ..., 'oos_f2': ...}, 0.40: {...} },  <-- Fold 1
#    { ... }                                                 <-- Fold 2
# ]
all_folds_data = []

# SMALL
final_df_cv_pool = df_cv_pool_test; df_oos = df_rest.reset_index(drop=True)

# BIG
# final_df_cv_pool = df_cv_pool_full; df_oos = df_cv_pool_full


for fold_idx, (train_idx, val_idx) in enumerate(kf.split(final_df_cv_pool)):
    print(f"\n--- FOLD {fold_idx+1}/{N_FOLDS} ---")
    
    # Podział danych
    train_df = final_df_cv_pool.iloc[train_idx]
    val_df = final_df_cv_pool.iloc[val_idx]
    
    # Przygotowanie treningu
    train_examples = []
    for _, row in train_df.iterrows():
        label = 0.9 if row['truth'] else 0.1 # Soft labels dla LaBSE
        train_examples.append(InputExample(texts=[row['plLabel'], row['alias']], label=label))
    
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
    
    # Inicjalizacja modelu
    model = SentenceTransformer('sentence-transformers/LaBSE')
    train_loss = losses.CosineSimilarityLoss(model=model)
    
    # Trening (krótki, np. 1-3 epoki)
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=EPOCHS,
        warmup_steps=100,
        show_progress_bar=True
    )
    
    # --- EWALUACJA ---
    print("Ewaluacja Validation Subset...")
    val_metrics = compute_metrics(model, val_df, THRESHOLDS)
    
    print("Ewaluacja OOS Pool...")
    oos_metrics = compute_metrics(model, df_oos, THRESHOLDS)
    
    # Zbieranie wyników dla tego folda
    fold_data = {}
    for t in THRESHOLDS:
        fold_data[t] = {
            # Validation Subset (wg Twojego obrazka: F2, Accuracy)
            'val_f2': val_metrics[t]['f2'],
            'val_acc': val_metrics[t]['acc'],
            
            # OOS Pool (wg obrazka: Precision, Recall, F2)
            'oos_prec': oos_metrics[t]['prec'],
            'oos_rec': oos_metrics[t]['rec'],
            'oos_f2': oos_metrics[t]['f2']
        }
    all_folds_data.append(fold_data)


--- FOLD 1/5 ---


Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.
                                                                     

Step,Training Loss


Ewaluacja Validation Subset...
Ewaluacja OOS Pool...

--- FOLD 2/5 ---


Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.
                                                                     

Step,Training Loss


Ewaluacja Validation Subset...
Ewaluacja OOS Pool...

--- FOLD 3/5 ---


Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.
                                                                     

Step,Training Loss


Ewaluacja Validation Subset...
Ewaluacja OOS Pool...

--- FOLD 4/5 ---


Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.
                                                                     

Step,Training Loss


Ewaluacja Validation Subset...
Ewaluacja OOS Pool...

--- FOLD 5/5 ---


Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.
                                                                     

Step,Training Loss


Ewaluacja Validation Subset...
Ewaluacja OOS Pool...


# --- 4. AGREGACJA WYNIKÓW I TWORZENIE TABELI ---

In [11]:
final_summary = []

for t in THRESHOLDS:
    # Wyciągnij wyniki dla danego progu ze wszystkich foldów
    metrics_at_t = [fold[t] for fold in all_folds_data]
    
    # Konwersja na DataFrame dla łatwego liczenia średniej/std
    df_metrics = pd.DataFrame(metrics_at_t)
    
    row = {'Threshold': t}
    
    # Dla każdej kolumny policz "mean ± std"
    for col in df_metrics.columns:
        mean = df_metrics[col].mean()
        std = df_metrics[col].std()
        row[col] = f"{mean:.3f} ± {std:.3f}"
        
    final_summary.append(row)

df_final_table = pd.DataFrame(final_summary)
df_final_table.set_index('Threshold', inplace=True)


In [12]:
# Wyświetlanie (transponowane, żeby przypominało obrazek, jeśli wolisz)
print("\n=== FINAL CROSS-VALIDATION TABLE ===")
display(df_final_table.T) # T da Ci metryki w wierszach, progi w kolumnach


=== FINAL CROSS-VALIDATION TABLE ===


Threshold,0.35,0.40,0.45,0.50,0.55,0.60,0.65,0.70,0.75,0.80,0.85,0.90,0.95
val_f2,0.732 ± 0.021,0.769 ± 0.021,0.804 ± 0.008,0.812 ± 0.021,0.820 ± 0.038,0.813 ± 0.046,0.774 ± 0.044,0.710 ± 0.060,0.639 ± 0.098,0.540 ± 0.162,0.408 ± 0.133,0.318 ± 0.081,0.099 ± 0.045
val_acc,0.763 ± 0.018,0.814 ± 0.024,0.855 ± 0.012,0.886 ± 0.007,0.906 ± 0.011,0.920 ± 0.012,0.923 ± 0.013,0.922 ± 0.016,0.921 ± 0.017,0.915 ± 0.028,0.902 ± 0.026,0.892 ± 0.021,0.868 ± 0.017
oos_prec,0.383 ± 0.004,0.448 ± 0.005,0.517 ± 0.005,0.587 ± 0.005,0.653 ± 0.004,0.710 ± 0.005,0.752 ± 0.007,0.789 ± 0.007,0.822 ± 0.010,0.853 ± 0.008,0.877 ± 0.009,0.898 ± 0.007,0.895 ± 0.003
oos_rec,0.982 ± 0.001,0.972 ± 0.002,0.958 ± 0.002,0.938 ± 0.003,0.912 ± 0.004,0.875 ± 0.006,0.820 ± 0.010,0.740 ± 0.012,0.638 ± 0.011,0.519 ± 0.011,0.393 ± 0.010,0.248 ± 0.011,0.101 ± 0.005
oos_f2,0.748 ± 0.003,0.787 ± 0.003,0.818 ± 0.002,0.838 ± 0.003,0.845 ± 0.002,0.836 ± 0.004,0.805 ± 0.007,0.749 ± 0.009,0.668 ± 0.009,0.563 ± 0.010,0.441 ± 0.009,0.290 ± 0.011,0.122 ± 0.006


In [13]:
filename_base = "labse_smaller_finetune_results"

# 1. Zapis do CSV (dla łatwego podglądu)
csv_path = f"output/{filename_base}.csv"
# Upewnij się, że katalog output istnieje
os.makedirs("output", exist_ok=True)

df_final_table.to_csv(csv_path)
print(f"Zapisano tabelę (CSV) do: {csv_path}")

# 2. Zapis do Parquet (dla zachowania precyzji i typów)
parquet_path = f"output/{filename_base}.parquet"
df_final_table.to_parquet(parquet_path)
print(f"Zapisano tabelę (Parquet) do: {parquet_path}")

Zapisano tabelę (CSV) do: output/labse_smaller_finetune_results.csv
Zapisano tabelę (Parquet) do: output/labse_smaller_finetune_results.parquet


# Export LATEX results

In [14]:
def export_latex_table(
    df_results: pd.DataFrame, 
    num_folds: int, 
    model_name: str, 
    selected_thresholds: Optional[List[float]] = None,
    label_suffix: str = "model"
) -> None:
    """
    Generates and prints a LaTeX table code snippet from a results DataFrame, 
    formatted specifically for academic papers. Allows filtering by specific thresholds.

    Args:
        df_results (pd.DataFrame): DataFrame containing result strings (e.g., "0.850 ± 0.010").
                                   Expected index: Thresholds (float).
        num_folds (int): Number of folds used in cross-validation.
        model_name (str): Name of the model.
        selected_thresholds (List[float], optional): List of thresholds to include in the table.
                                                     If None, all thresholds from df_results are used.
        label_suffix (str): Suffix for the LaTeX label.

    Raises:
        ValueError: If any of the selected_thresholds are not present in df_results index.
    """
    
    # 1. Filter Thresholds if specific list provided
    if selected_thresholds is not None:
        # Check for missing thresholds
        missing = [t for t in selected_thresholds if t not in df_results.index]
        if missing:
            available = df_results.index.tolist()
            raise ValueError(
                f"Thresholds {missing} not found in results.\n"
                f"Available thresholds are: {available}"
            )
        
        # Filter the DataFrame ensuring the order matches selected_thresholds
        df_filtered = df_results.loc[selected_thresholds]
    else:
        df_filtered = df_results

    # 2. Transpose: Thresholds become columns, Metrics become rows
    df_T = df_filtered.T
    
    # 3. Define Metric Mapping
    metric_map = {
        'val_f2': 'F$_2$ score',
        'val_acc': 'Accuracy',
        'oos_prec': 'Precision',
        'oos_rec': 'Recall',
        'oos_f2': 'F$_2$ score'
    }
    
    # 4. Helper function to format a row
    def format_row_content(metric_key, row_series):
        latex_cells = []
        
        # Parse values to find max for bolding (only among selected thresholds)
        values_dict = {}
        for col_idx, val_str in row_series.items():
            clean_val = str(val_str).replace('±', '').replace('$\pm$', '')
            try:
                mean_val = float(clean_val.split()[0])
                values_dict[col_idx] = mean_val
            except (ValueError, IndexError):
                values_dict[col_idx] = -1.0

        max_val = max(values_dict.values()) if values_dict else 0
        
        for col_idx, val_str in row_series.items():
            formatted_str = str(val_str).replace('±', r'$\pm$')
            
            # Bold if it's the max value in this specific view
            if values_dict.get(col_idx) == max_val and max_val > 0:
                formatted_str = f"\\textbf{{{formatted_str}}}"
            
            latex_cells.append(formatted_str)
            
        return f"{metric_map.get(metric_key, metric_key)} & " + " & ".join(latex_cells) + " \\\\"

    # 5. Prepare Headers
    thresholds = df_T.columns.tolist()
    header_cols = " & ".join([f"{t:.2f}" for t in thresholds])
    col_def = "c" * len(thresholds)
    
    # 6. Build LaTeX
    latex_code = [
        r"\begin{table*}[!ht]",
        r"    \centering",
        f"    \\caption{{{model_name}: cross-validated performance \\\\for different decision thresholds (mean $\\pm$ standard deviation over {num_folds} folds).}}",
        f"    \\label{{tab:{label_suffix}_thresholds}}",
        f"    \\begin{{tabular}}{{l{col_def}}}",
        r"        \hline",
        f"        & \\multicolumn{{{len(thresholds)}}}{{c}}{{Decision threshold $\\tau$}} \\\\",
        f"        \\cline{{2-{len(thresholds)+1}}}",
        f"        Metric & {header_cols} \\\\",
        r"        \hline",
        r"        \multicolumn{" + str(len(thresholds)+1) + r"}{l}{\textit{Validation subset}} \\"
    ]

    for metric in ['val_f2', 'val_acc']:
        if metric in df_T.index:
            latex_code.append("        " + format_row_content(metric, df_T.loc[metric]))

    latex_code.append(r"        \hline")
    latex_code.append(r"        \multicolumn{" + str(len(thresholds)+1) + r"}{l}{\textit{Out-of-sample evaluation pool}} \\")

    for metric in ['oos_prec', 'oos_rec', 'oos_f2']:
        if metric in df_T.index:
            latex_code.append("        " + format_row_content(metric, df_T.loc[metric]))

    latex_code.extend([
        r"        \hline",
        r"    \end{tabular}",
        r"\end{table*}"
    ])

    print("\n".join(latex_code))

def export_latex_table(
    df_results: pd.DataFrame, 
    num_folds: int, 
    model_name: str, 
    selected_thresholds: Optional[List[float]] = None,
    label_suffix: str = "model"
) -> None:
    """
    Generates and prints a LaTeX table code snippet from a results DataFrame, 
    formatted specifically for academic papers. Allows filtering by specific thresholds.

    Args:
        df_results (pd.DataFrame): DataFrame containing result strings (e.g., "0.850 ± 0.010").
                                   Expected index: Thresholds (float).
        num_folds (int): Number of folds used in cross-validation.
        model_name (str): Name of the model.
        selected_thresholds (List[float], optional): List of thresholds to include in the table.
                                                     If None, all thresholds from df_results are used.
        label_suffix (str): Suffix for the LaTeX label.

    Raises:
        ValueError: If any of the selected_thresholds are not present in df_results index.
    """
    
    # 1. Filter Thresholds if specific list provided
    if selected_thresholds is not None:
        # Check for missing thresholds
        missing = [t for t in selected_thresholds if t not in df_results.index]
        if missing:
            available = df_results.index.tolist()
            raise ValueError(
                f"Thresholds {missing} not found in results.\n"
                f"Available thresholds are: {available}"
            )
        
        # Filter the DataFrame ensuring the order matches selected_thresholds
        df_filtered = df_results.loc[selected_thresholds]
    else:
        df_filtered = df_results

    # 2. Transpose: Thresholds become columns, Metrics become rows
    df_T = df_filtered.T
    
    # 3. Define Metric Mapping
    metric_map = {
        'val_f2': 'F$_2$ score',
        'val_acc': 'Accuracy',
        'oos_prec': 'Precision',
        'oos_rec': 'Recall',
        'oos_f2': 'F$_2$ score'
    }
    
    # 4. Helper function to format a row
    def format_row_content(metric_key, row_series):
        latex_cells = []
        
        # Parse values to find max for bolding (only among selected thresholds)
        values_dict = {}
        for col_idx, val_str in row_series.items():
            clean_val = str(val_str).replace('±', '').replace('$\pm$', '')
            try:
                mean_val = float(clean_val.split()[0])
                values_dict[col_idx] = mean_val
            except (ValueError, IndexError):
                values_dict[col_idx] = -1.0

        max_val = max(values_dict.values()) if values_dict else 0
        
        for col_idx, val_str in row_series.items():
            formatted_str = str(val_str).replace('±', r'$\pm$')
            
            # Bold if it's the max value in this specific view
            if values_dict.get(col_idx) == max_val and max_val > 0:
                formatted_str = f"\\textbf{{{formatted_str}}}"
            
            latex_cells.append(formatted_str)
            
        return f"{metric_map.get(metric_key, metric_key)} & " + " & ".join(latex_cells) + " \\\\"

    # 5. Prepare Headers
    thresholds = df_T.columns.tolist()
    header_cols = " & ".join([f"{t:.2f}" for t in thresholds])
    col_def = "c" * len(thresholds)
    
    # 6. Build LaTeX
    latex_code = [
        r"\begin{table*}[!ht]",
        r"    \centering",
        f"    \\caption{{{model_name}: cross-validated performance \\\\for different decision thresholds (mean $\\pm$ standard deviation over {num_folds} folds).}}",
        f"    \\label{{tab:{label_suffix}_thresholds}}",
        f"    \\begin{{tabular}}{{l{col_def}}}",
        r"        \hline",
        f"        & \\multicolumn{{{len(thresholds)}}}{{c}}{{Decision threshold $\\tau$}} \\\\",
        f"        \\cline{{2-{len(thresholds)+1}}}",
        f"        Metric & {header_cols} \\\\",
        r"        \hline",
        r"        \multicolumn{" + str(len(thresholds)+1) + r"}{l}{\textit{Validation subset}} \\"
    ]

    for metric in ['val_f2', 'val_acc']:
        if metric in df_T.index:
            latex_code.append("        " + format_row_content(metric, df_T.loc[metric]))

    latex_code.append(r"        \hline")
    latex_code.append(r"        \multicolumn{" + str(len(thresholds)+1) + r"}{l}{\textit{Out-of-sample evaluation pool}} \\")

    for metric in ['oos_prec', 'oos_rec', 'oos_f2']:
        if metric in df_T.index:
            latex_code.append("        " + format_row_content(metric, df_T.loc[metric]))

    latex_code.extend([
        r"        \hline",
        r"    \end{tabular}",
        r"\end{table*}"
    ])

    print("\n".join(latex_code))

In [15]:
export_latex_table(
    df_results=df_final_table, 
    num_folds=N_FOLDS, 
    model_name="LaBSE", 
    label_suffix="labse_smaller_finetune",
    selected_thresholds=[0.35, 0.4, 0.45, 0.5, ] 
)

\begin{table*}[!ht]
    \centering
    \caption{LaBSE: cross-validated performance \\for different decision thresholds (mean $\pm$ standard deviation over 5 folds).}
    \label{tab:labse_smaller_finetune_thresholds}
    \begin{tabular}{lcccc}
        \hline
        & \multicolumn{4}{c}{Decision threshold $\tau$} \\
        \cline{2-5}
        Metric & 0.35 & 0.40 & 0.45 & 0.50 \\
        \hline
        \multicolumn{5}{l}{\textit{Validation subset}} \\
        F$_2$ score & 0.732 $\pm$ 0.021 & 0.769 $\pm$ 0.021 & 0.804 $\pm$ 0.008 & \textbf{0.812 $\pm$ 0.021} \\
        Accuracy & 0.763 $\pm$ 0.018 & 0.814 $\pm$ 0.024 & 0.855 $\pm$ 0.012 & \textbf{0.886 $\pm$ 0.007} \\
        \hline
        \multicolumn{5}{l}{\textit{Out-of-sample evaluation pool}} \\
        Precision & 0.383 $\pm$ 0.004 & 0.448 $\pm$ 0.005 & 0.517 $\pm$ 0.005 & \textbf{0.587 $\pm$ 0.005} \\
        Recall & \textbf{0.982 $\pm$ 0.001} & 0.972 $\pm$ 0.002 & 0.958 $\pm$ 0.002 & 0.938 $\pm$ 0.003 \\
        F$_2$ score &

# ROC curve

In [16]:
# 1. Przygotuj dane (np. ze zbioru OOS Evaluation Pool)
df_eval = df_oos.copy() # Używamy Twojego zbioru "Evaluation Pool"

# 2. Wygeneruj embeddingi
embs1 = model.encode(df_eval['plLabel'].tolist(), convert_to_tensor=True, show_progress_bar=True)
embs2 = model.encode(df_eval['alias'].tolist(), convert_to_tensor=True, show_progress_bar=True)

# 3. Oblicz podobieństwo (to są Twoje y_scores!)
# util.cos_sim zwraca macierz, bierzemy przekątną (pary odpowiadające sobie wierszami)
cosine_scores = util.cos_sim(embs1, embs2).diagonal().cpu().numpy()

# 4. Przygotuj etykiety (ground truth)
y_true = df_eval['truth'].astype(int).to_numpy() # Upewnij się, że to 0 i 1


Batches:   0%|          | 0/1247 [00:00<?, ?it/s]

Batches: 100%|██████████| 1247/1247 [00:11<00:00, 112.22it/s]
Batches: 100%|██████████| 1247/1247 [00:09<00:00, 131.50it/s]


In [18]:
def plot_roc_curve_scores(
        y_true: np.ndarray,
        y_scores: np.ndarray,
        title: str = "ROC curve",
        save_path: Optional[Path] = None,
        mark_threshold: Optional[float] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
    """
    Plot ROC curve based on pre-calculated scores (e.g., cosine similarity) 
    and optionally save to file. Maintains exact visual consistency with previous plots.
    Plot ROC curve using Seaborn's 'darkgrid' style to exactly match the reference chart.

    Args:
        y_true: Ground truth binary labels (0 or 1).
        y_scores: Continuous scores (e.g., probabilities or cosine similarity).
        title: Chart title.
        save_path: Path to save the PDF/PNG.
        mark_threshold: Specific threshold value to mark on the curve with a dot.

    Returns:
        fpr, tpr, thresholds, auc_value
    """
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    auc_value = auc(fpr, tpr)

    # 1. Styl siatki (ustalony w poprzednim kroku - zostawiamy 0.5 width)
    custom_dotted_style = (0, (1, 1.5))
    sns.set_theme(
        style="darkgrid", 
        rc={
            "grid.linestyle": custom_dotted_style, 
            "grid.linewidth": 0.5,   # Twoje ustalone  
            "grid.color": "white",   
            "axes.edgecolor": "white"
        }
    )

    fig, ax = plt.subplots(figsize=(4.0, 4.0), dpi=300)

    # 2. Rysowanie danych
    ax.plot(fpr, tpr, color='C0', lw=2, label=f"AUC = {auc_value:.3f}")
    ax.plot([0, 1], [0, 1], color='C1', linestyle='--', lw=1, label="Random")

    if mark_threshold is not None:
        idx = np.argmin(np.abs(thresholds - mark_threshold))
        ax.scatter(
            fpr[idx],
            tpr[idx],
            s=25,
            color='C0',
            marker='o',
            zorder=5,
            label=f"Threshold = {mark_threshold:.2f}"
        )

    # --- 3. KOSMETYKA CZCIONEK (Zmniejszanie) ---
    
    # Tytuł: nieco mniejszy (zazwyczaj jest 12-14, dajemy 12)
    ax.set_title(title, fontsize=11.5) 
    
    # Etykiety osi: wyraźnie mniejsze niż teraz (dajemy 11)
    ax.set_xlabel("False positive rate", fontsize=11.5)
    ax.set_ylabel("True positive rate", fontsize=11.5)
    
    # Ticki (liczby 0.0, 0.2...): mniejsze (dajemy 10)
    ax.tick_params(axis='both', which='major', labelsize=11)
    
    # Legenda: też mniejsza (dajemy 10)
    ax.legend(loc="lower right", fontsize=10.5)
    
    # ---------------------------------------------

    ax.set_xlim(-0.0, 1.0)
    ax.set_ylim(-0.0, 1.02)
    ax.set_aspect("equal", "box")

    fig.tight_layout()

    if save_path is not None:
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, bbox_inches="tight")
        print(f"Wykres zapisany do: {save_path}")
        plt.close(fig)
    else:
        plt.show()

    return fpr, tpr, thresholds, auc_value

In [19]:
# 5. Rysuj wykres
fig_path_labse = Path("figures/roc_labse_similarity_smaller_tuning.pdf")
best_threshold = 0.50 # Twój wybrany próg (np. z tabeli LaTeX)

fpr, tpr, ths, auc_val = plot_roc_curve_scores(
    y_true=y_true,
    y_scores=cosine_scores,
    title="LaBSE - ROC curve",
    save_path=fig_path_labse,
    mark_threshold=best_threshold
)

print(f"LaBSE AUC: {auc_val:.3f}")

Wykres zapisany do: figures/roc_labse_similarity_smaller_tuning.pdf
LaBSE AUC: 0.966


# END