# File for testing stuff

# Precompute the embeddings

In [None]:
# Note: This model is used for precomputing the Token embeddings. This model includes contextual embeddings.
# We use DistilBERT for this purpose.

from transformers import DistilBertTokenizer, DistilBertModel
import torch

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')


# Later use one sentence from the onto notes dataset
sentence = "John lives in Paris."

inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding=True)

# Get the token embeddings from the last hidden state of the model
with torch.no_grad():
    outputs = model(**inputs)
    token_embeddings = outputs.last_hidden_state  # Shape: (batch_size, seq_len, embedding_dim)
    
# Check shape
print(token_embeddings.shape)
print(type(token_embeddings))



# Confidence Model

In [None]:
# Our custom loss function

def custom_loss(p, ner_losses, lambda_weight):
    """
    p: Tensor of shape (seq_len,) - output of sigmoid (probability of invoking NER)
    ner_losses: Tensor of shape (seq_len,) - NER losses per token
    """
    ner_term = (p * ner_losses).sum()  # Weighted NER loss for invoked tokens
    cost_term = lambda_weight * p.sum()  # Penalty for invoking NER (more invocations = higher cost)
    return ner_term + cost_term

In [None]:
from model.run_or_wait_classifier import ConfidenceScoreModel
confidence_model = ConfidenceScoreModel()

In [None]:
optimizer = torch.optim.Adam(confidence_model.parameters(), lr=1e-4)

for batch in dataloader:
    sentence, gold_labels = batch  # batch of one sequence (or padded batch)
    
    # Token embeddings from DistilBERT
    inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        token_embeddings = outputs.last_hidden_state  # Shape: (batch_size, seq_len, embedding_dim)
    
    # Get the confidence scores for each token
    p = confidence_model(token_embeddings.squeeze(0))  # Shape: (seq_len,)

    # Compute NER loss (assume we have precomputed it for each token)
    ner_losses = custom_loss(gold_labels)  # Use your NER loss function here
    
    # Compute the custom loss (NER loss + cost of invoking NER)
    loss = custom_loss(p, ner_losses, lambda_weight=0.1)

    # Backprop and update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
