In [None]:
import torch
from torch import nn, optim
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
path = ''

tokenizer = AutoTokenizer.from_pretrained(path)
config = AutoConfig.from_pretrained(path)
model = AutoModelForSequenceClassification.from_pretrained(path, config=config).to(device)

In [None]:
from datasets import load_dataset
ds = load_dataset("go_emotions", "simplified")
valid_ds = ds["validation"]

In [None]:
def one_hot_encode(example):
    l = example["labels"]
    one_hot_list = [0] * (28)
    for i in l:
        one_hot_list[i] = 1
    example["labels"] = one_hot_list
    return example

In [None]:
valid_ds = valid_ds.map(one_hot_encode)

In [None]:
def tokenize_func(examples):
  return tokenizer(examples["text"], truncation=True, padding='max_length', max_length=50)

In [None]:
valid_ds = valid_ds.map(tokenize_func, batched=True, remove_columns=["text", "id"])
valid_ds = valid_ds.rename_column("labels", "label")

In [None]:
def to_float_labels(example):
    float_labels = example["label"].to(torch.float)
    example["float_label"] = float_labels
    return example

In [None]:
valid_ds.set_format("torch")
valid_ds = valid_ds.map(to_float_labels).remove_columns("label").rename_column("float_label", "label")
dataloader = torch.utils.data.DataLoader(valid_ds, batch_size=16)

In [None]:
model.eval()

In [None]:
def temperature_scale(logits, temperature):
    # Expand temperature to match the size of logits
    temperature = temperature.unsqueeze(0).expand(logits.size(0), logits.size(1))
    return logits / temperature

In [None]:
def find_optimal_temperature(model, valid_loader, initial_temp=1.5, max_iter=10000):
    temp = torch.nn.Parameter(torch.tensor(initial_temp, dtype=torch.float, requires_grad=True, device=device))
    
    nll_criterion = nn.BCEWithLogitsLoss().to(device)

    # First: collect all the logits and labels for the validation set
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for batch in valid_loader:
            enc = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device),
            }
            logits = model(**enc).logits
            logits_list.append(logits)
            labels_list.append(batch['label'])
        logits = torch.cat(logits_list).to(device)
        labels = torch.cat(labels_list).to(device)

    # Calculate NLL and ECE before temperature scaling
    before_temperature_nll = nll_criterion(logits, labels).item()
    print('Before temperature - NLL: %.3f' % (before_temperature_nll))

    # Next: optimize the temperature w.r.t. NLL
    optimizer = optim.LBFGS([temp], lr=0.01, max_iter=max_iter)

    def eval():
        optimizer.zero_grad()
        loss = nll_criterion(temperature_scale(logits, temp), labels)
        loss.backward()
        return loss
    optimizer.step(eval)

    # Calculate NLL and ECE after temperature scaling
    after_temperature_nll = nll_criterion(temperature_scale(logits, temp), labels).item()
    print('Optimal temperature: %.3f' % temp.item())
    print('After temperature - NLL: %.3f' % (after_temperature_nll))

    return temp.item()

In [None]:
find_optimal_temperature(model, dataloader)