In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
pip install sympy==1.12

Collecting sympy==1.12
  Downloading sympy-1.12-py3-none-any.whl.metadata (12 kB)
Downloading sympy-1.12-py3-none-any.whl (5.7 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.7/5.7 MB[0m [31m216.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m [31m118.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sympy
  Attempting uninstall: sympy
    Found existing installation: sympy 1.13.3
    Uninstalling sympy-1.13.3:
      Successfully uninstalled sympy-1.13.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.8.0+cu126 requires sympy>=1.13.3, but you have sympy 1.12 which is incompatible.[0m[31m
[0mSuccessfully installed sympy-1.12


In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, EsmForSequenceClassification
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pickle

class TokenizedDataset(Dataset):
    def __init__(self, csv_file, tokenizer, label_mapper, max_length=1024):
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.label_mapper = label_mapper
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = self.data.iloc[idx]['sequence']
        label = self.label_mapper[self.data.iloc[idx]['cycle']]
        inputs = self.tokenizer(sequence, return_tensors="pt", padding='max_length',
                                truncation=True, max_length=self.max_length)
        return {
            'input_ids': inputs['input_ids'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long)
        }

class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, input_ids):
        logits = self.model(input_ids).logits
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        return logits / self.temperature

    def set_temperature(self, valid_loader, save_path, device="cuda:0" if torch.cuda.is_available() else "cpu", asym_lambda=50, output_dir=None):
        self.to(device)
        self.model.eval()
        nll_criterion = nn.CrossEntropyLoss().to(device)

        logits_list, labels_list = [], []

        print("Collecting logits and labels...")
        with torch.no_grad():
            for batch in tqdm(valid_loader, desc="Processing validation set", unit="batch"):
                inputs = batch['input_ids'].to(device)
                labels = batch['label'].to(device)
                logits = self.model(inputs).logits
                logits_list.append(logits)
                labels_list.append(labels)

        logits = torch.cat(logits_list, dim=0)
        labels = torch.cat(labels_list, dim=0)

        before_nll = nll_criterion(logits, labels).item()
        before_ece = calculate_ece(logits, labels)
        print(f"Before temperature - NLL: {before_nll:.3f}, ECE: {before_ece:.3f}")

        if output_dir:
            make_reliability_diagram(logits, labels, title=f"Before Temperature Scaling\nECE: {before_ece:.3f}",
                                     save_path=os.path.join(output_dir, "before_scaling.png"))
        else:
            make_reliability_diagram(logits, labels, title=f"Before Temperature Scaling\nECE: {before_ece:.3f}")

        optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)

        def closure():
            optimizer.zero_grad()
            nll_loss = nll_criterion(self.temperature_scale(logits), labels)
            asym_loss = asymmetric_confidence_penalty(self.temperature_scale(logits), labels)
            loss = nll_loss + asym_lambda * asym_loss
            loss.backward()
            return loss

        print(f"Optimizing temperature with asymmetric penalty (lambda={asym_lambda})...")
        optimizer.step(closure)

        after_logits = self.temperature_scale(logits)
        after_nll = nll_criterion(after_logits, labels).item()
        after_ece = calculate_ece(after_logits, labels)

        print(f"Optimal temperature: {self.temperature.item():.3f}")
        print(f"After temperature - NLL: {after_nll:.3f}, ECE: {after_ece:.3f}")

        if output_dir:
            make_reliability_diagram(after_logits, labels, title=f"After Temperature Scaling\nECE: {after_ece:.3f}",
                                     save_path=os.path.join(output_dir, "after_scaling.png"))
            with open(os.path.join(output_dir, "metrics.txt"), "w") as f:
                f.write(f"Lambda: {asym_lambda}\n")
                f.write(f"Before Temp - NLL: {before_nll:.3f}, ECE: {before_ece:.3f}\n")
                f.write(f"After Temp - NLL: {after_nll:.3f}, ECE: {after_ece:.3f}\n")

        torch.save(self.temperature.item(), save_path)
        print(f"Saved optimal temperature to: {save_path}")

        return self

def asymmetric_confidence_penalty(logits, labels, threshold=0.7, weight=2.0):
    softmaxes = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(softmaxes, 1)
    is_correct = predictions.eq(labels)
    overconfident = confidences > threshold
    penalty = torch.where(
        is_correct,
        torch.zeros_like(confidences),
        (confidences - threshold) * overconfident.float() * weight
    )
    return penalty.mean()

def calculate_ece(logits, labels, n_bins=15):
    softmaxes = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(labels)

    bins = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bins[:-1]
    bin_uppers = bins[1:]

    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.ge(bin_lower) * confidences.lt(bin_upper)
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()

def make_reliability_diagram(logits, labels, n_bins=15, title="Reliability Diagram", save_path=None):
    softmaxes = F.softmax(logits, dim=1)
    confidences, predictions = softmaxes.max(1)
    accuracies = predictions.eq(labels)

    bins = torch.linspace(0, 1, n_bins + 1)
    width = 1.0 / n_bins
    bin_centers = np.linspace(0, 1.0 - width, n_bins) + width / 2
    bin_indices = [confidences.ge(bin_lower) * confidences.lt(bin_upper) for bin_lower, bin_upper in zip(bins[:-1], bins[1:])]

    bin_corrects = np.array([torch.mean(accuracies[bin_idx].float()).item() if bin_idx.any() else 0 for bin_idx in bin_indices])
    bin_scores = np.array([torch.mean(confidences[bin_idx].float()).item() if bin_idx.any() else 0 for bin_idx in bin_indices])

    plt.figure(figsize=(8, 8))
    gap = (bin_scores - bin_corrects)
    plt.bar(bin_centers, bin_corrects, width=width, alpha=0.5, ec='black', label='Outputs')
    plt.bar(bin_centers, gap, bottom=bin_corrects, color='red', alpha=0.5, width=width, hatch='//', edgecolor='r', label='Gap')
    plt.plot([0, 1], [0, 1], '--', color='gray')
    plt.legend(loc='best', fontsize='small')

    ece = calculate_ece(logits, labels, n_bins)
    bbox_props = dict(boxstyle="round", fc="lightgrey", ec="brown", lw=2)
    plt.text(0.2, 0.85, f"ECE: {ece:.2f}", ha="center", va="center", size=20, weight='bold', bbox=bbox_props)

    plt.title(title, size=20)
    plt.ylabel("Accuracy (P[y])", size=18)
    plt.xlabel("Confidence", size=18)
    plt.xlim(0, 1)
    plt.ylim(0, 1)

    if save_path:
        plt.savefig(save_path, dpi=600, bbox_inches="tight")
    else:
        plt.show()
    plt.close()

if __name__ == "__main__":
    model_path = "/content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/models/cyc_50"
    validation_csv = "/content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/data/validation_files/final_selected_val_50.csv"
    pickle_path = "/content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/models/cyc_id_maps/cyc_label_id_map_50.pickle"
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8m_UR50D")
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    with open(pickle_path, 'rb') as f:
        label_mapper = pickle.load(f)
    label_mapper = {v: k for k, v in label_mapper.items()}

    model = EsmForSequenceClassification.from_pretrained(model_path, num_labels=len(label_mapper))
    model_with_temp = ModelWithTemperature(model)

    val_dataset = TokenizedDataset(validation_csv, tokenizer, label_mapper)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


    base_temp_dir = "/content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50"
    lambda_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]

    for asym_lambda in lambda_list:
      output_dir = os.path.join(base_temp_dir, f"lambda_{asym_lambda}")
      os.makedirs(output_dir, exist_ok=True)
      temp_save_path = os.path.join(output_dir, "optimal_temperature.pt")

      print(f"Saving results to: {output_dir}")

      model_with_temp.set_temperature(
          val_loader,
          save_path=temp_save_path,
          device=device,
          asym_lambda=asym_lambda,
          output_dir=output_dir
      )



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_0
Collecting logits and labels...


Processing validation set:   0%|          | 0/15139 [00:00<?, ?batch/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Processing validation set: 100%|██████████| 15139/15139 [20:45<00:00, 12.15batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=0)...


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  loss = float(closure())


Optimal temperature: 1.758
After temperature - NLL: 3.612, ECE: 0.209
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_0/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_10
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:45<00:00, 12.15batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=10)...
Optimal temperature: 1.972
After temperature - NLL: 3.480, ECE: 0.170
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_10/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_20
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:45<00:00, 12.16batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=20)...
Optimal temperature: 2.192
After temperature - NLL: 3.390, ECE: 0.141
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_20/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_30
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:45<00:00, 12.15batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=30)...
Optimal temperature: 2.363
After temperature - NLL: 3.343, ECE: 0.126
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_30/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_40
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:45<00:00, 12.16batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=40)...
Optimal temperature: 2.576
After temperature - NLL: 3.303, ECE: 0.113
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_40/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_50
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:45<00:00, 12.16batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=50)...
Optimal temperature: 2.745
After temperature - NLL: 3.283, ECE: 0.108
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_50/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_60
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:46<00:00, 12.15batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=60)...
Optimal temperature: 2.905
After temperature - NLL: 3.269, ECE: 0.104
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_60/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_70
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:45<00:00, 12.15batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=70)...
Optimal temperature: 3.070
After temperature - NLL: 3.260, ECE: 0.103
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_70/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_80
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:45<00:00, 12.15batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=80)...
Optimal temperature: 3.200
After temperature - NLL: 3.256, ECE: 0.100
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_80/optimal_temperature.pt
Saving results to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_90
Collecting logits and labels...


Processing validation set: 100%|██████████| 15139/15139 [20:46<00:00, 12.15batch/s]


Before temperature - NLL: 4.979, ECE: 0.441
Optimizing temperature with asymmetric penalty (lambda=90)...
Optimal temperature: 3.297
After temperature - NLL: 3.254, ECE: 0.098
Saved optimal temperature to: /content/drive/MyDrive/cycformer_run/cycformer_sep2/cycformer/temperature_scaling/cyc_50/lambda_90/optimal_temperature.pt
