# SBERT Unfrozen Inference

Lade das best gespeicherte Modell und teste es gegen verschiedene Datensaetze. Aendere einfach `DATA_PATH` pro Durchlauf.


In [11]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer, InputExample
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}')


Using device: cuda


In [12]:
# Params
DATA_PATH = Path('../data/ls_translations_labeled.csv')  # <- hier Datensatz pfad aendern
MODEL_DIR = Path('../models/sbert_unfrozen_best_http')
BATCH_SIZE = 32

if not MODEL_DIR.exists():
    raise FileNotFoundError(f'Model dir not found: {MODEL_DIR}')


In [13]:
# Load model + classifier
model = SentenceTransformer(str(MODEL_DIR), device=DEVICE)
classifier_path = MODEL_DIR / 'classifier.pt'
if not classifier_path.exists():
    raise FileNotFoundError(f'Missing classifier at {classifier_path}')

embed_dim = model.get_sentence_embedding_dimension()
classifier = torch.nn.Linear(embed_dim, 2).to(DEVICE)
classifier.load_state_dict(torch.load(classifier_path, map_location=DEVICE))
classifier.eval()
model.eval()


SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False, 'architecture': 'BertModel'})
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

In [14]:
# Load dataset
if not DATA_PATH.exists():
    raise FileNotFoundError(f'Missing dataset at {DATA_PATH}')

df = pd.read_csv(DATA_PATH)
df['sentence'] = df['sentence'].astype(str)
df['label'] = df['label'].astype(int)
print(df['label'].value_counts(normalize=True))

sentences = df['sentence'].tolist()
labels = df['label'].tolist()


label
1    0.798978
0    0.201022
Name: proportion, dtype: float64


In [15]:
def collate_fn(batch):
    sentences = [ex.texts[0] if isinstance(ex.texts, (list, tuple)) else ex.texts for ex in batch]
    labels = torch.tensor([int(ex.label) for ex in batch], dtype=torch.long)
    features = model.tokenize(sentences)
    return features, labels

def forward_batch(features):
    features = {k: v.to(DEVICE) for k, v in features.items()}
    emb = model(features)['sentence_embedding']
    logits = classifier(emb)
    return logits

def eval_dataset(sentences, labels):
    examples = [InputExample(texts=[s], label=int(l)) for s, l in zip(sentences, labels)]
    loader = DataLoader(examples, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    all_logits = []
    all_labels = []
    with torch.no_grad():
        for features, batch_labels in loader:
            logits = forward_batch(features)
            all_logits.append(logits.cpu())
            all_labels.append(batch_labels)

    logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    preds = logits.argmax(dim=1).numpy()

    acc = accuracy_score(all_labels.numpy(), preds)
    bal = balanced_accuracy_score(all_labels.numpy(), preds)

    print(f'Accuracy: {acc:.3f}')
    print(f'Balanced accuracy: {bal:.3f}')
    print('Classification report:')
    print(classification_report(all_labels.numpy(), preds, digits=3))

# Run evaluation
eval_dataset(sentences, labels)


Accuracy: 0.810
Balanced accuracy: 0.703
Classification report:
              precision    recall  f1-score   support

           0      0.529     0.524     0.526      2753
           1      0.881     0.882     0.882     10942

    accuracy                          0.810     13695
   macro avg      0.705     0.703     0.704     13695
weighted avg      0.810     0.810     0.810     13695

