In [1]:
!pip install transformers datasets wandb scikit-learn matplotlib seaborn textstat huggingface_hub



In [2]:
class Config:
    model_name = 'roberta-base'
    hidden_dim = 768
    pred_depth = 3
    batch_size = 16
    max_length = 128
    seed = 11
    pred_dim = 16

In [3]:
import torch
import torch.nn as nn
from transformers import AutoModel

class JEPAEncoder(nn.Module):
    def __init__(self, model_name='roberta-base', hidden_dim=768):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
        )
    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        return self.projection(outputs.last_hidden_state[:, 0, :])

class JEPAPredictor(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=16, output_dim=768, depth=3):
        super().__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU()]
        for _ in range(depth-2):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU()])
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.layers = nn.Sequential(*layers)
    def forward(self, x):
        return self.layers(x)


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch
import torch.nn as nn
from JEPA_Models import JEPAEncoder, JEPAPredictor
import copy

class ParaJEPA(nn.Module):
    def __init__(self, model_name='roberta-base', hidden_dim=768, pred_depth=3):
        super().__init__()
        self.context_encoder = JEPAEncoder(model_name, hidden_dim)
        self.target_encoder = copy.deepcopy(self.context_encoder)
        self.predictor = JEPAPredictor(input_dim=hidden_dim, hidden_dim=128, output_dim=hidden_dim, depth=pred_depth)

    def forward(self, style_input, content_input):
        # We only need encoding for inference/probing
        context_embeddings = self.context_encoder(style_input['input_ids'], style_input['attention_mask'])
        return context_embeddings

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import textstat
import numpy as np

class ProbingDataset(Dataset):
    def __init__(self, tokenizer, max_length=128, split="validation", max_samples=1000):
        self.tokenizer = tokenizer
        self.max_length = max_length
        # Load a subset for quick probing
        self.dataset = load_dataset("GEM/wiki_auto_asset_turk", split=split)
        if max_samples:
            self.dataset = self.dataset.select(range(min(len(self.dataset), max_samples)))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # We use the 'source' (Complex) text for probing features
        text = item['source']

        # Calculate Style Features (Ground Truth)
        # 1. Length (Word count)
        length_score = len(text.split())
        # 2. Complexity (Flesch-Kincaid Grade)
        readability_score = textstat.flesch_kincaid_grade(text)

        enc = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'length_score': torch.tensor(length_score, dtype=torch.float),
            'readability_score': torch.tensor(readability_score, dtype=torch.float)
        }

In [6]:
import torch
from transformers import AutoTokenizer, AutoModel
from para_jepa_train import ParaJEPA
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import r2_score
from huggingface_hub import hf_hub_download
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from para_jepa_conditional import ParaJEPAConditional

def get_embeddings_and_features(model, loader, device):
    model.eval()
    embeddings = []
    lengths = []
    readability = []

    print("Extracting embeddings and style features...")
    with torch.no_grad():
        for batch in tqdm(loader):
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)

            # Extract Content Embedding (Using context_encoder as proxy for learned rep)
            # In ParaJEPA, context_encoder is the main trainable encoder
            emb = model.context_encoder(input_ids, mask)

            embeddings.append(emb.cpu().numpy())
            lengths.append(batch['length_score'].numpy())
            readability.append(batch['readability_score'].numpy())

    return np.vstack(embeddings), np.concatenate(lengths), np.concatenate(readability)

def run_probe(X, y, feature_name):
    # Split train/test (80/20) for the probe
    split = int(len(X) * 0.8)
    X_train, X_test = X[:split], X[split:]
    y_train, y_test = y[:split], y[split:]

    # Train Linear Probe
    reg = Ridge(alpha=1.0)
    reg.fit(X_train, y_train)
    preds = reg.predict(X_test)

    # Calculate R2 Score (1.0 is perfect prediction, 0.0 is random guessing)
    score = r2_score(y_test, preds)
    return score

# --- MAIN EXECUTION ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'roberta-base'

# 1. Load Model
print("Downloading and Loading Model...")
try:
    #path = hf_hub_download(repo_id="ege-dgny/ParaJEPA-03-25-2025", filename="para_jepa_best_model.pt")
    path = "para_jepa_best_model.pt"
    model = ParaJEPA(model_name=model_name, pred_hidden_dim=128, funnel=True).to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    print("Model Loaded Successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    # Initialize blank for debugging if download fails
    model = ParaJEPA(model_name=model_name).to(device)

# 2. Prepare Data
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = ProbingDataset(tokenizer, max_samples=2000) # 2000 samples for robust stats
loader = DataLoader(dataset, batch_size=32, shuffle=True) # Shuffle to mix styles

# 3. Extract
X, y_len, y_read = get_embeddings_and_features(model, loader, device)

# 4. Run Probes
print("\n" + "="*50)
print("STYLE INVARIANCE STRESS TEST")
print("="*50)
print("Hypothesis: A perfect content embedding should NOT correlate with style features.")
print("Goal: Low R² scores (close to 0). High scores indicate 'leaking' style info.")
print("-" * 50)

# Length Probe
len_score = run_probe(X, y_len, "Sentence Length")
print(f"Probe Task: Predict Sentence Length")
print(f"R² Score: {len_score:.4f}  [{'PASSED' if len_score < 0.3 else 'FAILED'}]")
if len_score < 0.3: print("   -> Great! The model largely ignores length.")
else: print("   -> Warning: The model still encodes length information.")

print("-" * 20)

# Complexity Probe
read_score = run_probe(X, y_read, "Readability (Complexity)")
print(f"Probe Task: Predict Readability (Flesch-Kincaid)")
print(f"R² Score: {read_score:.4f}  [{'PASSED' if read_score < 0.3 else 'FAILED'}]")
if read_score < 0.3: print("   -> Great! The model abstracts away complexity.")
else: print("   -> Warning: The model still encodes complexity.")

print("="*50)

# 5. Visual Baseline Comparison (Optional)
# Compare against raw RoBERTa (untrained) to see improvement
print("\nRunning Baseline (Untrained RoBERTa) for comparison...")
raw_model = AutoModel.from_pretrained('roberta-base').to(device)
raw_model.eval()
raw_embeddings = []
with torch.no_grad():
    for batch in tqdm(loader):
        out = raw_model(batch['input_ids'].to(device), batch['attention_mask'].to(device))
        raw_embeddings.append(out.last_hidden_state[:, 0, :].cpu().numpy())
X_raw = np.vstack(raw_embeddings)

raw_len_score = run_probe(X_raw, y_len, "Length")
raw_read_score = run_probe(X_raw, y_read, "Readability")

print("\n--- COMPARISON ---")
print(f"Length Prediction R²:      ParaJEPA {len_score:.2f} vs RoBERTa {raw_len_score:.2f}")
print(f"Readability Prediction R²: ParaJEPA {read_score:.2f} vs RoBERTa {raw_read_score:.2f}")
print("Lower score for ParaJEPA compared to RoBERTa indicates successful disentanglement.")

Downloading and Loading Model...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model Loaded Successfully.
Extracting embeddings and style features...


100%|██████████| 63/63 [00:04<00:00, 14.85it/s]



STYLE INVARIANCE STRESS TEST
Hypothesis: A perfect content embedding should NOT correlate with style features.
Goal: Low R² scores (close to 0). High scores indicate 'leaking' style info.
--------------------------------------------------
Probe Task: Predict Sentence Length
R² Score: -0.0019  [PASSED]
   -> Great! The model largely ignores length.
--------------------
Probe Task: Predict Readability (Flesch-Kincaid)
R² Score: -0.0011  [PASSED]
   -> Great! The model abstracts away complexity.

Running Baseline (Untrained RoBERTa) for comparison...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 63/63 [00:03<00:00, 18.65it/s]



--- COMPARISON ---
Length Prediction R²:      ParaJEPA -0.00 vs RoBERTa -0.12
Readability Prediction R²: ParaJEPA -0.00 vs RoBERTa -0.06
Lower score for ParaJEPA compared to RoBERTa indicates successful disentanglement.


In [7]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from scipy.stats import spearmanr
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer
from para_jepa_train import ParaJEPA
from huggingface_hub import hf_hub_download
from para_jepa_conditional import ParaJEPAConditional

# --- HELPER FUNCTIONS ---

def get_batch_embeddings(model, texts, tokenizer, device):
    """Encodes a list of texts into embeddings."""
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=128).to(device)
    with torch.no_grad():
        # Using context_encoder as the main representation
        embeddings = model.context_encoder(inputs['input_ids'], inputs['attention_mask'])
    return embeddings

def evaluate_stsb(model, tokenizer, device):
    """
    Evaluates Semantic Textual Similarity (STS-B).
    Computes Spearman correlation between cosine similarity of embeddings and human labels.
    """
    print("\n--- 1. Semantic Utility: STS-Benchmark (Validation) ---")
    stsb = load_dataset("glue", "stsb", split="validation")

    similarities = []
    labels = []

    batch_size = 32
    dataloader = DataLoader(stsb, batch_size=batch_size)

    print("Computing STS-B correlations...")
    for batch in tqdm(dataloader, desc="STS-B"):
        emb1 = get_batch_embeddings(model, batch['sentence1'], tokenizer, device)
        emb2 = get_batch_embeddings(model, batch['sentence2'], tokenizer, device)

        # Cosine Similarity
        cos_sim = F.cosine_similarity(emb1, emb2)

        similarities.extend(cos_sim.cpu().numpy())
        labels.extend(batch['label'].numpy())

    # Spearman Correlation
    spearman_corr, _ = spearmanr(similarities, labels)
    return spearman_corr

def evaluate_retrieval(model, tokenizer, device, num_samples=1000):
    """
    Evaluates Content Preservation via Retrieval.
    Given a Complex sentence, can we find its Simple pair in a batch of `num_samples`?
    """
    print(f"\n--- 2. Advanced Disentanglement: Content Retrieval (N={num_samples}) ---")
    # Load WikiAutoAsset (Complex -> Simple pairs)
    dataset = load_dataset("GEM/wiki_auto_asset_turk", split="validation")
    dataset = dataset.select(range(min(len(dataset), num_samples)))

    complex_texts = [item['source'] for item in dataset]
    # Use first reference or target as the "Simple" pair
    simple_texts = []
    for item in dataset:
        if 'references' in item and len(item['references']) > 0:
            simple_texts.append(item['references'][0])
        elif 'target' in item:
            simple_texts.append(item['target'])
        else:
            simple_texts.append(item['source']) # Fallback

    # Batch Encode
    batch_size = 32
    complex_embs = []
    simple_embs = []

    print("Encoding Retrieval Corpus...")
    for i in tqdm(range(0, len(complex_texts), batch_size)):
        c_batch = complex_texts[i:i+batch_size]
        s_batch = simple_texts[i:i+batch_size]

        complex_embs.append(get_batch_embeddings(model, c_batch, tokenizer, device).cpu())
        simple_embs.append(get_batch_embeddings(model, s_batch, tokenizer, device).cpu())

    complex_embs = torch.cat(complex_embs) # [N, 768]
    simple_embs = torch.cat(simple_embs)   # [N, 768]

    # Normalize for Cosine Similarity
    complex_embs = F.normalize(complex_embs, p=2, dim=1)
    simple_embs = F.normalize(simple_embs, p=2, dim=1)

    # Similarity Matrix: [N, N] (Complex vs All Simples)
    sim_matrix = torch.mm(complex_embs, simple_embs.t())

    # Calculate Recall@K
    recall_k = {1: 0, 5: 0, 10: 0}
    num_correct = 0

    # Explicit loop to check rank of correct pair (diagonal)
    for i in range(len(sim_matrix)):
        scores = sim_matrix[i]
        target_score = scores[i].item()

        # Count how many distractors have a higher score than the target
        # (Rank = count + 1)
        rank = (scores > target_score).sum().item() + 1

        if rank <= 1: recall_k[1] += 1
        if rank <= 5: recall_k[5] += 1
        if rank <= 10: recall_k[10] += 1

    for k in recall_k:
        recall_k[k] /= num_samples

    # Geometric Metrics (Alignment & Uniformity)
    # Alignment: Expected distance between positive pairs
    alignment_loss = (complex_embs - simple_embs).pow(2).sum(dim=1).mean().item()

    return recall_k, alignment_loss

# --- MAIN EXECUTION ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'roberta-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 1. Load Model
print("Loading Model...")
try:
    #path = hf_hub_download(repo_id="ege-dgny/ParaJEPA-03-25-2025", filename="para_jepa_best_model.pt")
    path = "para_jepa_best_model.pt"
    model = ParaJEPA(model_name=model_name, pred_hidden_dim=128, funnel=True).to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    print("Model Loaded.")
except Exception as e:
    print(f"Error: {e}")
    model = ParaJEPA(model_name=model_name).to(device)

# 2. Run Evaluations
sts_score = evaluate_stsb(model, tokenizer, device)
recall_results, alignment = evaluate_retrieval(model, tokenizer, device, num_samples=1000)

print("\n" + "="*60)
print("PHASE 2 EVALUATION REPORT")
print("="*60)
print(f"1. Semantic Utility (STS-B Spearman): {sts_score:.4f}")
print(f"   (Reference: RoBERTa-base untrained ~0.2-0.4, trained ~0.85)")
print("-" * 60)
print(f"2. Content Retrieval (Recall@K with 1000 distractors)")
print(f"   R@1:  {recall_results[1]:.4f}  (Perfect match found first?)")
print(f"   R@5:  {recall_results[5]:.4f}")
print(f"   R@10: {recall_results[10]:.4f}")
print("-" * 60)
print(f"3. Geometric Alignment: {alignment:.4f}")
print(f"   (Lower is better. Measures raw distance between style pairs.)")
print("="*60)

Loading Model...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model Loaded.

--- 1. Semantic Utility: STS-Benchmark (Validation) ---
Computing STS-B correlations...


STS-B: 100%|██████████| 47/47 [00:01<00:00, 27.90it/s]



--- 2. Advanced Disentanglement: Content Retrieval (N=1000) ---
Encoding Retrieval Corpus...


100%|██████████| 32/32 [00:01<00:00, 25.60it/s]



PHASE 2 EVALUATION REPORT
1. Semantic Utility (STS-B Spearman): 0.0147
   (Reference: RoBERTa-base untrained ~0.2-0.4, trained ~0.85)
------------------------------------------------------------
2. Content Retrieval (Recall@K with 1000 distractors)
   R@1:  0.2770  (Perfect match found first?)
   R@5:  0.3450
   R@10: 0.3800
------------------------------------------------------------
3. Geometric Alignment: 0.0000
   (Lower is better. Measures raw distance between style pairs.)
