<a href="https://colab.research.google.com/github/juihuichung/simcse-ablations/blob/main/simcse_ablations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# hyperparameters
batch_size    = 64
seed          = 49
num_epochs    = 1
learning_rate = 3e-5
temperature   = 0.05
max_length    = 32
num_samples   = 10000
log_every     = 100

# batch-aware
p_start = 0.3
p_end   = 0.1

# layer-aware
dropout_start = 0.1
dropout_end   = 0.3

# pooling method
pooling_method = 'self_attention'
learn_temp = False
use_dynamic_temp = False

In [None]:
!pip install transformers > /dev/null 2>&1
!pip install transformers datasets > /dev/null 2>&1


import datetime
from zoneinfo import ZoneInfo

now = datetime.datetime.now(ZoneInfo("America/New_York"))
check_file = now.strftime("checkpoint_%m%d_%H%M.pth")
eval_file  = now.strftime("evaluation_results_%m%d_%H%M.txt")
log_file   = now.strftime("log_simcse_%m%d_%H%M.txt")

with open(log_file, "a") as f:
    print("...loading packages", file=f, flush=True)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, BertConfig
from datasets import load_dataset
import random
import numpy as np

with open(log_file, "a") as f:
    print("...finish loading packages", file=f, flush=True)

In [None]:
hyperparams = {
    "batch_size":    batch_size,
    "seed":          seed,
    "num_epochs":    num_epochs,
    "learning_rate": learning_rate,
    "temperature":   temperature,
    "max_length":    max_length,
    "num_samples":   num_samples,
    "log_every":     log_every,
    "dropout_start": dropout_start,
    "dropout_end":   dropout_end,
    "pooling_method": pooling_method,
    "learn_temp" : learn_temp,
    "use_dynamic_temp" : use_dynamic_temp,
    "p_start":       p_start,
    "p_end":         p_end
}

with open(log_file, "a") as f:
    f.write("=== Hyperparameters ===\n")
    for name, val in hyperparams.items():
        f.write(f"{name}: {val}\n")
    f.write("\n")
    f.flush()


In [None]:
# Download the text file containing Wikipedia samples for SimCSE
!wget -O wiki1m_for_simcse.txt https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt > /dev/null 2>&1


class WikiTextDataset(Dataset):
    def __init__(self, file_path, num_samples=5000, max_length=32):
        self.sentences = []
        self.max_length = max_length
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        # Read file and take the first num_samples lines
        with open(file_path, encoding="utf-8") as f:
            for i, line in enumerate(f):
                if i >= num_samples:
                    break
                line = line.strip()
                if line:
                    self.sentences.append(line)

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        text = self.sentences[idx]
        encoded = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        return {k: v.squeeze(0) for k, v in encoded.items()}

g = torch.Generator()
g.manual_seed(seed)


print("...loading data", flush=True)
dataset = WikiTextDataset("wiki1m_for_simcse.txt", num_samples=num_samples, max_length=max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, generator=g)

In [None]:
print(f"...number of batches is {len(dataloader)}")

In [None]:
"""
Keys from BERT tokenizer:

input_ids: The token IDs representing the input sentence.
attention_mask: A mask indicating which tokens are actual words and which are padding.
"""

# The length of the token is the maximum length of the sentences.
# If a sentence is shorter than maximum length, then the tokens are padded with zeros

print("sentence: ", dataset.sentences[0])
print("input_ids: ", dataset[0]["input_ids"])
print("mask: ", dataset[0]["attention_mask"])

In [None]:
# Set a random seed for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


class MLPLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, x):
        return self.activation(self.dense(x))


class SelfAttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

    def forward(self, hidden_states, attention_mask):
        scores = self.attention(hidden_states).squeeze(-1)  # [batch, seq]
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        weights = torch.softmax(scores, dim=1).unsqueeze(-1)  # [batch, sq, 1]
        return torch.sum(hidden_states * weights, dim=1)  # [batch, hidden]


class BertForContrastive(nn.Module):
    def __init__(
        self,
        model_name: str = "bert-base-uncased",
        temp: float = 0.05,
        learn_temp: bool = False,
        use_dynamic_temp: bool = False,
        dropout_start: float = 0.1,
        dropout_end: float = 0.1,
        pooling_method: str = 'cls_with_mlp',
        p_start = 0.1
    ):
        super().__init__()
        # load the backbone
        self.bert = BertModel.from_pretrained(model_name, add_pooling_layer=False)
        self.mlp = MLPLayer(self.bert.config)

        self.learn_temp = learn_temp
        self.use_dynamic_temp = use_dynamic_temp
        if learn_temp:
            self.temp = nn.Parameter(torch.tensor(temp))
        else:
            self.temp = temp

        self.pooling_method = pooling_method
        hidden_size = self.bert.config.hidden_size
        if pooling_method == 'self_attention':
            self.attn_pooling = SelfAttentionPooling(hidden_size)


        config = BertConfig.from_pretrained(model_name)
        config.hidden_dropout_prob = p_start
        config.attention_probs_dropout_prob = p_start

        # apply our layer-wise dropout schedule
        self._set_layerwise_dropout(dropout_start, dropout_end)


    def _set_layerwise_dropout(self, d_start: float, d_end: float):
        # number of Transformer layers
        layers = self.bert.encoder.layer
        L = len(layers)
        print("number of layers: ", L)
        # build a linear schedule [d0, d1, …, d_{L-1}]
        rates = [
            d_start + (d_end - d_start) * (i / (L - 1))
            for i in range(L)
        ]
        # embedding-dropout at rate d0
        self.bert.embeddings.dropout.p = rates[0]

        # now override each layer’s dropout modules
        for i, layer in enumerate(layers):
            p = rates[i]
            # attention probabilities (optional; if you want to vary that too)
            layer.attention.self.dropout.p = p
            # post-attention & feed-forward dropouts
            layer.attention.output.dropout.p = p
            layer.output.dropout.p = p

    def forward_with_pooling(self, input_ids, attention_mask, pooling_method='cls_with_mlp'):
        # Get all hidden states if needed for first-last avg
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask,
                            output_hidden_states=True, return_dict=True)

        if pooling_method == 'cls_with_mlp':
            cls = outputs.last_hidden_state[:, 0]
            return self.mlp(cls)

        elif pooling_method == 'cls_without_mlp':
            return outputs.last_hidden_state[:, 0]

        elif pooling_method == 'first_last_avg':
            first = outputs.hidden_states[1]
            last = outputs.hidden_states[-1]
            first_cls = first[:, 0]
            last_cls = last[:, 0]
            avg_cls = (first_cls + last_cls) / 2
            return avg_cls

        elif pooling_method == 'mean_pooling':
            # Mean pooling (average all token embeddings)
            # Use attention mask to ignore padding tokens
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(
                outputs.last_hidden_state.size()).float()
            sum_embeddings = torch.sum(outputs.last_hidden_state * input_mask_expanded, 1)
            sum_mask = input_mask_expanded.sum(1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            return sum_embeddings / sum_mask

        elif pooling_method == 'max_pooling':
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(
                outputs.last_hidden_state.size()).float()
            masked_embeddings = outputs.last_hidden_state.clone()
            masked_embeddings[input_mask_expanded == 0] = -1e9
            return torch.max(masked_embeddings, dim=1)[0]

        elif pooling_method == 'self_attention':
            return self.attn_pooling(outputs.last_hidden_state, attention_mask)

        else:
            raise ValueError("Invalid pooling method")

    def forward_contrastive(self, input_ids, attention_mask):
        emb1 = self.forward_with_pooling(input_ids, attention_mask, self.pooling_method)
        emb2 = self.forward_with_pooling(input_ids, attention_mask, self.pooling_method)
        return emb1, emb2

def contrastive_loss(emb1, emb2, temperature):
    # Normalize the embeddings along the feature dimension
    emb1 = F.normalize(emb1, dim=1)
    emb2 = F.normalize(emb2, dim=1)
    # Compute cosine similarity matrix (each row compares emb1 with all emb2)
    logits = torch.matmul(emb1, emb2.T)
    if use_dynamic_temp:
        # Exclude diagonal (self-similarity) from std calculation
        mask = ~torch.eye(logits.size(0), dtype=torch.bool, device=logits.device)
        temp = logits[mask].std().clamp(min=1e-2)
    else:
        temp = temperature
    logits = logits / temp  # shape: (batch_size, batch_size)
    # The positive pair for each instance is along the diagonal.
    labels = torch.arange(emb1.size(0)).to(emb1.device)
    loss = F.cross_entropy(logits, labels)
    return loss, temp


In [None]:
from torch.optim import Adam
from tqdm.auto import tqdm

class DropoutScheduler:
    def __init__(self, model, p_start, p_end, total_steps):
        """
        model        : your nn.Module (e.g. the whole BertForContrastive)
        p_start      : initial (high) dropout, e.g. 0.3
        p_end        : final (low) dropout,   e.g. 0.1
        total_steps  : total number of update steps (epochs * iters_per_epoch)
        """
        self.model = model
        self.p_start = p_start
        self.p_end   = p_end
        self.T       = total_steps

    def step(self, step):
        # linear interpolation
        t = min(step, self.T)
        p_t = self.p_start + (self.p_end - self.p_start) * (t / self.T)
        # override every nn.Dropout in the model
        for m in self.model.modules():
            if isinstance(m, torch.nn.Dropout):
                m.p = p_t
        return p_t

        # print(f"Dropout set to {p_t}")

In [None]:
print("...initializing models", flush=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForContrastive(
    "bert-base-uncased",
    temp=temperature,
    learn_temp=learn_temp,
    use_dynamic_temp=use_dynamic_temp,
    dropout_start=dropout_start,
    dropout_end=dropout_end,
    pooling_method=pooling_method
).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)



In [None]:
# Download the tar file using wget
!wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/senteval.tar > /dev/null 2>&1

# Extract the tar file using the tar command
!tar -xvf senteval.tar > /dev/null 2>&1

In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from scipy.stats import spearmanr
import io
import tarfile
import shutil

def load_sts_dataset(dataset_path, datasets):
    results = {}
    for dataset in datasets:
        # Load input sentences
        sent1 = []
        sent2 = []
        with io.open(os.path.join(dataset_path, f"STS.input.{dataset}.txt"), encoding='utf8') as f:
            for line in f:
                s1, s2 = line.strip().split('\t')
                sent1.append(s1.strip())
                sent2.append(s2.strip())

        # Load gold standard scores
        raw_scores = []
        with io.open(os.path.join(dataset_path, f"STS.gs.{dataset}.txt"), encoding='utf8') as f:
            raw_scores = [line.strip() for line in f]

        # Filter out empty scores and corresponding sentence pairs
        filtered_sent1 = []
        filtered_sent2 = []
        filtered_scores = []

        for s1, s2, score in zip(sent1, sent2, raw_scores):
            if score != '':
                filtered_sent1.append(s1)
                filtered_sent2.append(s2)
                filtered_scores.append(float(score))

        if len(filtered_sent1) > 0:
            results[dataset] = (filtered_sent1, filtered_sent2, filtered_scores)

    return results

def load_sts_benchmark(benchmark_path):
    sent1 = []
    sent2 = []
    scores = []

    with io.open(os.path.join(benchmark_path, "sts-test.csv"), encoding='utf8') as f:
        for line in f:
            text = line.strip().split('\t')
            sent1.append(text[5].strip())
            sent2.append(text[6].strip())
            scores.append(float(text[4]))

    return sent1, sent2, scores

def load_sick_r(sick_path):
    sent1 = []
    sent2 = []
    scores = []

    skip_first = True
    with io.open(os.path.join(sick_path, "SICK_test_annotated.txt"), encoding='utf8') as f:
        for line in f:
            if skip_first:
                skip_first = False
                continue
            text = line.strip().split('\t')
            sent1.append(text[1].strip())
            sent2.append(text[2].strip())
            scores.append(float(text[3]))

    return sent1, sent2, scores

In [None]:
data_path = "."

def get_embeddings(model, tokenizer, sentences, device, batch_size=32):
    model.eval()
    embeddings = []

    for i in range(0, len(sentences), batch_size):
        batch = sentences[i:i+batch_size]
        inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)

        with torch.no_grad():
            batch_embeddings = model.forward_with_pooling(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                pooling_method=pooling_method,
            )

        # Normalize
        batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
        embeddings.append(batch_embeddings.cpu())

    return torch.cat(embeddings, dim=0)

# Calculate cosine similarity and Spearman correlation
def evaluate_sts(model, tokenizer, sent1, sent2, scores, device):
    embeddings1 = get_embeddings(model, tokenizer, sent1, device)
    embeddings2 = get_embeddings(model, tokenizer, sent2, device)
    cos_sim = torch.nn.functional.cosine_similarity(embeddings1, embeddings2).numpy()
    correlation, _ = spearmanr(scores, cos_sim)
    return correlation * 100  # Convert to percentage

# Main evaluation function
def evaluate_all_sts(model, tokenizer, data_path, device='cuda'):
    results = {}

    # STS12
    datasets_sts12 = ['MSRpar', 'MSRvid', 'SMTeuroparl', 'surprise.OnWN', 'surprise.SMTnews']
    data_sts12 = load_sts_dataset(os.path.join(data_path, "STS/STS12-en-test"), datasets_sts12)
    all_sent1_sts12, all_sent2_sts12, all_scores_sts12 = [], [], []
    for dataset, (sent1, sent2, scores) in data_sts12.items():
        all_sent1_sts12.extend(sent1)
        all_sent2_sts12.extend(sent2)
        all_scores_sts12.extend(scores)
    results['STS12'] = evaluate_sts(model, tokenizer, all_sent1_sts12, all_sent2_sts12, all_scores_sts12, device)

    # STS13
    datasets_sts13 = ['FNWN', 'headlines', 'OnWN']
    data_sts13 = load_sts_dataset(os.path.join(data_path, "STS/STS13-en-test"), datasets_sts13)
    all_sent1_sts13, all_sent2_sts13, all_scores_sts13 = [], [], []
    for dataset, (sent1, sent2, scores) in data_sts13.items():
        all_sent1_sts13.extend(sent1)
        all_sent2_sts13.extend(sent2)
        all_scores_sts13.extend(scores)
    results['STS13'] = evaluate_sts(model, tokenizer, all_sent1_sts13, all_sent2_sts13, all_scores_sts13, device)

    # STS14
    datasets_sts14 = ['deft-forum', 'deft-news', 'headlines', 'images', 'OnWN', 'tweet-news']
    data_sts14 = load_sts_dataset(os.path.join(data_path, "STS/STS14-en-test"), datasets_sts14)
    all_sent1_sts14, all_sent2_sts14, all_scores_sts14 = [], [], []
    for dataset, (sent1, sent2, scores) in data_sts14.items():
        all_sent1_sts14.extend(sent1)
        all_sent2_sts14.extend(sent2)
        all_scores_sts14.extend(scores)
    results['STS14'] = evaluate_sts(model, tokenizer, all_sent1_sts14, all_sent2_sts14, all_scores_sts14, device)

    # STS15
    datasets_sts15 = ['answers-forums', 'answers-students', 'belief', 'headlines', 'images']
    data_sts15 = load_sts_dataset(os.path.join(data_path, "STS/STS15-en-test"), datasets_sts15)
    all_sent1_sts15, all_sent2_sts15, all_scores_sts15 = [], [], []
    for dataset, (sent1, sent2, scores) in data_sts15.items():
        all_sent1_sts15.extend(sent1)
        all_sent2_sts15.extend(sent2)
        all_scores_sts15.extend(scores)
    results['STS15'] = evaluate_sts(model, tokenizer, all_sent1_sts15, all_sent2_sts15, all_scores_sts15, device)

    # STS16
    datasets_sts16 = ['answer-answer', 'headlines', 'plagiarism', 'postediting', 'question-question']
    data_sts16 = load_sts_dataset(os.path.join(data_path, "STS/STS16-en-test"), datasets_sts16)
    all_sent1_sts16, all_sent2_sts16, all_scores_sts16 = [], [], []
    for dataset, (sent1, sent2, scores) in data_sts16.items():
        all_sent1_sts16.extend(sent1)
        all_sent2_sts16.extend(sent2)
        all_scores_sts16.extend(scores)
    results['STS16'] = evaluate_sts(model, tokenizer, all_sent1_sts16, all_sent2_sts16, all_scores_sts16, device)

    # STS-Benchmark
    sent1_stsb, sent2_stsb, scores_stsb = load_sts_benchmark(os.path.join(data_path, "STS/STSBenchmark"))
    results['STS-B'] = evaluate_sts(model, tokenizer, sent1_stsb, sent2_stsb, scores_stsb, device)

    # SICK-R
    sent1_sickr, sent2_sickr, scores_sickr = load_sick_r(os.path.join(data_path, "SICK"))
    results['SICK-R'] = evaluate_sts(model, tokenizer, sent1_sickr, sent2_sickr, scores_sickr, device)

    # Calculate average
    avg_score = sum(results.values()) / len(results)
    results['Avg'] = avg_score

    # Output a single line with the percentage values for each metric
    output_line = (
        f"{results['STS12']:.2f} "
        f"{results['STS13']:.2f} "
        f"{results['STS14']:.2f} "
        f"{results['STS15']:.2f} "
        f"{results['STS16']:.2f} "
        f"{results['STS-B']:.2f} "
        f"{results['SICK-R']:.2f} "
        f"{results['Avg']:.2f}"
    )
    print(output_line)

    return results

In [None]:
import time
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr
import os
import sys


start_time = time.time()
with open(log_file, "a") as f:
    print("Starting training...\n", file=f, flush=True)


global_step = 0
drop_sched = DropoutScheduler(model, p_start=p_start, p_end=p_end, total_steps=num_epochs * len(dataloader))


for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for batch_idx, batch in enumerate(dataloader, 1):
        global_step += 1
        current_dropout_rate = drop_sched.step(global_step)

        # Move inputs to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        optimizer.zero_grad()
        # Get two views of the sentence by running the model twice (dropout introduces variation)
        emb1, emb2 = model.forward_contrastive(input_ids, attention_mask)
        loss, cur_temp = contrastive_loss(emb1, emb2, model.temp)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        with open(log_file, "a") as f:
            print(f"Epoch {epoch+1}, Time {time.time() - start_time:.2f}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}, Temperature: {cur_temp.item() if isinstance(cur_temp, torch.Tensor) else cur_temp:.8f}", file=f, flush=True)
        # Every 1000 batches, run evaluation and append the output to a file.
        model.eval()
        with torch.no_grad():
            if batch_idx % log_every == 0:
                # print(f"Running evaluation at Epoch {epoch+1}, Batch {batch_idx}...", flush=True)
                eval_results = evaluate_all_sts(model, dataset.tokenizer, data_path, device=device)
                output_line = (
                    f"{eval_results['STS12']:.2f} "
                    f"{eval_results['STS13']:.2f} "
                    f"{eval_results['STS14']:.2f} "
                    f"{eval_results['STS15']:.2f} "
                    f"{eval_results['STS16']:.2f} "
                    f"{eval_results['STS-B']:.2f} "
                    f"{eval_results['SICK-R']:.2f} "
                    f"{eval_results['Avg']:.2f}"
                )
                # Append the single-line evaluation results to the file.
                with open(eval_file, "a") as f:
                    f.write(output_line + "\n")
                    f.flush()
        # Return to training mode!
        model.train()

    avg_loss = total_loss / len(dataloader)
    with open(log_file, "a") as f:
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}\n", file=f, flush=True)

with open(log_file, "a") as f:
    print("\nTraining completed.", file=f, flush=True)


In [None]:
import torch

hyperparameters = {
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "learning_rate": learning_rate,
    "temperature": temperature,
    "max_length": max_length,
    "num_samples": num_samples,
    "pooling_method": pooling_method,
    "device": device
}
checkpoint = {
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss.item(),
    "hyperparameters": hyperparameters,
}
torch.save(checkpoint, check_file)
with open(log_file, "a") as f:
    print(f"Checkpoint saved to {check_file}", file=f, flush=True)


In [None]:
checkpoint = torch.load(check_file)
hyperparameters = checkpoint["hyperparameters"]
model = BertForContrastive("bert-base-uncased", temp=hyperparameters["temperature"], pooling_method=hyperparameters["pooling_method"]).to(hyperparameters["device"])
optimizer = optim.Adam(model.parameters(), lr=hyperparameters["learning_rate"])

model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [None]:
# Set model to evaluation mode (dropout is disabled)
model.eval()

# Tokenize input texts
texts = [
    "There's a kid on a skateboard.",
    "A kid is skateboarding.",
    "A kid is inside the house."
]


inputs = dataset.tokenizer(
    texts,
    truncation=True,
    max_length=dataset.max_length,  # same as self.max_length used in training
    padding="max_length",
    return_tensors="pt"
).to(device)


# Get the embeddings using our trained model.
# Our custom model returns the [CLS] embedding (after the MLP head).
# The paper also explores other pooling method!
with torch.no_grad():
    embeddings = model.forward_with_pooling(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        pooling_method=pooling_method,
    )

# Now, compute cosine similarities between the embeddings.
# F.cosine_similarity expects two tensors of the same shape. We add a batch dimension to each.
cosine_sim_0_1 = F.cosine_similarity(embeddings[0].unsqueeze(0), embeddings[1].unsqueeze(0)).item()
cosine_sim_0_2 = F.cosine_similarity(embeddings[0].unsqueeze(0), embeddings[2].unsqueeze(0)).item()

with open(log_file, "a") as f:
    print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[1], cosine_sim_0_1), file=f, flush=True)
    print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[2], cosine_sim_0_2), file=f, flush=True)


In [None]:
# # using the check point from author to compare with their results

# import torch
# from scipy.spatial.distance import cosine
# from transformers import AutoModel, AutoTokenizer

# # Import our models. The package will take care of downloading the models automatically
# author_tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
# author_model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")

# # Tokenize input texts
# texts = [
#     "There's a kid on a skateboard.",
#     "A kid is skateboarding.",
#     "A kid is inside the house."
# ]
# inputs = author_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

# # Get the embeddings
# with torch.no_grad():
#     embeddings = author_model(**inputs, output_hidden_states=True, return_dict=True).pooler_output

# # Calculate cosine similarities
# # Cosine similarities are in [-1, 1]. Higher means more similar
# cosine_sim_0_1 = 1 - cosine(embeddings[0], embeddings[1])
# cosine_sim_0_2 = 1 - cosine(embeddings[0], embeddings[2])

# print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[1], cosine_sim_0_1))
# print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[2], cosine_sim_0_2))