In [None]:
import torch
import torch.nn.functional as F
from datasets import load_dataset, load_from_disk
from dotenv import load_dotenv
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor
import matplotlib.pyplot as plt
import seaborn as sns

load_dotenv()

## Generate Logits

In [None]:
TEACHER_MODEL_CARD = "thewalnutaisg/florence2-large-doclaynet-70k"
STUDENT_MODEL_CARD = "microsoft/Florence-2-base-ft"
DATASET_CARD = "katphlab/doclaynet-table"
PROMPT = "<OD>"
IGNORE_ID = -100  # Pytorch ignore index when computing loss
MAX_LENGTH = 512
DATA_SPLIT = "val"
RUN_NAME = "distillv2"
OUTPUT_DIR = "./runs"

In [None]:
# Initialize models and processor
processor = AutoProcessor.from_pretrained(STUDENT_MODEL_CARD, trust_remote_code=True)
config = AutoConfig.from_pretrained(TEACHER_MODEL_CARD, trust_remote_code=True)
config.vision_config.model_type = "davit"
teacher_model = AutoModelForCausalLM.from_pretrained(
    TEACHER_MODEL_CARD, trust_remote_code=True, config=config
)
# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model.to(device)

# Freeze the teacher model
for param in teacher_model.parameters():
    param.requires_grad = False

In [None]:
# Load COCO dataset
dataset = (
    load_dataset(DATASET_CARD, DATA_SPLIT, split=DATA_SPLIT, num_proc=12)
    .shuffle(seed=42)
    .select(range(10))
)

In [None]:
# Create DataLoader
def preprocess_function(examples):
    prompt_texts = [PROMPT] * len(examples["image"])

    inputs = processor(
        images=examples["image"],
        text=prompt_texts,
        return_tensors="pt",
        padding="longest",
        max_length=MAX_LENGTH,
    )

    labels = processor.tokenizer(
        examples["bbox_str"],
        return_tensors="pt",
        padding="longest",
        max_length=MAX_LENGTH,
        return_token_type_ids=False,
    )["input_ids"]

    labels[labels == processor.tokenizer.pad_token_id] = IGNORE_ID
    # No need to remove batch dimension as we're processing in batches
    inputs["labels"] = labels

    # Move all inputs to CUDA
    inputs = {k: v.to(device) for k, v in inputs.items()}
    labels = labels.to(device)

    # Compute teacher logits
    with torch.no_grad():
        teacher_model.eval()
        teacher_outputs = teacher_model(**inputs)
        teacher_logits = teacher_outputs.logits

    examples["teacher_logits"] = teacher_logits.cpu()

    for key, value in inputs.items():
        if isinstance(value, torch.Tensor):
            inputs[key] = value.tolist()

    return examples

In [None]:
logit_dataset = dataset.map(preprocess_function, batched=True, batch_size=2)
logit_dataset.save_to_disk("./data/teacher_logits")

## Visualize top k 

In [None]:
def compute_cpm_for_k(logits, k_values):
    """
    Compute the Cumulative Probability Mass (CPM) for different k values.

    Args:
    logits (torch.Tensor): The full logits from the teacher model.
    k_values (list): List of k values to evaluate.

    Returns:
    dict: A dictionary with k values as keys and mean CPM as values.
    """
    probs = F.softmax(logits, dim=-1)
    sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)

    cpm_dict = {}
    for k in k_values:
        top_k_probs = sorted_probs[:, :, :k]
        cpm = top_k_probs.sum(dim=-1).mean().item()
        cpm_dict[k] = cpm

    return cpm_dict


def select_optimal_k(cpm_dict, threshold=0.95):
    """
    Select the smallest k that achieves a CPM above the threshold.

    Args:
    cpm_dict (dict): Dictionary with k values as keys and mean CPM as values.
    threshold (float): The desired CPM threshold.

    Returns:
    int: The optimal k value.
    """
    for k, cpm in sorted(cpm_dict.items()):
        if cpm >= threshold:
            return k
    return max(cpm_dict.keys())  # If no k meets the threshold, return the largest k

In [None]:
def reconstruct_logits_from_topk(top_k_logits, top_k_indices, vocab_size):
    # Convert to tensors if they're lists
    top_k_logits = (
        torch.tensor(top_k_logits) if isinstance(top_k_logits, list) else top_k_logits
    )
    top_k_indices = (
        torch.tensor(top_k_indices)
        if isinstance(top_k_indices, list)
        else top_k_indices
    )

    # Ensure tensors are of type float and long respectively
    top_k_logits = top_k_logits.float()
    top_k_indices = top_k_indices.long()

    # Get the shape of the original logits
    seq_len, k = top_k_indices.shape

    # Create a tensor with the default logit value
    reconstructed_logits = torch.full((seq_len, vocab_size), 0.0)

    # Use scatter to place the top-k logits in the correct positions
    reconstructed_logits.scatter_(-1, top_k_indices, top_k_logits)

    return reconstructed_logits

In [None]:
def compute_cpm(logits):
    probs = F.softmax(logits, dim=-1)
    sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
    cpm = torch.cumsum(sorted_probs, dim=-1)
    return cpm


def plot_cpm_comparison(original_cpm, reconstructed_cpm, k):
    plt.figure(figsize=(12, 6))
    plt.plot(original_cpm.mean(dim=0).cpu().numpy(), label="Original")
    plt.plot(reconstructed_cpm.mean(dim=0).cpu().numpy(), label="Reconstructed")
    plt.axvline(x=k, color="r", linestyle="--", label=f"k={k}")
    plt.xlabel("Top-k")
    plt.ylabel("Cumulative Probability Mass")
    plt.title("Comparison of Original and Reconstructed CPM")
    plt.legend()
    plt.grid(True)
    plt.show()


def plot_logits_heatmap(original_logits, reconstructed_logits, k):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Ensure logits are 2D
    if original_logits.dim() == 1:
        original_logits = original_logits.unsqueeze(0)
    if reconstructed_logits.dim() == 1:
        reconstructed_logits = reconstructed_logits.unsqueeze(0)

    # Limit to top-k values
    top_k_original = torch.topk(
        original_logits, min(k, original_logits.size(-1)), dim=-1
    ).values
    top_k_reconstructed = torch.topk(
        reconstructed_logits, min(k, reconstructed_logits.size(-1)), dim=-1
    ).values

    sns.heatmap(top_k_original.cpu().numpy(), ax=ax1, cmap="viridis")
    ax1.set_title("Original Logits (Top-k)")
    ax1.set_xlabel("Token Index")
    ax1.set_ylabel("Sequence Position" if original_logits.dim() > 1 else "Top-k Values")

    sns.heatmap(top_k_reconstructed.cpu().numpy(), ax=ax2, cmap="viridis")
    ax2.set_title("Reconstructed Logits (Top-k)")
    ax2.set_xlabel("Token Index")
    ax2.set_ylabel(
        "Sequence Position" if reconstructed_logits.dim() > 1 else "Top-k Values"
    )

    plt.tight_layout()
    plt.show()


def visualize_cpm_and_logits(original_logits, top_k_indices, top_k_logits, k):
    # Reconstruct logits
    vocab_size = original_logits.size(-1)
    reconstructed_logits = reconstruct_logits_from_topk(
        top_k_logits, top_k_indices, vocab_size
    )

    # Compute CPM
    original_cpm = compute_cpm(original_logits)
    reconstructed_cpm = compute_cpm(reconstructed_logits)

    # Plot CPM comparison
    plot_cpm_comparison(original_cpm, reconstructed_cpm, k)

    # Plot logits heatmap
    plot_logits_heatmap(original_logits[0], reconstructed_logits[0], k)

    # Print statistics
    loss = F.kl_div(
        F.log_softmax(original_logits, dim=-1),
        F.log_softmax(reconstructed_logits, dim=-1),
        reduction="batchmean",
        log_target=True,
    )
    print(f"Loss: {loss:.4f}")

In [None]:
logit_dataset = load_from_disk("./data/teacher_logits")

In [None]:
# Get top-k predictions
teacher_logits = torch.tensor(logit_dataset[0]["teacher_logits"])
k = 10  # You can adjust this value
top_k_logits, top_k_indices = torch.topk(teacher_logits, k, dim=-1)

In [None]:
visualize_cpm_and_logits(
    teacher_logits, top_k_indices, top_k_logits, k
)