In [193]:
import torch
import torch.nn.functional as F


def pragmatic_loss(
    logprobs: torch.FloatTensor, 
    max_iter: int = 100,
    epsilon: float = 1e-5,
) -> torch.FloatTensor:
    """Compute the pragmatic loss for a batch of response log probabilities.
    
    Args:
        logprobs: The log probabilities of the responses. Shape: (constitution_batch_size, constitution_batch_size * response_batch_size).
        max_iter: The maximum number of iterations for the pragmatic recursion.
        epsilon: The convergence threshold for the pragmatic recursion.
        
    Returns:
        pragmatic_loss: The pragmatic loss for the batch of responses. 
    """        
    # have constitutions compete for responses
    probs = torch.softmax(logprobs, dim=0) 

    for _ in range(max_iter):

        # row normalization
        probs = probs / probs.sum(dim=1, keepdim=True)
        
        # check convergence
        col_sums = probs.sum(dim=0)

        if torch.max(col_sums) - torch.min(col_sums) < epsilon:
            break
        
        # column normalization
        probs = probs / probs.sum(dim=0, keepdim=True)

    # use probs as class probabilities to compute the loss
    loss = F.cross_entropy(logprobs, probs, reduction="mean")
 
    return loss 


In [194]:
logprobs = torch.FloatTensor([[-1, -30000, -2], [-2, -1, -2], [-2, -2, -1]])

a = torch.FloatTensor([0.692, 0.000, 0.3038])
b = torch.FloatTensor([-1, -30000, -2])

F.cross_entropy(b, a, reduction="none")

tensor(0.6157)

In [195]:
pragmatic_loss(logprobs, max_iter=100, epsilon=1e-10)

tensor(0.8308)

In [143]:
F.cross_entropy(torch.Float)

0.9998999999999999

In [138]:
import numpy as np

def sinkhorn_knopp(matrix, tol=1e-5, max_iters=1000):
    """
    Scale a non-negative 3x3 matrix into a doubly stochastic matrix using the Sinkhorn-Knopp algorithm.

    Args:
        matrix (np.ndarray): A 3x3 non-negative matrix.
        tol (float): Tolerance for the convergence criterion.
        max_iters (int): Maximum number of iterations.

    Returns:
        np.ndarray: A doubly stochastic matrix.
    """
    if not isinstance(matrix, np.ndarray):
        raise ValueError("Input must be a numpy ndarray.")
    if matrix.shape != (3, 3):
        raise ValueError("Input matrix must be 3x3.")
    
    # Initial normalization of the matrix
    # matrix = matrix / np.sum(matrix)
    
    for _ in range(max_iters):
        print(_)
        # Scale rows to sum to 1
        matrix = matrix / matrix.sum(axis=1, keepdims=True)
        # Scale columns to sum to 1
        matrix = matrix / matrix.sum(axis=0, keepdims=True)
        
        # Check for convergence
        row_sums = matrix.sum(axis=1)
        col_sums = matrix.sum(axis=0)
        if np.all(np.abs(row_sums - 1) < tol) and np.all(np.abs(col_sums - 1) < tol):
            break

    return matrix

# Example usage
matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
doubly_stochastic_matrix = sinkhorn_knopp(matrix)
print(doubly_stochastic_matrix)


0
1
2
3
[[0.24241937 0.34887671 0.40870344]
 [0.3646397  0.32798095 0.30737952]
 [0.39294093 0.32314234 0.28391704]]


In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" # the device to load the model onto

# model = AutoModelForCausalLM.from_pretrained(cache_dir="/Users/iphilipp/Documents/research/scai-tuning/pragmatics/conf/model/local_cache")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")



print(tokenizer.batch_decode(encodeds))

["<s> [INST] What is your favourite condiment? [/INST]Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!</s>  [INST] Do you have mayonnaise recipes? [/INST]"]
