<a href="https://colab.research.google.com/github/mikhail-ram/VIOLET/blob/main/VIOLET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vectorized Invariance Optimization for Language Embeddings using Twins



## Installing dependencies

In [None]:
%%capture
!pip install datasets nlpaug sacremoses optuna kaleido
!python -m nltk.downloader averaged_perceptron_tagger_eng

## Implementing Augmentations

In [None]:
from transformers import MarianMTModel, MarianTokenizer

class BackTranslationAug:
    def __init__(self, model_name_en_to_fr, model_name_fr_to_en, device='cuda'):
        # Load models and tokenizers for English-French and French-English translation
        self.device = device
        self.name = "BackTranslation_Aug"
        self.action = "insert"
        self.aug_p = 0.3

        # Load English to French model and tokenizer
        self.tokenizer_en_to_fr = MarianTokenizer.from_pretrained(model_name_en_to_fr)
        self.model_en_to_fr = MarianMTModel.from_pretrained(model_name_en_to_fr).to(self.device)

        # Load French to English model and tokenizer
        self.tokenizer_fr_to_en = MarianTokenizer.from_pretrained(model_name_fr_to_en)
        self.model_fr_to_en = MarianMTModel.from_pretrained(model_name_fr_to_en).to(self.device)

    def augment(self, texts, num_beams=1):
        with torch.no_grad():
            # Translate English to French
            inputs_en_to_fr = self.tokenizer_en_to_fr(texts, return_tensors="pt", padding=True, truncation=True).to(self.device)
            with torch.amp.autocast(self.device):  # Enable mixed precision
                translated_to_fr = self.model_en_to_fr.generate(**inputs_en_to_fr, num_beams=num_beams)
            translated_to_fr = self.tokenizer_en_to_fr.batch_decode(translated_to_fr, skip_special_tokens=True)

            # Translate French back to English
            inputs_fr_to_en = self.tokenizer_fr_to_en(translated_to_fr, return_tensors="pt", padding=True, truncation=True).to(self.device)
            with torch.amp.autocast(self.device):  # Enable mixed precision
                translated_back_to_en = self.model_fr_to_en.generate(**inputs_fr_to_en, num_beams=num_beams)

        return self.tokenizer_fr_to_en.batch_decode(translated_back_to_en, skip_special_tokens=True)

## Training Loop (w/ shuffling for dev sentences + mixup regularization)

In [None]:
import textwrap
import os
import torch
from datetime import datetime
from sentence_transformers import SentenceTransformer, models, util
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import copy
import gc
import math
import random
import nlpaug.augmenter.word as naw
from IPython.display import display, clear_output
from tqdm import tqdm
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from torch import GradScaler
from torch.amp import autocast

class BarlowTwinsNCSE:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.backends.cudnn.benchmark = True  # Enable cuDNN benchmarking
        self._prepare_datasets()
        self._initialize_models()
        self._initialize_optimizer_scheduler()
        self.scaler = GradScaler("cuda", enabled=self.config.get("use_amp", True))  # GradScaler for AMP
        self.best_spearman = -float("inf")
        self.best_pearson = -float("inf")
        self.patience_counter = 0
        self.augmenters = [
            naw.SynonymAug(aug_src='wordnet', aug_p=self.config["aug_p"]),
            naw.RandomWordAug(action="swap", aug_p=self.config["aug_p"]),
            naw.RandomWordAug(aug_p=self.config["aug_p"]),
        ]
        self.test_sts_pearson_cosine_values = []
        self.test_sts_spearman_cosine_values = []
        self.test_iterations = []
        self._create_plot()

    def _create_plot(self):
        self.loss_values = []
        self.sts_pearson_cosine_values = []
        self.sts_spearman_cosine_values = []
        self.mean_grad_norm_values = []
        self.variance_values = []
        self.learning_rate_values = []
        self.iterations = []
        self.epochs = []

        # 3 rows x 2 columns grid with 6 subplots
        self.fig = make_subplots(
            rows=3, cols=2,
            subplot_titles=(
                "Loss vs Iterations",
                "Mean Gradient Norm vs Iterations",
                "Variance vs Iterations",
                "Learning Rate vs Iterations",
                "Dev STS Cosine (Pearson & Spearman) vs Iterations",
                "Test STS Cosine (Pearson & Spearman) vs Iterations"
            )
        )
        # Dev metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Loss'), row=1, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Mean Gradient Norm'), row=1, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Variance'), row=2, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Learning Rate'), row=2, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Pearson Cosine'), row=3, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Spearman Cosine'), row=3, col=1)
        # Test metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Pearson Cosine'), row=3, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Spearman Cosine'), row=3, col=2)

        # Prepare footer text
        footer_text = ", ".join([f"{key}={value}" for key, value in self.config.items()])
        augmenters_text = ", Augmenters: " + ", ".join([f"{aug.name}[{aug.action}:{aug.aug_p}]" for aug in self.augmenters])
        footer_text += augmenters_text
        wrapped_footer = "<br>".join(textwrap.wrap(footer_text, width=160))

        # Configure the download button
        self.plot_config = {
            'toImageButtonOptions': {
                'filename': self.config["model_save_path"],
                'format': 'png',
                'width': 1200,
                'height': 800,
                'scale': 1
            }
        }

        self.fig.update_layout(
            width=1200,
            height=800,
            title_text='Training Metrics',
            showlegend=True,
            margin=dict(l=50, r=50, t=100, b=150),
            annotations=[
                dict(
                    text=wrapped_footer,
                    showarrow=False,
                    xref="paper",
                    yref="paper",
                    x=0.5,
                    y=-0.15,
                    xanchor='center',
                    yanchor='top',
                    align="center",
                    font=dict(size=10, color="gray")
                )
            ]
        )

        # Annotation for best metrics
        self.best_metrics_annotation_index = len(self.fig.layout.annotations)
        self.fig.add_annotation(
            text="Best Spearman: N/A<br>Best Pearson: N/A",
            showarrow=False,
            xref="paper",
            yref="paper",
            x=1.0,
            y=0.0,
            xanchor='right',
            yanchor='bottom',
            align="right",
            font=dict(size=12, color="blue")
        )

        self.fig.show(config=self.plot_config)

    def _update_traces(self):
        with self.fig.batch_update():
            self.fig.data[0].x = self.iterations
            self.fig.data[0].y = self.loss_values
            self.fig.data[1].x = self.iterations
            self.fig.data[1].y = self.mean_grad_norm_values
            self.fig.data[2].x = self.iterations
            self.fig.data[2].y = self.variance_values
            self.fig.data[3].x = self.iterations
            self.fig.data[3].y = self.learning_rate_values
            self.fig.data[4].x = self.iterations
            self.fig.data[4].y = self.sts_pearson_cosine_values
            self.fig.data[5].x = self.iterations
            self.fig.data[5].y = self.sts_spearman_cosine_values
            self.fig.data[6].x = self.test_iterations
            self.fig.data[6].y = self.test_sts_pearson_cosine_values
            self.fig.data[7].x = self.test_iterations
            self.fig.data[7].y = self.test_sts_spearman_cosine_values

            for i in range(1, 4):
                for j in range(1, 3):
                    self.fig.update_yaxes(autorange=True, row=i, col=j)
                    self.fig.update_xaxes(autorange=True, row=i, col=j)

    def _update_plot(self):
        self._update_traces()

        unique_epochs = sorted(set(self.epochs))
        frames = []
        for ep in unique_epochs:
            indices = [i for i, e in enumerate(self.epochs) if e == ep]
            frame = go.Frame(
                data=[
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.loss_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.mean_grad_norm_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.variance_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.learning_rate_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_pearson_cosine_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_spearman_cosine_values[i] for i in indices])
                ],
                name=str(ep)
            )
            frames.append(frame)

        self.fig.frames = frames

        slider_steps = [
            {"args": [[str(ep)], {"frame": {"duration": 0, "redraw": True},
                                   "mode": "immediate", "transition": {"duration": 0}}],
             "label": str(ep), "method": "animate"} for ep in unique_epochs
        ]

        self.fig.update_layout(
            sliders=[{
                "active": len(unique_epochs) - 1 if unique_epochs else 0,
                "currentvalue": {"prefix": "Epoch: "},
                "pad": {"t": 50},
                "steps": slider_steps
            }]
        )

        self.fig.layout.annotations[self.best_metrics_annotation_index].text = (
            f"Best Spearman: {self.best_spearman:.4f}<br>Best Pearson: {self.best_pearson:.4f}"
        )

        for i in range(1, 4):
            for j in range(1, 3):
                self.fig.update_yaxes(autorange=True, row=i, col=j)
                self.fig.update_xaxes(autorange=True, row=i, col=j)

        clear_output(wait=True)
        self.fig.show(config=self.plot_config)

    def _prepare_datasets(self):
        wikipedia_url = "https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt"
        wikipedia_dataset_path = "data/wiki1m_for_simcse.txt"

        if not os.path.exists(wikipedia_dataset_path):
            util.http_get(wikipedia_url, wikipedia_dataset_path)
        train_sentences = []
        with open(wikipedia_dataset_path, encoding="utf8") as f:
            for line in f:
                line = line.strip()
                if len(line) >= 10:
                    train_sentences.append(line)
        self.train_sentences = train_sentences

        self.train_dataset = load_dataset("sentence-transformers/stsb", split="train")
        self.eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
        self.test_dataset = load_dataset("sentence-transformers/stsb", split="test")

        self.train_data_loader = DataLoader(
            SentenceDataset(self.train_sentences),
            batch_size=self.config["batch_size"],
            shuffle=True,
            num_workers=self.config["num_workers"],
            pin_memory=True
        )
        self.test_evaluator = EmbeddingSimilarityEvaluator(
            sentences1=self.test_dataset["sentence1"],
            sentences2=self.test_dataset["sentence2"],
            scores=self.test_dataset["score"]
        )

        self.evaluate_steps = max(len(self.train_data_loader) // 100, 1)

    def _get_random_augmentation(self):
        return random.choice(self.augmenters)

    def _apply_augmentation(self, sentences, aug):
        return aug.augment(sentences)

    def _initialize_models(self):
        word_embedding_model = models.Transformer(
            self.config["model_name"],
            max_seq_length=self.config["max_seq_length"],
            config_args={"attention_dropout": self.config["aug_p"], "dropout": self.config["aug_p"]}
        )
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        projection_layers = [torch.nn.Linear(768, self.config["projection_size"])]
        for _ in range(self.config["projection_depth"] - 1):
            projection_layers.append(torch.nn.BatchNorm1d(self.config["projection_size"]))
            projection_layers.append(torch.nn.ReLU())
            projection_layers.append(torch.nn.Linear(self.config["projection_size"], self.config["projection_size"]))
        projection_head = torch.nn.Sequential(*projection_layers)
        self.online_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, projection_head]).to(self.device)
        encoder_modules = [copy.deepcopy(self.online_model[i]) for i in range(2)]
        self.encoder = SentenceTransformer(modules=encoder_modules).to(self.device)

    def _update_encoder(self):
        for i in range(len(self.encoder)):
            self.encoder[i].load_state_dict(self.online_model[i].state_dict())

    def _initialize_optimizer_scheduler(self):
        self.optimizer = torch.optim.Adam(
            self.online_model.parameters(),
            lr=self.config["learning_rate"],
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=self.config["weight_decay"]
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=200, verbose=True)

    def _forward_pass(self, model, sentences, train):
        if train:
            model.train()
        else:
            model.eval()
        features = model.tokenize(sentences)
        features = {k: v.to(self.device, non_blocking=True) for k, v in features.items()}
        with autocast("cuda", enabled=self.config.get("use_amp", True)):
            embeddings = model[0](features)["token_embeddings"]
            pooled = model[1]({"token_embeddings": embeddings})["sentence_embedding"]
            if len(model) > 2:
                return model[2](pooled)
            else:
                return pooled

    def _mixed_barlow_twins_loss(self, z_a, z_b):
        N, D = z_a.size()

        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + 1e-6)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + 1e-6)

        c = torch.matmul(z_a_norm.T, z_b_norm) / N
        I = torch.eye(D, device=z_a.device)
        c_diff = (c - I).pow(2)
        off_diag_mask = ~torch.eye(D, dtype=torch.bool, device=z_a.device)
        c_diff[off_diag_mask] *= self.config["lambda_bt"]
        loss_bt = c_diff.sum()

        # MixUp Regularization
        idx = torch.randperm(N)
        alpha = torch.tensor(np.random.beta(1.0, 1.0), device=z_a.device, dtype=z_a.dtype)
        # Instead of mixing raw inputs (not applicable for text), mix the embeddings
        z_m = alpha * z_a + (1 - alpha) * z_b[idx, :]
        z_m_norm = (z_m - z_m.mean(dim=0)) / (z_m.std(dim=0) + 1e-6)
        cc_m_a = torch.matmul(z_m_norm.T, z_a_norm) / N
        cc_m_b = torch.matmul(z_m_norm.T, z_b_norm) / N
        cc_m_a_gt = alpha * torch.matmul(z_a_norm.T, z_a_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_a_norm) / N
        cc_m_b_gt = alpha * torch.matmul(z_a_norm.T, z_b_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_b_norm) / N
        loss_mix = self.config["lambda_mixup"] * self.config["lambda_bt"] * (
            (cc_m_a - cc_m_a_gt).pow(2).sum() + (cc_m_b - cc_m_b_gt).pow(2).sum()
        )
        return loss_bt + loss_mix

    def _evaluate_without_heads(self):
        self._update_encoder()
        self.encoder.eval()
        indices = list(range(len(self.eval_dataset["sentence1"])))
        random.shuffle(indices)
        sentences1 = [self.eval_dataset["sentence1"][i] for i in indices]
        sentences2 = [self.eval_dataset["sentence2"][i] for i in indices]
        scores = [self.eval_dataset["score"][i] for i in indices]
        evaluator = EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
        return evaluator(self.encoder)

    def fit(self):
        latest_eval_metrics = {}
        eval_count = 0
        for epoch in range(self.config["epochs"]):
            early_stop = False
            epoch_loss = 0
            pbar = tqdm(self.train_data_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}:")
            for idx, sentences in enumerate(pbar):
                s1 = self._apply_augmentation(sentences, self._get_random_augmentation())
                s2 = self._apply_augmentation(sentences, self._get_random_augmentation())

                with autocast("cuda", enabled=self.config.get("use_amp", True)):
                    z_a = self._forward_pass(self.online_model, s1, train=True)
                    z_b = self._forward_pass(self.online_model, s2, train=True)
                    loss = self._mixed_barlow_twins_loss(z_a, z_b)
                epoch_loss += loss.item()

                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()

                scale = self.scaler.get_scale()
                scale = scale if scale != 0 else 1e-8
                total_norm_scaled = 0.0
                for param in self.online_model.parameters():
                    if param.grad is not None:
                        param_norm = param.grad.data.norm(2).item()
                        total_norm_scaled += param_norm ** 2
                total_norm_scaled = math.sqrt(total_norm_scaled)
                total_norm = total_norm_scaled / scale
                mean_grad_norm = total_norm / len(list(self.online_model.parameters()))

                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step(loss)

                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    latest_eval_metrics = self._evaluate_without_heads()
                    last_spearman = latest_eval_metrics.get('spearman_cosine', -float('inf'))
                    last_pearson = latest_eval_metrics.get('pearson_cosine', -float('inf'))
                    if last_spearman > self.best_spearman:
                        self.best_spearman = last_spearman
                        self.best_pearson = last_pearson
                        self.patience_counter = 0
                    else:
                        self.patience_counter += self.evaluate_steps

                    eval_count += 1
                    # Every 50 evaluation steps, compute test evaluator metrics.
                    if eval_count % 5 == 0 or eval_count == 0:
                        self._update_encoder()
                        self.encoder.eval()
                        test_metrics = self.test_evaluator(self.encoder)
                        self.test_sts_pearson_cosine_values.append(test_metrics.get('pearson_cosine', np.nan))
                        self.test_sts_spearman_cosine_values.append(test_metrics.get('spearman_cosine', np.nan))
                        self.test_iterations.append(idx)

                pbar.set_postfix({
                    "loss": loss.item(),
                    **latest_eval_metrics,
                    "mean_grad_norm": mean_grad_norm,
                    "learning_rate": self.optimizer.param_groups[0]['lr']
                })

                self.loss_values.append(loss.item())
                self.sts_pearson_cosine_values.append(latest_eval_metrics.get('pearson_cosine', np.nan))
                self.sts_spearman_cosine_values.append(latest_eval_metrics.get('spearman_cosine', np.nan))
                self.mean_grad_norm_values.append(mean_grad_norm)
                self.variance_values.append(torch.var(z_a, dim=0).mean().item())
                self.learning_rate_values.append(self.optimizer.param_groups[0]['lr'])
                self.iterations.append(idx)
                self.epochs.append(epoch)

                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    self._update_plot()
                    if self.patience_counter >= self.config["patience"]:
                        early_stop = True
                        print(f"Early stopping triggered at epoch {epoch+1}, iteration {idx}.")
                        print(f"Best Spearman Correlation: {self.best_spearman}")
                        print(f"Best Pearson Correlation: {self.best_pearson}")
                        break

            if early_stop:
                break
            avg_loss = epoch_loss / len(self.train_data_loader)
            pbar.set_description(f"Epoch {epoch+1} Loss: {avg_loss}")

    def cleanup(self):
        del self.online_model, self.optimizer, self.scheduler, self.scaler
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

class SentenceDataset(torch.utils.data.Dataset):
    def __init__(self, sentences):
        self.sentences = sentences
    def __len__(self):
        return len(self.sentences)
    def __getitem__(self, idx):
        return self.sentences[idx]

config = {
    "model_name": "distilbert-base-uncased",
    "batch_size": 256,
    "projection_depth": 4,
    "projection_size": 6144,
    "epochs": 1,
    "warmup_proportion": 0.0,
    "max_seq_length": 64,
    "aug_p": 0.3,
    "learning_rate": 1e-4,
    "model_save_path": f"train_stsb_bt-distilbert-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
    "num_workers": 2,
    "weight_decay": 1e-5,
    "lambda_bt": 0.0051,
    "lambda_mixup": 0.8,
    "use_amp": True,
    "patience": 500
}

trainer = BarlowTwinsNCSE(config)
trainer.fit()

Epoch 1/1::   4%|▍         | 1232/30804 [09:17<3:43:10,  2.21it/s, loss=7.05e+3, pearson_cosine=0.704, spearman_cosine=0.727, mean_grad_norm=60.5, learning_rate=1.25e-5]

Early stopping triggered at epoch 1, iteration 1232.
Best Spearman Correlation: 0.7417269994158159
Best Pearson Correlation: 0.726118031360284





## Testing Model

In [None]:
trainer._update_encoder()
trainer.encoder.eval()
trainer.test_evaluator(trainer.encoder)

{'pearson_cosine': 0.6668615350342543, 'spearman_cosine': 0.66859453666296}

## Saving Model

In [None]:
# Ensure the encoder is in evaluation mode before saving
trainer.encoder.eval()

# Save the state dictionary of the encoder
torch.save(trainer.encoder.state_dict(), trainer.config["model_save_path"])
print(f"Encoder weights saved to {trainer.config['model_save_path']}")

Encoder weights saved to train_stsb_bt-distilbert-2025-02-12_16-13-57


## Hyperparameter Tuning using Optuna

Initial testing

In [None]:
import textwrap
import os
import torch
from datetime import datetime
from sentence_transformers import SentenceTransformer, models, util
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import copy
import gc
import math
import random
import nlpaug.augmenter.word as naw
from IPython.display import display, clear_output
from tqdm import tqdm
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from torch import GradScaler
from torch.amp import autocast
import optuna

class BarlowTwinsNCSE:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.backends.cudnn.benchmark = True  # Enable cuDNN benchmarking
        self._prepare_datasets()
        self._initialize_models()
        self._initialize_optimizer_scheduler()
        self.scaler = GradScaler("cuda", enabled=self.config.get("use_amp", True))  # GradScaler for AMP
        self.best_spearman = -float("inf")
        self.best_pearson = -float("inf")
        self.patience_counter = 0
        self.augmenters = [
            naw.SynonymAug(aug_src='wordnet', aug_p=self.config["aug_p"]),
            naw.RandomWordAug(action="swap", aug_p=self.config["aug_p"]),
            naw.RandomWordAug(aug_p=self.config["aug_p"]),
        ]
        self.test_sts_pearson_cosine_values = []
        self.test_sts_spearman_cosine_values = []
        self.test_iterations = []
        self._create_plot()

    def _create_plot(self):
        self.loss_values = []
        self.sts_pearson_cosine_values = []
        self.sts_spearman_cosine_values = []
        self.mean_grad_norm_values = []
        self.variance_values = []
        self.learning_rate_values = []
        self.iterations = []
        self.epochs = []

        # 3 rows x 2 columns grid with 6 subplots
        self.fig = make_subplots(
            rows=3, cols=2,
            subplot_titles=(
                "Loss vs Iterations",
                "Mean Gradient Norm vs Iterations",
                "Variance vs Iterations",
                "Learning Rate vs Iterations",
                "Dev STS Cosine (Pearson & Spearman) vs Iterations",
                "Test STS Cosine (Pearson & Spearman) vs Iterations"
            )
        )
        # Dev metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Loss'), row=1, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Mean Gradient Norm'), row=1, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Variance'), row=2, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Learning Rate'), row=2, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Pearson Cosine'), row=3, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Spearman Cosine'), row=3, col=1)
        # Test metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Pearson Cosine'), row=3, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Spearman Cosine'), row=3, col=2)

        # Prepare footer text
        footer_text = ", ".join([f"{key}={value}" for key, value in self.config.items()])
        augmenters_text = ", Augmenters: " + ", ".join([f"{aug.name}[{aug.action}:{aug.aug_p}]" for aug in self.augmenters])
        footer_text += augmenters_text
        wrapped_footer = "<br>".join(textwrap.wrap(footer_text, width=160))

        # Configure the download button
        self.plot_config = {
            'toImageButtonOptions': {
                'filename': self.config["model_save_path"],
                'format': 'png',
                'width': 1200,
                'height': 800,
                'scale': 1
            }
        }

        self.fig.update_layout(
            width=1200,
            height=800,
            title_text='Training Metrics',
            showlegend=True,
            margin=dict(l=50, r=50, t=100, b=150),
            annotations=[
                dict(
                    text=wrapped_footer,
                    showarrow=False,
                    xref="paper",
                    yref="paper",
                    x=0.5,
                    y=-0.15,
                    xanchor='center',
                    yanchor='top',
                    align="center",
                    font=dict(size=10, color="gray")
                )
            ]
        )

        # Annotation for best metrics
        self.best_metrics_annotation_index = len(self.fig.layout.annotations)
        self.fig.add_annotation(
            text="Best Spearman (Test): N/A<br>Best Pearson (Test): N/A<br>Best Spearman (Val): N/A<br>Best Pearson (Val): N/A",
            showarrow=False,
            xref="paper",
            yref="paper",
            x=1.0,
            y=0.0,
            xanchor='right',
            yanchor='bottom',
            align="right",
            font=dict(size=12, color="blue")
        )

        self.fig.show(config=self.plot_config)

    def _update_traces(self):
        with self.fig.batch_update():
            self.fig.data[0].x = self.iterations
            self.fig.data[0].y = self.loss_values
            self.fig.data[1].x = self.iterations
            self.fig.data[1].y = self.mean_grad_norm_values
            self.fig.data[2].x = self.iterations
            self.fig.data[2].y = self.variance_values
            self.fig.data[3].x = self.iterations
            self.fig.data[3].y = self.learning_rate_values
            self.fig.data[4].x = self.iterations
            self.fig.data[4].y = self.sts_pearson_cosine_values
            self.fig.data[5].x = self.iterations
            self.fig.data[5].y = self.sts_spearman_cosine_values
            self.fig.data[6].x = self.test_iterations
            self.fig.data[6].y = self.test_sts_pearson_cosine_values
            self.fig.data[7].x = self.test_iterations
            self.fig.data[7].y = self.test_sts_spearman_cosine_values

            for i in range(1, 4):
                for j in range(1, 3):
                    self.fig.update_yaxes(autorange=True, row=i, col=j)
                    self.fig.update_xaxes(autorange=True, row=i, col=j)

    def _update_plot(self):
        self._update_traces()
        unique_epochs = sorted(set(self.epochs))
        frames = []
        for ep in unique_epochs:
            indices = [i for i, e in enumerate(self.epochs) if e == ep]
            frame = go.Frame(
                data=[
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.loss_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.mean_grad_norm_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.variance_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.learning_rate_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_pearson_cosine_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_spearman_cosine_values[i] for i in indices])
                ],
                name=str(ep)
            )
            frames.append(frame)
        self.fig.frames = frames

        slider_steps = [
            {"args": [[str(ep)], {"frame": {"duration": 0, "redraw": True},
                                   "mode": "immediate", "transition": {"duration": 0}}],
             "label": str(ep), "method": "animate"} for ep in unique_epochs
        ]

        self.fig.update_layout(
            sliders=[{
                "active": len(unique_epochs) - 1 if unique_epochs else 0,
                "currentvalue": {"prefix": "Epoch: "},
                "pad": {"t": 50},
                "steps": slider_steps
            }]
        )

        # Compute best test metrics if available
        if self.test_sts_spearman_cosine_values:
            best_test_spearman = max(self.test_sts_spearman_cosine_values)
        else:
            best_test_spearman = float('nan')
        if self.test_sts_pearson_cosine_values:
            best_test_pearson = max(self.test_sts_pearson_cosine_values)
        else:
            best_test_pearson = float('nan')

        self.fig.layout.annotations[self.best_metrics_annotation_index].text = (
            f"Best Spearman (Test): {best_test_spearman:.4f}<br>"
            f"Best Pearson (Test): {best_test_pearson:.4f}<br>"
            f"Best Spearman (Val): {self.best_spearman:.4f}<br>"
            f"Best Pearson (Val): {self.best_pearson:.4f}"
        )

        for i in range(1, 4):
            for j in range(1, 3):
                self.fig.update_yaxes(autorange=True, row=i, col=j)
                self.fig.update_xaxes(autorange=True, row=i, col=j)

        clear_output(wait=True)
        self.fig.show(config=self.plot_config)

    def _prepare_datasets(self):
        wikipedia_url = "https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt"
        wikipedia_dataset_path = "data/wiki1m_for_simcse.txt"
        if not os.path.exists(wikipedia_dataset_path):
            util.http_get(wikipedia_url, wikipedia_dataset_path)
        train_sentences = []
        with open(wikipedia_dataset_path, encoding="utf8") as f:
            for line in f:
                line = line.strip()
                if len(line) >= 10:
                    train_sentences.append(line)
        self.train_sentences = train_sentences

        self.train_dataset = load_dataset("sentence-transformers/stsb", split="train")
        self.eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
        self.test_dataset = load_dataset("sentence-transformers/stsb", split="test")

        self.train_data_loader = DataLoader(
            SentenceDataset(self.train_sentences),
            batch_size=self.config["batch_size"],
            shuffle=True,
            num_workers=self.config["num_workers"],
            pin_memory=True
        )
        self.test_evaluator = EmbeddingSimilarityEvaluator(
            sentences1=self.test_dataset["sentence1"],
            sentences2=self.test_dataset["sentence2"],
            scores=self.test_dataset["score"]
        )

        self.evaluate_steps = max(len(self.train_data_loader) // 50, 1)

        # Ensure patience is at least twice as large as evaluate_steps
        if self.config.get("patience", 0) < self.evaluate_steps * 5:
            print(f"Warning: Patience ({self.config['patience']}) is less than evaluation steps x5 ({self.evaluate_steps * 5}). Adjusting patience to {self.evaluate_steps * 5}.")
            self.config["patience"] = self.evaluate_steps * 5


    def _get_random_augmentation(self):
        return random.choice(self.augmenters)

    def _apply_augmentation(self, sentences, aug):
        return aug.augment(sentences)

    def _initialize_models(self):
        word_embedding_model = models.Transformer(
            self.config["model_name"],
            max_seq_length=self.config["max_seq_length"],
            config_args={"attention_dropout": self.config["aug_p"], "dropout": self.config["aug_p"]}
        )
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        projection_layers = [torch.nn.Linear(768, self.config["projection_size"])]
        for _ in range(self.config["projection_depth"] - 1):
            projection_layers.append(torch.nn.BatchNorm1d(self.config["projection_size"]))
            projection_layers.append(torch.nn.ReLU())
            projection_layers.append(torch.nn.Linear(self.config["projection_size"], self.config["projection_size"]))
        projection_head = torch.nn.Sequential(*projection_layers)
        self.online_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, projection_head]).to(self.device)
        encoder_modules = [copy.deepcopy(self.online_model[i]) for i in range(2)]
        self.encoder = SentenceTransformer(modules=encoder_modules).to(self.device)

    def _update_encoder(self):
        for i in range(len(self.encoder)):
            self.encoder[i].load_state_dict(self.online_model[i].state_dict())

    def _initialize_optimizer_scheduler(self):
        self.optimizer = torch.optim.Adam(
            self.online_model.parameters(),
            lr=self.config["learning_rate"],
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=self.config["weight_decay"]
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=200, verbose=True)

    def _forward_pass(self, model, sentences, train):
        if train:
            model.train()
        else:
            model.eval()
        features = model.tokenize(sentences)
        features = {k: v.to(self.device, non_blocking=True) for k, v in features.items()}
        with autocast("cuda", enabled=self.config.get("use_amp", True)):
            embeddings = model[0](features)["token_embeddings"]
            pooled = model[1]({"token_embeddings": embeddings})["sentence_embedding"]
            if len(model) > 2:
                return model[2](pooled)
            else:
                return pooled

    def _mixed_barlow_twins_loss(self, z_a, z_b):
        N, D = z_a.size()
        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + 1e-6)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + 1e-6)
        c = torch.matmul(z_a_norm.T, z_b_norm) / N
        I = torch.eye(D, device=z_a.device)
        c_diff = (c - I).pow(2)
        off_diag_mask = ~torch.eye(D, dtype=torch.bool, device=z_a.device)
        c_diff[off_diag_mask] *= self.config["lambda_bt"]
        loss_bt = c_diff.sum()
        # MixUp Regularization
        idx = torch.randperm(N)
        alpha = torch.tensor(np.random.beta(1.0, 1.0), device=z_a.device, dtype=z_a.dtype)
        z_m = alpha * z_a + (1 - alpha) * z_b[idx, :]
        z_m_norm = (z_m - z_m.mean(dim=0)) / (z_m.std(dim=0) + 1e-6)
        cc_m_a = torch.matmul(z_m_norm.T, z_a_norm) / N
        cc_m_b = torch.matmul(z_m_norm.T, z_b_norm) / N
        cc_m_a_gt = alpha * torch.matmul(z_a_norm.T, z_a_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_a_norm) / N
        cc_m_b_gt = alpha * torch.matmul(z_a_norm.T, z_b_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_b_norm) / N
        loss_mix = self.config["lambda_mixup"] * self.config["lambda_bt"] * (
            (cc_m_a - cc_m_a_gt).pow(2).sum() + (cc_m_b - cc_m_b_gt).pow(2).sum()
        )
        return loss_bt + loss_mix

    def _evaluate_without_heads(self):
        self._update_encoder()
        self.encoder.eval()
        indices = list(range(len(self.eval_dataset["sentence1"])))
        random.shuffle(indices)
        sentences1 = [self.eval_dataset["sentence1"][i] for i in indices]
        sentences2 = [self.eval_dataset["sentence2"][i] for i in indices]
        scores = [self.eval_dataset["score"][i] for i in indices]
        evaluator = EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
        return evaluator(self.encoder)

    def fit(self):
        latest_eval_metrics = {}
        for epoch in range(self.config["epochs"]):
            early_stop = False
            epoch_loss = 0
            pbar = tqdm(self.train_data_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}:")
            for idx, sentences in enumerate(pbar):
                s1 = self._apply_augmentation(sentences, self._get_random_augmentation())
                s2 = self._apply_augmentation(sentences, self._get_random_augmentation())
                with autocast("cuda", enabled=self.config.get("use_amp", True)):
                    z_a = self._forward_pass(self.online_model, s1, train=True)
                    z_b = self._forward_pass(self.online_model, s2, train=True)
                    loss = self._mixed_barlow_twins_loss(z_a, z_b)
                epoch_loss += loss.item()
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                scale = self.scaler.get_scale()
                scale = scale if scale != 0 else 1e-8
                total_norm_scaled = 0.0
                for param in self.online_model.parameters():
                    if param.grad is not None:
                        param_norm = param.grad.data.norm(2).item()
                        total_norm_scaled += param_norm ** 2
                total_norm_scaled = math.sqrt(total_norm_scaled)
                total_norm = total_norm_scaled / scale
                mean_grad_norm = total_norm / len(list(self.online_model.parameters()))
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step(loss)
                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    latest_eval_metrics = self._evaluate_without_heads()
                    last_spearman = latest_eval_metrics.get('spearman_cosine', -float('inf'))
                    last_pearson = latest_eval_metrics.get('pearson_cosine', -float('inf'))
                    self._update_encoder()
                    self.encoder.eval()
                    test_metrics = self.test_evaluator(self.encoder)
                    self.test_sts_pearson_cosine_values.append(test_metrics.get('pearson_cosine', np.nan))
                    self.test_sts_spearman_cosine_values.append(test_metrics.get('spearman_cosine', np.nan))
                    self.test_iterations.append(idx)
                    self._update_plot()

                    if last_spearman > self.best_spearman:
                        self.best_spearman = last_spearman
                        self.best_pearson = last_pearson
                        self.patience_counter = 0
                    else:
                        self.patience_counter += self.evaluate_steps
                pbar.set_postfix({
                    "loss": loss.item(),
                    **latest_eval_metrics,
                    "mean_grad_norm": mean_grad_norm,
                    "learning_rate": self.optimizer.param_groups[0]['lr']
                })
                self.loss_values.append(loss.item())
                self.sts_pearson_cosine_values.append(latest_eval_metrics.get('pearson_cosine', np.nan))
                self.sts_spearman_cosine_values.append(latest_eval_metrics.get('spearman_cosine', np.nan))
                self.mean_grad_norm_values.append(mean_grad_norm)
                self.variance_values.append(torch.var(z_a, dim=0).mean().item())
                self.learning_rate_values.append(self.optimizer.param_groups[0]['lr'])
                self.iterations.append(idx)
                self.epochs.append(epoch)
                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    if self.patience_counter >= self.config["patience"]:
                        early_stop = True
                        print(f"Early stopping triggered at epoch {epoch+1}, iteration {idx}.")
                        print(f"Best Spearman Correlation: {self.best_spearman}")
                        print(f"Best Pearson Correlation: {self.best_pearson}")
                        break
            if early_stop:
                break
            avg_loss = epoch_loss / len(self.train_data_loader)
            pbar.set_description(f"Epoch {epoch+1} Loss: {avg_loss}")

    def cleanup(self):
        del self.online_model, self.optimizer, self.scheduler, self.scaler
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

class SentenceDataset(torch.utils.data.Dataset):
    def __init__(self, sentences):
        self.sentences = sentences
    def __len__(self):
        return len(self.sentences)
    def __getitem__(self, idx):
        return self.sentences[idx]

# Mount Google Drive for checkpoint and study persistence
from google.colab import drive
drive.mount('/content/drive')

CHECKPOINT_DIR = "/content/drive/MyDrive/violet_checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
STUDY_DB_PATH = os.path.join(CHECKPOINT_DIR, "llm_finetuning_study.db")

def save_checkpoint(trial_number, trainer):
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_trial_{trial_number}.pt")
    torch.save({
        'online_model_state_dict': trainer.online_model.state_dict(),
        'optimizer_state_dict': trainer.optimizer.state_dict(),
        'scheduler_state_dict': trainer.scheduler.state_dict(),
        'best_spearman': trainer.best_spearman,
        'best_pearson': trainer.best_pearson,
        'epochs': trainer.epochs,
        'iterations': trainer.iterations,
    }, checkpoint_path)

    # Also save the graph to a subfolder called "graphs"
    graphs_dir = os.path.join(CHECKPOINT_DIR, "graphs")
    os.makedirs(graphs_dir, exist_ok=True)
    graph_path = os.path.join(graphs_dir, f"checkpoint_trial_{trial_number}.png")
    trainer.fig.write_image(graph_path)
    print(f"Checkpoint and graph saved to {checkpoint_path} and {graph_path}")

    return checkpoint_path

def load_checkpoint(trainer, checkpoint_path):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        trainer.online_model.load_state_dict(checkpoint['online_model_state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        trainer.best_spearman = checkpoint.get('best_spearman', -float("inf"))
        trainer.best_pearson = checkpoint.get('best_pearson', -float("inf"))
        trainer.epochs = checkpoint.get('epochs', [])
        trainer.iterations = checkpoint.get('iterations', [])
        print(f"Loaded checkpoint from {checkpoint_path}")

def objective(trial):
    # Experiment with various hyperparameters
    config = {
        "model_name": "distilbert-base-uncased",
        "batch_size": trial.suggest_categorical("batch_size", [64, 128, 256, 512]),
        "projection_depth": trial.suggest_int("projection_depth", 2, 6),
        "projection_size": trial.suggest_categorical("projection_size", [2048, 4096, 6144, 8192]),
        "epochs": 1,  # fixed at 1
        "warmup_proportion": 0.0,
        "max_seq_length": trial.suggest_categorical("max_seq_length", [32, 64, 75]),
        "aug_p": trial.suggest_uniform("aug_p", 0.2, 0.4),
        "learning_rate": trial.suggest_loguniform("learning_rate", 3e-5, 1e-3),
        "model_save_path": os.path.join(CHECKPOINT_DIR, f"train_stsb_bt-distilbert-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_trial{trial.number}"),
        "num_workers": 10,
        "weight_decay": trial.suggest_uniform("weight_decay", 0.1, 0.2),
        "lambda_bt": trial.suggest_uniform("lambda_bt", 0.001, 0.2),
        "lambda_mixup": trial.suggest_uniform("lambda_mixup", 0.6, 1.5),
        "use_amp": True,
        "patience": 500
    }

    trainer = BarlowTwinsNCSE(config)
    # Resume from checkpoint if available
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_trial_{trial.number}.pt")
    if os.path.exists(checkpoint_path):
        load_checkpoint(trainer, checkpoint_path)

    try:
        trainer.fit()
    except KeyboardInterrupt:
        save_checkpoint(trial.number, trainer)
        raise optuna.TrialPruned("Trial interrupted and checkpoint saved.")

    # Save checkpoint at end of trial
    save_checkpoint(trial.number, trainer)

    # Use the STS-B validation metric (spearman cosine) as objective.
    val_metrics = trainer._evaluate_without_heads()
    return val_metrics.get("spearman_cosine", -float("inf"))

# Create or resume an Optuna study persisted on Google Drive
study = optuna.create_study(
    direction="maximize",
    study_name="llm_finetuning_study",
    storage=f"sqlite:///{STUDY_DB_PATH}",
    load_if_exists=True
)
study.optimize(objective, n_trials=20)

In [None]:
import textwrap
import os
import torch
from datetime import datetime
from sentence_transformers import SentenceTransformer, models, util
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import copy
import gc
import math
import random
import nlpaug.augmenter.word as naw
from IPython.display import display, clear_output
from tqdm import tqdm
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from torch import GradScaler
from torch.amp import autocast
import optuna

# ----- Your BarlowTwinsNCSE and SentenceDataset classes as provided -----
# (The code below is the same as your original, including the plotting functions.)

class BarlowTwinsNCSE:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.backends.cudnn.benchmark = True  # Enable cuDNN benchmarking
        self._prepare_datasets()
        self._initialize_models()
        self._initialize_optimizer_scheduler()
        self.scaler = GradScaler("cuda", enabled=self.config.get("use_amp", True))  # GradScaler for AMP
        self.best_spearman = -float("inf")
        self.best_pearson = -float("inf")
        self.patience_counter = 0
        self.augmenters = [
            naw.SynonymAug(aug_src='wordnet', aug_p=self.config["aug_p"]),
            naw.RandomWordAug(action="swap", aug_p=self.config["aug_p"]),
            naw.RandomWordAug(aug_p=self.config["aug_p"]),
        ]
        self.test_sts_pearson_cosine_values = []
        self.test_sts_spearman_cosine_values = []
        self.test_iterations = []
        self._create_plot()

    def _create_plot(self):
        self.loss_values = []
        self.sts_pearson_cosine_values = []
        self.sts_spearman_cosine_values = []
        self.mean_grad_norm_values = []
        self.variance_values = []
        self.learning_rate_values = []
        self.iterations = []
        self.epochs = []

        # 3 rows x 2 columns grid with 6 subplots
        self.fig = make_subplots(
            rows=3, cols=2,
            subplot_titles=(
                "Loss vs Iterations",
                "Mean Gradient Norm vs Iterations",
                "Variance vs Iterations",
                "Learning Rate vs Iterations",
                "Dev STS Cosine (Pearson & Spearman) vs Iterations",
                "Test STS Cosine (Pearson & Spearman) vs Iterations"
            )
        )
        # Dev metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Loss'), row=1, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Mean Gradient Norm'), row=1, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Variance'), row=2, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Learning Rate'), row=2, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Pearson Cosine'), row=3, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Spearman Cosine'), row=3, col=1)
        # Test metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Pearson Cosine'), row=3, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Spearman Cosine'), row=3, col=2)

        # Prepare footer text
        footer_text = ", ".join([f"{key}={value}" for key, value in self.config.items()])
        augmenters_text = ", Augmenters: " + ", ".join([f"{aug.name}[{aug.action}:{aug.aug_p}]" for aug in self.augmenters])
        footer_text += augmenters_text
        wrapped_footer = "<br>".join(textwrap.wrap(footer_text, width=160))

        # Configure the download button
        self.plot_config = {
            'toImageButtonOptions': {
                'filename': self.config["model_save_path"],
                'format': 'png',
                'width': 1200,
                'height': 800,
                'scale': 1
            }
        }

        self.fig.update_layout(
            width=1200,
            height=800,
            title_text='Training Metrics',
            showlegend=True,
            margin=dict(l=50, r=50, t=100, b=150),
            annotations=[
                dict(
                    text=wrapped_footer,
                    showarrow=False,
                    xref="paper",
                    yref="paper",
                    x=0.5,
                    y=-0.15,
                    xanchor='center',
                    yanchor='top',
                    align="center",
                    font=dict(size=10, color="gray")
                )
            ]
        )

        # Annotation for best metrics
        self.best_metrics_annotation_index = len(self.fig.layout.annotations)
        self.fig.add_annotation(
            text="Best Spearman (Test): N/A<br>Best Pearson (Test): N/A<br>Best Spearman (Val): N/A<br>Best Pearson (Val): N/A",
            showarrow=False,
            xref="paper",
            yref="paper",
            x=1.0,
            y=0.0,
            xanchor='right',
            yanchor='bottom',
            align="right",
            font=dict(size=12, color="blue")
        )

        self.fig.show(config=self.plot_config)

    def _update_traces(self):
        with self.fig.batch_update():
            self.fig.data[0].x = self.iterations
            self.fig.data[0].y = self.loss_values
            self.fig.data[1].x = self.iterations
            self.fig.data[1].y = self.mean_grad_norm_values
            self.fig.data[2].x = self.iterations
            self.fig.data[2].y = self.variance_values
            self.fig.data[3].x = self.iterations
            self.fig.data[3].y = self.learning_rate_values
            self.fig.data[4].x = self.iterations
            self.fig.data[4].y = self.sts_pearson_cosine_values
            self.fig.data[5].x = self.iterations
            self.fig.data[5].y = self.sts_spearman_cosine_values
            self.fig.data[6].x = self.test_iterations
            self.fig.data[6].y = self.test_sts_pearson_cosine_values
            self.fig.data[7].x = self.test_iterations
            self.fig.data[7].y = self.test_sts_spearman_cosine_values

            for i in range(1, 4):
                for j in range(1, 3):
                    self.fig.update_yaxes(autorange=True, row=i, col=j)
                    self.fig.update_xaxes(autorange=True, row=i, col=j)

    def _update_plot(self):
        self._update_traces()
        unique_epochs = sorted(set(self.epochs))
        frames = []
        for ep in unique_epochs:
            indices = [i for i, e in enumerate(self.epochs) if e == ep]
            frame = go.Frame(
                data=[
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.loss_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.mean_grad_norm_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.variance_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.learning_rate_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_pearson_cosine_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_spearman_cosine_values[i] for i in indices])
                ],
                name=str(ep)
            )
            frames.append(frame)
        self.fig.frames = frames

        slider_steps = [
            {"args": [[str(ep)], {"frame": {"duration": 0, "redraw": True},
                                   "mode": "immediate", "transition": {"duration": 0}}],
             "label": str(ep), "method": "animate"} for ep in unique_epochs
        ]

        self.fig.update_layout(
            sliders=[{
                "active": len(unique_epochs) - 1 if unique_epochs else 0,
                "currentvalue": {"prefix": "Epoch: "},
                "pad": {"t": 50},
                "steps": slider_steps
            }]
        )

        # Compute best test metrics if available
        if self.test_sts_spearman_cosine_values:
            best_test_spearman = max(self.test_sts_spearman_cosine_values)
        else:
            best_test_spearman = float('nan')
        if self.test_sts_pearson_cosine_values:
            best_test_pearson = max(self.test_sts_pearson_cosine_values)
        else:
            best_test_pearson = float('nan')

        self.fig.layout.annotations[self.best_metrics_annotation_index].text = (
            f"Best Spearman (Test): {best_test_spearman:.4f}<br>"
            f"Best Pearson (Test): {best_test_pearson:.4f}<br>"
            f"Best Spearman (Val): {self.best_spearman:.4f}<br>"
            f"Best Pearson (Val): {self.best_pearson:.4f}"
        )

        for i in range(1, 4):
            for j in range(1, 3):
                self.fig.update_yaxes(autorange=True, row=i, col=j)
                self.fig.update_xaxes(autorange=True, row=i, col=j)

        clear_output(wait=True)
        self.fig.show(config=self.plot_config)

    def _prepare_datasets(self):
        wikipedia_url = "https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt"
        wikipedia_dataset_path = "data/wiki1m_for_simcse.txt"
        if not os.path.exists(wikipedia_dataset_path):
            util.http_get(wikipedia_url, wikipedia_dataset_path)
        train_sentences = []
        with open(wikipedia_dataset_path, encoding="utf8") as f:
            for line in f:
                line = line.strip()
                if len(line) >= 10:
                    train_sentences.append(line)
        self.train_sentences = train_sentences

        self.train_dataset = load_dataset("sentence-transformers/stsb", split="train")
        self.eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
        self.test_dataset = load_dataset("sentence-transformers/stsb", split="test")

        self.train_data_loader = DataLoader(
            SentenceDataset(self.train_sentences),
            batch_size=self.config["batch_size"],
            shuffle=True,
            num_workers=self.config["num_workers"],
            pin_memory=True
        )
        self.test_evaluator = EmbeddingSimilarityEvaluator(
            sentences1=self.test_dataset["sentence1"],
            sentences2=self.test_dataset["sentence2"],
            scores=self.test_dataset["score"]
        )

        self.evaluate_steps = max(len(self.train_data_loader) // 50, 1)

        # Ensure patience is at least twice as large as evaluate_steps
        if self.config.get("patience", 0) < self.evaluate_steps * 5:
            print(f"Warning: Patience ({self.config['patience']}) is less than evaluation steps x5 ({self.evaluate_steps * 5}). Adjusting patience to {self.evaluate_steps * 5}.")
            self.config["patience"] = self.evaluate_steps * 5


    def _get_random_augmentation(self):
        return random.choice(self.augmenters)

    def _apply_augmentation(self, sentences, aug):
        return aug.augment(sentences)

    def _initialize_models(self):
        word_embedding_model = models.Transformer(
            self.config["model_name"],
            max_seq_length=self.config["max_seq_length"],
            config_args={"attention_dropout": self.config["aug_p"], "dropout": self.config["aug_p"]}
        )
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        projection_layers = [torch.nn.Linear(768, self.config["projection_size"])]
        for _ in range(self.config["projection_depth"] - 1):
            projection_layers.append(torch.nn.BatchNorm1d(self.config["projection_size"]))
            projection_layers.append(torch.nn.ReLU())
            projection_layers.append(torch.nn.Linear(self.config["projection_size"], self.config["projection_size"]))
        projection_head = torch.nn.Sequential(*projection_layers)
        self.online_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, projection_head]).to(self.device)
        encoder_modules = [copy.deepcopy(self.online_model[i]) for i in range(2)]
        self.encoder = SentenceTransformer(modules=encoder_modules).to(self.device)

    def _update_encoder(self):
        for i in range(len(self.encoder)):
            self.encoder[i].load_state_dict(self.online_model[i].state_dict())

    def _initialize_optimizer_scheduler(self):
        self.optimizer = torch.optim.Adam(
            self.online_model.parameters(),
            lr=self.config["learning_rate"],
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=self.config["weight_decay"]
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=200, verbose=True)

    def _forward_pass(self, model, sentences, train):
        if train:
            model.train()
        else:
            model.eval()
        features = model.tokenize(sentences)
        features = {k: v.to(self.device, non_blocking=True) for k, v in features.items()}
        with autocast("cuda", enabled=self.config.get("use_amp", True)):
            embeddings = model[0](features)["token_embeddings"]
            pooled = model[1]({"token_embeddings": embeddings})["sentence_embedding"]
            if len(model) > 2:
                return model[2](pooled)
            else:
                return pooled

    def _mixed_barlow_twins_loss(self, z_a, z_b):
        N, D = z_a.size()
        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + 1e-6)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + 1e-6)
        c = torch.matmul(z_a_norm.T, z_b_norm) / N
        I = torch.eye(D, device=z_a.device)
        c_diff = (c - I).pow(2)
        off_diag_mask = ~torch.eye(D, dtype=torch.bool, device=z_a.device)
        c_diff[off_diag_mask] *= self.config["lambda_bt"]
        loss_bt = c_diff.sum()
        # MixUp Regularization
        idx = torch.randperm(N)
        alpha = torch.tensor(np.random.beta(1.0, 1.0), device=z_a.device, dtype=z_a.dtype)
        z_m = alpha * z_a + (1 - alpha) * z_b[idx, :]
        z_m_norm = (z_m - z_m.mean(dim=0)) / (z_m.std(dim=0) + 1e-6)
        cc_m_a = torch.matmul(z_m_norm.T, z_a_norm) / N
        cc_m_b = torch.matmul(z_m_norm.T, z_b_norm) / N
        cc_m_a_gt = alpha * torch.matmul(z_a_norm.T, z_a_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_a_norm) / N
        cc_m_b_gt = alpha * torch.matmul(z_a_norm.T, z_b_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_b_norm) / N
        loss_mix = self.config["lambda_mixup"] * self.config["lambda_bt"] * (
            (cc_m_a - cc_m_a_gt).pow(2).sum() + (cc_m_b - cc_m_b_gt).pow(2).sum()
        )
        return loss_bt + loss_mix

    def _evaluate_without_heads(self):
        self._update_encoder()
        self.encoder.eval()
        indices = list(range(len(self.eval_dataset["sentence1"])))
        random.shuffle(indices)
        sentences1 = [self.eval_dataset["sentence1"][i] for i in indices]
        sentences2 = [self.eval_dataset["sentence2"][i] for i in indices]
        scores = [self.eval_dataset["score"][i] for i in indices]
        evaluator = EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
        return evaluator(self.encoder)

    def fit(self):
        latest_eval_metrics = {}
        for epoch in range(self.config["epochs"]):
            early_stop = False
            epoch_loss = 0
            pbar = tqdm(self.train_data_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}:")
            for idx, sentences in enumerate(pbar):
                s1 = self._apply_augmentation(sentences, self._get_random_augmentation())
                s2 = self._apply_augmentation(sentences, self._get_random_augmentation())
                with autocast("cuda", enabled=self.config.get("use_amp", True)):
                    z_a = self._forward_pass(self.online_model, s1, train=True)
                    z_b = self._forward_pass(self.online_model, s2, train=True)
                    loss = self._mixed_barlow_twins_loss(z_a, z_b)
                epoch_loss += loss.item()
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                scale = self.scaler.get_scale()
                scale = scale if scale != 0 else 1e-8
                total_norm_scaled = 0.0
                for param in self.online_model.parameters():
                    if param.grad is not None:
                        param_norm = param.grad.data.norm(2).item()
                        total_norm_scaled += param_norm ** 2
                total_norm_scaled = math.sqrt(total_norm_scaled)
                total_norm = total_norm_scaled / scale
                mean_grad_norm = total_norm / len(list(self.online_model.parameters()))
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step(loss)
                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    latest_eval_metrics = self._evaluate_without_heads()
                    last_spearman = latest_eval_metrics.get('spearman_cosine', -float('inf'))
                    last_pearson = latest_eval_metrics.get('pearson_cosine', -float('inf'))
                    self._update_encoder()
                    self.encoder.eval()
                    test_metrics = self.test_evaluator(self.encoder)
                    self.test_sts_pearson_cosine_values.append(test_metrics.get('pearson_cosine', np.nan))
                    self.test_sts_spearman_cosine_values.append(test_metrics.get('spearman_cosine', np.nan))
                    self.test_iterations.append(idx)
                    self._update_plot()

                    if last_spearman > self.best_spearman:
                        self.best_spearman = last_spearman
                        self.best_pearson = last_pearson
                        self.patience_counter = 0
                    else:
                        self.patience_counter += self.evaluate_steps
                pbar.set_postfix({
                    "loss": loss.item(),
                    **latest_eval_metrics,
                    "mean_grad_norm": mean_grad_norm,
                    "learning_rate": self.optimizer.param_groups[0]['lr']
                })
                self.loss_values.append(loss.item())
                self.sts_pearson_cosine_values.append(latest_eval_metrics.get('pearson_cosine', np.nan))
                self.sts_spearman_cosine_values.append(latest_eval_metrics.get('spearman_cosine', np.nan))
                self.mean_grad_norm_values.append(mean_grad_norm)
                self.variance_values.append(torch.var(z_a, dim=0).mean().item())
                self.learning_rate_values.append(self.optimizer.param_groups[0]['lr'])
                self.iterations.append(idx)
                self.epochs.append(epoch)
                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    if self.patience_counter >= self.config["patience"]:
                        early_stop = True
                        print(f"Early stopping triggered at epoch {epoch+1}, iteration {idx}.")
                        print(f"Best Spearman Correlation: {self.best_spearman}")
                        print(f"Best Pearson Correlation: {self.best_pearson}")
                        break
            if early_stop:
                break
            avg_loss = epoch_loss / len(self.train_data_loader)
            pbar.set_description(f"Epoch {epoch+1} Loss: {avg_loss}")

    def cleanup(self):
        del self.online_model, self.optimizer, self.scheduler, self.scaler
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

class SentenceDataset(torch.utils.data.Dataset):
    def __init__(self, sentences):
        self.sentences = sentences
    def __len__(self):
        return len(self.sentences)
    def __getitem__(self, idx):
        return self.sentences[idx]

# ----- End of original code -----

# Mount Google Drive for checkpoint and study persistence
from google.colab import drive
drive.mount('/content/drive')

CHECKPOINT_DIR = "/content/drive/MyDrive/violet_checkpoints_2"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
STUDY_DB_PATH = os.path.join(CHECKPOINT_DIR, "llm_finetuning_study_2.db")

def save_checkpoint(trial_number, trainer):
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_trial_{trial_number}.pt")
    torch.save({
        'online_model_state_dict': trainer.online_model.state_dict(),
        'optimizer_state_dict': trainer.optimizer.state_dict(),
        'scheduler_state_dict': trainer.scheduler.state_dict(),
        'best_spearman': trainer.best_spearman,
        'best_pearson': trainer.best_pearson,
        'epochs': trainer.epochs,
        'iterations': trainer.iterations,
    }, checkpoint_path)

    # Also save the graph to a subfolder called "graphs"
    graphs_dir = os.path.join(CHECKPOINT_DIR, "graphs")
    os.makedirs(graphs_dir, exist_ok=True)
    graph_path = os.path.join(graphs_dir, f"checkpoint_trial_{trial_number}.png")
    trainer.fig.write_image(graph_path)
    print(f"Checkpoint and graph saved to {checkpoint_path} and {graph_path}")

    return checkpoint_path

def load_checkpoint(trainer, checkpoint_path):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        trainer.online_model.load_state_dict(checkpoint['online_model_state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        trainer.best_spearman = checkpoint.get('best_spearman', -float("inf"))
        trainer.best_pearson = checkpoint.get('best_pearson', -float("inf"))
        trainer.epochs = checkpoint.get('epochs', [])
        trainer.iterations = checkpoint.get('iterations', [])
        print(f"Loaded checkpoint from {checkpoint_path}")

def objective(trial):
    # Experiment with various hyperparameters
    config = {
        "model_name": "distilbert-base-uncased",
        "batch_size": 128, #trial.suggest_categorical("batch_size", [64, 128, 256, 512]),
        "projection_depth": 2, # trial.suggest_int("projection_depth", 2, 6),
        "projection_size": trial.suggest_categorical("projection_size", [8192, 12288]),# [2048, 4096, 6144, 8192]),
        "epochs": 1,  # fixed at 1
        "warmup_proportion": 0.0,
        "max_seq_length": 75, #trial.suggest_categorical("max_seq_length", [32, 64, 75]),
        "aug_p": trial.suggest_uniform("aug_p", 0.2, 0.4),
        "learning_rate": trial.suggest_loguniform("learning_rate", 3e-5, 1e-3),
        "model_save_path": os.path.join(CHECKPOINT_DIR, f"train_stsb_bt-distilbert-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_trial{trial.number}"),
        "num_workers": 10,# 2
        "weight_decay": trial.suggest_uniform("weight_decay", 0.1, 0.2),
        "lambda_bt": trial.suggest_uniform("lambda_bt", 0.001, 0.2),
        "lambda_mixup": trial.suggest_uniform("lambda_mixup", 0.6, 1.5),
        "use_amp": True,
        "patience": 500
    }

    trainer = BarlowTwinsNCSE(config)
    # Resume from checkpoint if available
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_trial_{trial.number}.pt")
    if os.path.exists(checkpoint_path):
        load_checkpoint(trainer, checkpoint_path)

    try:
        trainer.fit()
    except KeyboardInterrupt:
        save_checkpoint(trial.number, trainer)
        raise optuna.TrialPruned("Trial interrupted and checkpoint saved.")

    # Save checkpoint at end of trial
    save_checkpoint(trial.number, trainer)

    # Use the STS-B validation metric (spearman cosine) as objective.
    val_metrics = trainer._evaluate_without_heads()
    return val_metrics.get("spearman_cosine", -float("inf"))

# Create or resume an Optuna study persisted on Google Drive
study = optuna.create_study(
    direction="maximize",
    study_name="llm_finetuning_study_2",
    storage=f"sqlite:///{STUDY_DB_PATH}",
    load_if_exists=True
)
study.optimize(objective, n_trials=20)

print("Best trial:")
print(study.best_trial)


Epoch 1/1::  12%|█▏        | 924/7701 [11:07<1:21:33,  1.38it/s, loss=1.5e+3, pearson_cosine=0.564, spearman_cosine=0.586, mean_grad_norm=8.28, learning_rate=0.000674]


Early stopping triggered at epoch 1, iteration 924.
Best Spearman Correlation: 0.7213868430993974
Best Pearson Correlation: 0.7111729038546595
Checkpoint and graph saved to /content/drive/MyDrive/violet_checkpoints_2/checkpoint_trial_19.pt and /content/drive/MyDrive/violet_checkpoints_2/graphs/checkpoint_trial_19.png


[I 2025-03-30 18:35:09,265] Trial 19 finished with value: 0.5863060006247319 and parameters: {'projection_size': 8192, 'aug_p': 0.20943901101203366, 'learning_rate': 0.0006737552047782743, 'weight_decay': 0.1333928565968183, 'lambda_bt': 0.0017107385254106905, 'lambda_mixup': 0.6059879433460729}. Best is trial 10 with value: 0.7860914915940419.

suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.


suggest_loguniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float(..., log=True) instead.


suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.


suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/op

KeyboardInterrupt: 

In [None]:
# Mount Google Drive for checkpoint and study persistence
import os
import optuna
from google.colab import drive
drive.mount('/content/drive')

CHECKPOINT_DIR = "/content/drive/MyDrive/violet_checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
STUDY_DB_PATH = os.path.join(CHECKPOINT_DIR, "llm_finetuning_study.db")

# Create or resume an Optuna study persisted on Google Drive
study = optuna.create_study(
    direction="maximize",
    study_name="llm_finetuning_study",
    storage=f"sqlite:///{STUDY_DB_PATH}",
    load_if_exists=True
)

print("Best trial:")
print(study.best_trial)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


[I 2025-03-30 12:56:37,239] Using an existing study with name 'llm_finetuning_study' instead of creating a new one.


Best trial:
FrozenTrial(number=25, state=1, values=[0.8027423473243455], datetime_start=datetime.datetime(2025, 3, 28, 14, 9, 18, 733650), datetime_complete=datetime.datetime(2025, 3, 28, 14, 40, 13, 582679), params={'batch_size': 128, 'projection_depth': 2, 'projection_size': 6144, 'max_seq_length': 75, 'aug_p': 0.11568972876642647, 'learning_rate': 5.7605107036319744e-05, 'weight_decay': 0.17805243663726578, 'lambda_bt': 0.1117488685923089, 'lambda_mixup': 0.9545383539595996}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'batch_size': CategoricalDistribution(choices=(64, 128, 256, 512)), 'projection_depth': IntDistribution(high=6, log=False, low=2, step=1), 'projection_size': CategoricalDistribution(choices=(2048, 4096, 6144, 8192)), 'max_seq_length': CategoricalDistribution(choices=(32, 64, 75)), 'aug_p': FloatDistribution(high=0.5, log=False, low=0.1, step=None), 'learning_rate': FloatDistribution(high=0.001, log=True, low=1e-05, step=None), 'weight_decay'

In [None]:
# After optimization, use Optuna's visualization tools:
import optuna.visualization as vis

# Plot the optimization history
opt_history_fig = vis.plot_optimization_history(study)
opt_history_fig.show()

# Plot hyperparameter importances
opt_param_fig = vis.plot_param_importances(study)
opt_param_fig.show()

# Plot slices for selected hyperparameters
opt_slice_fig = vis.plot_slice(study)
opt_slice_fig.show()

## Mixed Barlow Twins with BERT-Base

In [None]:
import textwrap
import os
import torch
from datetime import datetime
from sentence_transformers import SentenceTransformer, models, util
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import copy
import gc
import math
import random
import nlpaug.augmenter.word as naw
from IPython.display import display, clear_output
from tqdm import tqdm
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from torch import GradScaler
from torch.amp import autocast
import optuna

class BarlowTwinsNCSE:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.backends.cudnn.benchmark = True  # Enable cuDNN benchmarking
        self._prepare_datasets()
        self._initialize_models()
        self._initialize_optimizer_scheduler()
        self.scaler = GradScaler("cuda", enabled=self.config.get("use_amp", True))  # GradScaler for AMP
        self.best_spearman = -float("inf")
        self.best_pearson = -float("inf")
        self.patience_counter = 0
        self.augmenters = [
            naw.SynonymAug(aug_src='wordnet', aug_p=self.config["aug_p"]),
            naw.RandomWordAug(action="swap", aug_p=self.config["aug_p"]),
            naw.RandomWordAug(aug_p=self.config["aug_p"]),
        ]
        self.test_sts_pearson_cosine_values = []
        self.test_sts_spearman_cosine_values = []
        self.test_iterations = []
        self._create_plot()

    def _create_plot(self):
        self.loss_values = []
        self.sts_pearson_cosine_values = []
        self.sts_spearman_cosine_values = []
        self.mean_grad_norm_values = []
        self.variance_values = []
        self.learning_rate_values = []
        self.iterations = []
        self.epochs = []

        # 3 rows x 2 columns grid with 6 subplots
        self.fig = make_subplots(
            rows=3, cols=2,
            subplot_titles=(
                "Loss vs Iterations",
                "Mean Gradient Norm vs Iterations",
                "Variance vs Iterations",
                "Learning Rate vs Iterations",
                "Dev STS Cosine (Pearson & Spearman) vs Iterations",
                "Test STS Cosine (Pearson & Spearman) vs Iterations"
            )
        )
        # Dev metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Loss'), row=1, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Mean Gradient Norm'), row=1, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Variance'), row=2, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Learning Rate'), row=2, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Pearson Cosine'), row=3, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Spearman Cosine'), row=3, col=1)
        # Test metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Pearson Cosine'), row=3, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Spearman Cosine'), row=3, col=2)

        # Prepare footer text
        footer_text = ", ".join([f"{key}={value}" for key, value in self.config.items()])
        augmenters_text = ", Augmenters: " + ", ".join([f"{aug.name}[{aug.action}:{aug.aug_p}]" for aug in self.augmenters])
        footer_text += augmenters_text
        wrapped_footer = "<br>".join(textwrap.wrap(footer_text, width=160))

        # Configure the download button
        self.plot_config = {
            'toImageButtonOptions': {
                'filename': self.config["model_save_path"],
                'format': 'png',
                'width': 1200,
                'height': 800,
                'scale': 1
            }
        }

        self.fig.update_layout(
            width=1200,
            height=800,
            title_text='Training Metrics',
            showlegend=True,
            margin=dict(l=50, r=50, t=100, b=150),
            annotations=[
                dict(
                    text=wrapped_footer,
                    showarrow=False,
                    xref="paper",
                    yref="paper",
                    x=0.5,
                    y=-0.15,
                    xanchor='center',
                    yanchor='top',
                    align="center",
                    font=dict(size=10, color="gray")
                )
            ]
        )

        # Annotation for best metrics
        self.best_metrics_annotation_index = len(self.fig.layout.annotations)
        self.fig.add_annotation(
            text="Best Spearman (Test): N/A<br>Best Pearson (Test): N/A<br>Best Spearman (Val): N/A<br>Best Pearson (Val): N/A",
            showarrow=False,
            xref="paper",
            yref="paper",
            x=1.0,
            y=0.0,
            xanchor='right',
            yanchor='bottom',
            align="right",
            font=dict(size=12, color="blue")
        )

        self.fig.show(config=self.plot_config)

    def _update_traces(self):
        with self.fig.batch_update():
            self.fig.data[0].x = self.iterations
            self.fig.data[0].y = self.loss_values
            self.fig.data[1].x = self.iterations
            self.fig.data[1].y = self.mean_grad_norm_values
            self.fig.data[2].x = self.iterations
            self.fig.data[2].y = self.variance_values
            self.fig.data[3].x = self.iterations
            self.fig.data[3].y = self.learning_rate_values
            self.fig.data[4].x = self.iterations
            self.fig.data[4].y = self.sts_pearson_cosine_values
            self.fig.data[5].x = self.iterations
            self.fig.data[5].y = self.sts_spearman_cosine_values
            self.fig.data[6].x = self.test_iterations
            self.fig.data[6].y = self.test_sts_pearson_cosine_values
            self.fig.data[7].x = self.test_iterations
            self.fig.data[7].y = self.test_sts_spearman_cosine_values

            for i in range(1, 4):
                for j in range(1, 3):
                    self.fig.update_yaxes(autorange=True, row=i, col=j)
                    self.fig.update_xaxes(autorange=True, row=i, col=j)

    def _update_plot(self):
        self._update_traces()
        unique_epochs = sorted(set(self.epochs))
        frames = []
        for ep in unique_epochs:
            indices = [i for i, e in enumerate(self.epochs) if e == ep]
            frame = go.Frame(
                data=[
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.loss_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.mean_grad_norm_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.variance_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.learning_rate_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_pearson_cosine_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_spearman_cosine_values[i] for i in indices])
                ],
                name=str(ep)
            )
            frames.append(frame)
        self.fig.frames = frames

        slider_steps = [
            {"args": [[str(ep)], {"frame": {"duration": 0, "redraw": True},
                                   "mode": "immediate", "transition": {"duration": 0}}],
             "label": str(ep), "method": "animate"} for ep in unique_epochs
        ]

        self.fig.update_layout(
            sliders=[{
                "active": len(unique_epochs) - 1 if unique_epochs else 0,
                "currentvalue": {"prefix": "Epoch: "},
                "pad": {"t": 50},
                "steps": slider_steps
            }]
        )

        # Compute best test metrics if available
        if self.test_sts_spearman_cosine_values:
            best_test_spearman = max(self.test_sts_spearman_cosine_values)
        else:
            best_test_spearman = float('nan')
        if self.test_sts_pearson_cosine_values:
            best_test_pearson = max(self.test_sts_pearson_cosine_values)
        else:
            best_test_pearson = float('nan')

        self.fig.layout.annotations[self.best_metrics_annotation_index].text = (
            f"Best Spearman (Test): {best_test_spearman:.4f}<br>"
            f"Best Pearson (Test): {best_test_pearson:.4f}<br>"
            f"Best Spearman (Val): {self.best_spearman:.4f}<br>"
            f"Best Pearson (Val): {self.best_pearson:.4f}"
        )

        for i in range(1, 4):
            for j in range(1, 3):
                self.fig.update_yaxes(autorange=True, row=i, col=j)
                self.fig.update_xaxes(autorange=True, row=i, col=j)

        clear_output(wait=True)
        self.fig.show(config=self.plot_config)

    def _prepare_datasets(self):
        wikipedia_url = "https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt"
        wikipedia_dataset_path = "data/wiki1m_for_simcse.txt"
        if not os.path.exists(wikipedia_dataset_path):
            util.http_get(wikipedia_url, wikipedia_dataset_path)
        train_sentences = []
        with open(wikipedia_dataset_path, encoding="utf8") as f:
            for line in f:
                line = line.strip()
                if len(line) >= 10:
                    train_sentences.append(line)
        self.train_sentences = train_sentences

        self.train_dataset = load_dataset("sentence-transformers/stsb", split="train")
        self.eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
        self.test_dataset = load_dataset("sentence-transformers/stsb", split="test")

        self.train_data_loader = DataLoader(
            SentenceDataset(self.train_sentences),
            batch_size=self.config["batch_size"],
            shuffle=True,
            num_workers=self.config["num_workers"],
            pin_memory=True
        )
        self.test_evaluator = EmbeddingSimilarityEvaluator(
            sentences1=self.test_dataset["sentence1"],
            sentences2=self.test_dataset["sentence2"],
            scores=self.test_dataset["score"]
        )

        self.evaluate_steps = max(len(self.train_data_loader) // 50, 1)

        # Ensure patience is at least five times as large as evaluate_steps
        if self.config.get("patience", 0) < self.evaluate_steps * 5:
            print(f"Warning: Patience ({self.config['patience']}) is less than evaluation steps x5 ({self.evaluate_steps * 5}). Adjusting patience to {self.evaluate_steps * 5}.")
            self.config["patience"] = self.evaluate_steps * 5

    def _apply_augmentation(self, sentences, aug):
        return aug.augment(sentences)

    def _initialize_models(self):
        word_embedding_model = models.Transformer(
            self.config["model_name"],
            max_seq_length=self.config["max_seq_length"],
            config_args={"attention_dropout": self.config["aug_p"], "dropout": self.config["aug_p"]}
        )
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        projection_layers = [torch.nn.Linear(768, self.config["projection_size"])]
        for _ in range(self.config["projection_depth"] - 1):
            projection_layers.append(torch.nn.BatchNorm1d(self.config["projection_size"]))
            projection_layers.append(torch.nn.ReLU())
            projection_layers.append(torch.nn.Linear(self.config["projection_size"], self.config["projection_size"]))
        projection_head = torch.nn.Sequential(*projection_layers)
        self.online_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, projection_head]).to(self.device)
        encoder_modules = [copy.deepcopy(self.online_model[i]) for i in range(2)]
        self.encoder = SentenceTransformer(modules=encoder_modules).to(self.device)

    def _update_encoder(self):
        for i in range(len(self.encoder)):
            self.encoder[i].load_state_dict(self.online_model[i].state_dict())

    def _initialize_optimizer_scheduler(self):
        self.optimizer = torch.optim.Adam(
            self.online_model.parameters(),
            lr=self.config["learning_rate"],
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=self.config["weight_decay"]
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=200, verbose=True, min_lr=1e-6)

    def _forward_pass(self, model, sentences, train):
        if train:
            model.train()
        else:
            model.eval()
        features = model.tokenize(sentences)
        features = {k: v.to(self.device, non_blocking=True) for k, v in features.items()}
        with autocast("cuda", enabled=self.config.get("use_amp", True)):
            embeddings = model[0](features)["token_embeddings"]
            pooled = model[1]({"token_embeddings": embeddings})["sentence_embedding"]
            if len(model) > 2:
                return model[2](pooled)
            else:
                return pooled

    def _mixed_barlow_twins_loss(self, z_a, z_b):
        N, D = z_a.size()
        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + 1e-6)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + 1e-6)
        c = torch.matmul(z_a_norm.T, z_b_norm) / N
        I = torch.eye(D, device=z_a.device)
        c_diff = (c - I).pow(2)
        off_diag_mask = ~torch.eye(D, dtype=torch.bool, device=z_a.device)
        c_diff[off_diag_mask] *= self.config["lambda_bt"]
        loss_bt = c_diff.sum()
        # MixUp Regularization
        idx = torch.randperm(N)
        alpha = torch.tensor(np.random.beta(1.0, 1.0), device=z_a.device, dtype=z_a.dtype)
        z_m = alpha * z_a + (1 - alpha) * z_b[idx, :]
        z_m_norm = (z_m - z_m.mean(dim=0)) / (z_m.std(dim=0) + 1e-6)
        cc_m_a = torch.matmul(z_m_norm.T, z_a_norm) / N
        cc_m_b = torch.matmul(z_m_norm.T, z_b_norm) / N
        cc_m_a_gt = alpha * torch.matmul(z_a_norm.T, z_a_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_a_norm) / N
        cc_m_b_gt = alpha * torch.matmul(z_a_norm.T, z_b_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_b_norm) / N
        loss_mix = self.config["lambda_mixup"] * self.config["lambda_bt"] * (
            (cc_m_a - cc_m_a_gt).pow(2).sum() + (cc_m_b - cc_m_b_gt).pow(2).sum()
        )
        return loss_bt + loss_mix

    def _evaluate_without_heads(self):
        self._update_encoder()
        self.encoder.eval()
        indices = list(range(len(self.eval_dataset["sentence1"])))
        random.shuffle(indices)
        sentences1 = [self.eval_dataset["sentence1"][i] for i in indices]
        sentences2 = [self.eval_dataset["sentence2"][i] for i in indices]
        scores = [self.eval_dataset["score"][i] for i in indices]
        evaluator = EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
        return evaluator(self.encoder)

    def fit(self):
        latest_eval_metrics = {}
        for epoch in range(self.config["epochs"]):
            early_stop = False
            epoch_loss = 0
            pbar = tqdm(self.train_data_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}:")
            for idx, sentences in enumerate(pbar):
                aug1, aug2 = random.sample(self.augmenters, 2)
                s1 = self._apply_augmentation(sentences, aug1)
                s2 = self._apply_augmentation(sentences, aug2)
                with autocast("cuda", enabled=self.config.get("use_amp", True)):
                    z_a = self._forward_pass(self.online_model, s1, train=True)
                    z_b = self._forward_pass(self.online_model, s2, train=True)
                    loss = self._mixed_barlow_twins_loss(z_a, z_b)
                epoch_loss += loss.item()
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                scale = self.scaler.get_scale()
                scale = scale if scale != 0 else 1e-8
                total_norm_scaled = 0.0
                for param in self.online_model.parameters():
                    if param.grad is not None:
                        param_norm = param.grad.data.norm(2).item()
                        total_norm_scaled += param_norm ** 2
                total_norm_scaled = math.sqrt(total_norm_scaled)
                total_norm = total_norm_scaled / scale
                mean_grad_norm = total_norm / len(list(self.online_model.parameters()))
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step(loss)
                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    latest_eval_metrics = self._evaluate_without_heads()
                    last_spearman = latest_eval_metrics.get('spearman_cosine', -float('inf'))
                    last_pearson = latest_eval_metrics.get('pearson_cosine', -float('inf'))
                    self._update_encoder()
                    self.encoder.eval()
                    test_metrics = self.test_evaluator(self.encoder)
                    self.test_sts_pearson_cosine_values.append(test_metrics.get('pearson_cosine', np.nan))
                    self.test_sts_spearman_cosine_values.append(test_metrics.get('spearman_cosine', np.nan))
                    self.test_iterations.append(idx)
                    self._update_plot()

                    if last_spearman > self.best_spearman:
                        self.best_spearman = last_spearman
                        self.best_pearson = last_pearson
                        self.patience_counter = 0
                    else:
                        self.patience_counter += self.evaluate_steps
                pbar.set_postfix({
                    "loss": loss.item(),
                    **latest_eval_metrics,
                    "mean_grad_norm": mean_grad_norm,
                    "learning_rate": self.optimizer.param_groups[0]['lr']
                })
                self.loss_values.append(loss.item())
                self.sts_pearson_cosine_values.append(latest_eval_metrics.get('pearson_cosine', np.nan))
                self.sts_spearman_cosine_values.append(latest_eval_metrics.get('spearman_cosine', np.nan))
                self.mean_grad_norm_values.append(mean_grad_norm)
                self.variance_values.append(torch.var(z_a, dim=0).mean().item())
                self.learning_rate_values.append(self.optimizer.param_groups[0]['lr'])
                self.iterations.append(idx)
                self.epochs.append(epoch)
                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    if self.patience_counter >= self.config["patience"]:
                        early_stop = True
                        print(f"Early stopping triggered at epoch {epoch+1}, iteration {idx}.")
                        print(f"Best Spearman Correlation: {self.best_spearman}")
                        print(f"Best Pearson Correlation: {self.best_pearson}")
                        break
            if early_stop:
                break
            avg_loss = epoch_loss / len(self.train_data_loader)
            pbar.set_description(f"Epoch {epoch+1} Loss: {avg_loss}")

    def cleanup(self):
        del self.online_model, self.optimizer, self.scheduler, self.scaler
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

class SentenceDataset(torch.utils.data.Dataset):
    def __init__(self, sentences):
        self.sentences = sentences
    def __len__(self):
        return len(self.sentences)
    def __getitem__(self, idx):
        return self.sentences[idx]

# Mount Google Drive for checkpoint and study persistence
from google.colab import drive
drive.mount('/content/drive')

CHECKPOINT_DIR = "/content/drive/MyDrive/violet_bert_base_checkpoints_2"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
STUDY_DB_PATH = os.path.join(CHECKPOINT_DIR, "llm_finetuning_study.db")

def save_checkpoint(trial_number, trainer):
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_trial_{trial_number}.pt")
    torch.save({
        'online_model_state_dict': trainer.online_model.state_dict(),
        'optimizer_state_dict': trainer.optimizer.state_dict(),
        'scheduler_state_dict': trainer.scheduler.state_dict(),
        'best_spearman': trainer.best_spearman,
        'best_pearson': trainer.best_pearson,
        'epochs': trainer.epochs,
        'iterations': trainer.iterations,
    }, checkpoint_path)

    # Also save the graph to a subfolder called "graphs"
    graphs_dir = os.path.join(CHECKPOINT_DIR, "graphs")
    os.makedirs(graphs_dir, exist_ok=True)
    graph_path = os.path.join(graphs_dir, f"checkpoint_trial_{trial_number}.png")
    trainer.fig.write_image(graph_path)
    print(f"Checkpoint and graph saved to {checkpoint_path} and {graph_path}")

    return checkpoint_path

def load_checkpoint(trainer, checkpoint_path):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        trainer.online_model.load_state_dict(checkpoint['online_model_state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        trainer.best_spearman = checkpoint.get('best_spearman', -float("inf"))
        trainer.best_pearson = checkpoint.get('best_pearson', -float("inf"))
        trainer.epochs = checkpoint.get('epochs', [])
        trainer.iterations = checkpoint.get('iterations', [])
        print(f"Loaded checkpoint from {checkpoint_path}")

def objective(trial):
    # Experiment with various hyperparameters
    config = {
        "model_name": "bert-base-uncased",
        "batch_size": trial.suggest_categorical("batch_size", [64, 128]),
        "projection_depth": trial.suggest_int("projection_depth", 2, 3),
        "projection_size": trial.suggest_categorical("projection_size", [2048, 4096, 6144, 8192]),
        "epochs": 1,  # fixed at 1
        "warmup_proportion": 0.0,
        "max_seq_length": 75,#trial.suggest_categorical("max_seq_length", [32, 75]),
        "aug_p": trial.suggest_uniform("aug_p", 0.1, 0.4),
        "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 1e-4),
        "model_save_path": os.path.join(CHECKPOINT_DIR, f"train_stsb_bt-distilbert-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_trial{trial.number}"),
        "num_workers": 10,
        "weight_decay": trial.suggest_uniform("weight_decay", 1e-4, 0.2),
        "lambda_bt": trial.suggest_uniform("lambda_bt", 0.001, 0.2),
        "lambda_mixup": trial.suggest_uniform("lambda_mixup", 0.1, 1.5),
        "use_amp": True,
        "patience": 1000
    }

    trainer = BarlowTwinsNCSE(config)
    # Resume from checkpoint if available
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_trial_{trial.number}.pt")
    if os.path.exists(checkpoint_path):
        load_checkpoint(trainer, checkpoint_path)

    try:
        trainer.fit()
    except KeyboardInterrupt:
        save_checkpoint(trial.number, trainer)
        raise optuna.TrialPruned("Trial interrupted and checkpoint saved.")

    # Save checkpoint at end of trial
    save_checkpoint(trial.number, trainer)

    # Use the STS-B validation metric (spearman cosine) as objective.
    val_metrics = trainer._evaluate_without_heads()
    return val_metrics.get("spearman_cosine", -float("inf"))
    trainer.cleanup()

# Create or resume an Optuna study persisted on Google Drive
study = optuna.create_study(
    direction="maximize",
    study_name="llm_finetuning_study",
    storage=f"sqlite:///{STUDY_DB_PATH}",
    load_if_exists=True
)
study.optimize(objective, n_trials=20)

[I 2025-04-13 14:33:20,494] Using an existing study with name 'llm_finetuning_study' instead of creating a new one.


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).



suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.


suggest_loguniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float(..., log=True) instead.


suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.


suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.


suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.



Epoch 1/1::   0%|          | 0/7701 [00:00<?, ?it/s]
[W 2025-04-13 14:33:36,210] Trial 38 failed with parameters: {'batch_size': 128, 'projection_depth': 3, 'projection_size': 8192, 'aug_p': 0.16539755977310927, 'learning_rate': 4.837592944784042e-05, 'weight_decay': 0.05398982009405306, 'lambda_bt': 0.06954297619664421, 'lambda_mixup': 1.16868068523566} because of the following error: OutOfMemoryError('CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 22.16 GiB of which 15.38 MiB is free. Process 10775 has 22.14 GiB memory in use. Of the allocated memory 21.79 GiB is allocated by PyTorch, and 114.15 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)').
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist

OutOfMemoryError: CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 22.16 GiB of which 15.38 MiB is free. Process 10775 has 22.14 GiB memory in use. Of the allocated memory 21.79 GiB is allocated by PyTorch, and 114.15 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

### adding optuna pruning

In [None]:
import textwrap
import os
import torch
from datetime import datetime
from sentence_transformers import SentenceTransformer, models, util
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import copy
import gc
import math
import random
import nlpaug.augmenter.word as naw
from IPython.display import display, clear_output
from tqdm import tqdm
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from torch import GradScaler
from torch.amp import autocast
import optuna

# Global variable to hold the best validation score so far
BEST_SCORE = -float("inf")

class BarlowTwinsNCSE:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.backends.cudnn.benchmark = True  # Enable cuDNN benchmarking
        self._prepare_datasets()
        self._initialize_models()
        self._initialize_optimizer_scheduler()
        self.scaler = GradScaler("cuda", enabled=self.config.get("use_amp", True))  # GradScaler for AMP
        self.best_spearman = -float("inf")
        self.best_pearson = -float("inf")
        self.patience_counter = 0
        self.augmenters = [
            naw.SynonymAug(aug_src='wordnet', aug_p=self.config["aug_p"]),
            naw.RandomWordAug(action="swap", aug_p=self.config["aug_p"]),
            naw.RandomWordAug(aug_p=self.config["aug_p"]),
        ]
        self.test_sts_pearson_cosine_values = []
        self.test_sts_spearman_cosine_values = []
        self.test_iterations = []
        self._create_plot()

    def _create_plot(self):
        self.loss_values = []
        self.sts_pearson_cosine_values = []
        self.sts_spearman_cosine_values = []
        self.mean_grad_norm_values = []
        self.variance_values = []
        self.learning_rate_values = []
        self.iterations = []
        self.epochs = []

        # 3 rows x 2 columns grid with 6 subplots
        self.fig = make_subplots(
            rows=3, cols=2,
            subplot_titles=(
                "Loss vs Iterations",
                "Mean Gradient Norm vs Iterations",
                "Variance vs Iterations",
                "Learning Rate vs Iterations",
                "Dev STS Cosine (Pearson & Spearman) vs Iterations",
                "Test STS Cosine (Pearson & Spearman) vs Iterations"
            )
        )
        # Dev metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Loss'), row=1, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Mean Gradient Norm'), row=1, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Variance'), row=2, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Learning Rate'), row=2, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Pearson Cosine'), row=3, col=1)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Dev STS Spearman Cosine'), row=3, col=1)
        # Test metrics
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Pearson Cosine'), row=3, col=2)
        self.fig.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Test STS Spearman Cosine'), row=3, col=2)

        # Prepare footer text
        footer_text = ", ".join([f"{key}={value}" for key, value in self.config.items()])
        augmenters_text = ", Augmenters: " + ", ".join([f"{aug.name}[{aug.action}:{aug.aug_p}]" for aug in self.augmenters])
        footer_text += augmenters_text
        wrapped_footer = "<br>".join(textwrap.wrap(footer_text, width=160))

        # Configure the download button
        self.plot_config = {
            'toImageButtonOptions': {
                'filename': self.config["model_save_path"],
                'format': 'png',
                'width': 1200,
                'height': 800,
                'scale': 1
            }
        }

        self.fig.update_layout(
            width=1200,
            height=800,
            title_text='Training Metrics',
            showlegend=True,
            margin=dict(l=50, r=50, t=100, b=150),
            annotations=[
                dict(
                    text=wrapped_footer,
                    showarrow=False,
                    xref="paper",
                    yref="paper",
                    x=0.5,
                    y=-0.15,
                    xanchor='center',
                    yanchor='top',
                    align="center",
                    font=dict(size=10, color="gray")
                )
            ]
        )

        # Annotation for best metrics
        self.best_metrics_annotation_index = len(self.fig.layout.annotations)
        self.fig.add_annotation(
            text="Best Spearman (Test): N/A<br>Best Pearson (Test): N/A<br>Best Spearman (Val): N/A<br>Best Pearson (Val): N/A",
            showarrow=False,
            xref="paper",
            yref="paper",
            x=1.0,
            y=0.0,
            xanchor='right',
            yanchor='bottom',
            align="right",
            font=dict(size=12, color="blue")
        )

        self.fig.show(config=self.plot_config)

    def _update_traces(self):
        with self.fig.batch_update():
            self.fig.data[0].x = self.iterations
            self.fig.data[0].y = self.loss_values
            self.fig.data[1].x = self.iterations
            self.fig.data[1].y = self.mean_grad_norm_values
            self.fig.data[2].x = self.iterations
            self.fig.data[2].y = self.variance_values
            self.fig.data[3].x = self.iterations
            self.fig.data[3].y = self.learning_rate_values
            self.fig.data[4].x = self.iterations
            self.fig.data[4].y = self.sts_pearson_cosine_values
            self.fig.data[5].x = self.iterations
            self.fig.data[5].y = self.sts_spearman_cosine_values
            self.fig.data[6].x = self.test_iterations
            self.fig.data[6].y = self.test_sts_pearson_cosine_values
            self.fig.data[7].x = self.test_iterations
            self.fig.data[7].y = self.test_sts_spearman_cosine_values

            for i in range(1, 4):
                for j in range(1, 3):
                    self.fig.update_yaxes(autorange=True, row=i, col=j)
                    self.fig.update_xaxes(autorange=True, row=i, col=j)

    def _update_plot(self):
        self._update_traces()
        unique_epochs = sorted(set(self.epochs))
        frames = []
        for ep in unique_epochs:
            indices = [i for i, e in enumerate(self.epochs) if e == ep]
            frame = go.Frame(
                data=[
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.loss_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.mean_grad_norm_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.variance_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.learning_rate_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_pearson_cosine_values[i] for i in indices]),
                    go.Scatter(x=[self.iterations[i] for i in indices], y=[self.sts_spearman_cosine_values[i] for i in indices])
                ],
                name=str(ep)
            )
            frames.append(frame)
        self.fig.frames = frames

        slider_steps = [
            {"args": [[str(ep)], {"frame": {"duration": 0, "redraw": True},
                                   "mode": "immediate", "transition": {"duration": 0}}],
             "label": str(ep), "method": "animate"} for ep in unique_epochs
        ]

        self.fig.update_layout(
            sliders=[{
                "active": len(unique_epochs) - 1 if unique_epochs else 0,
                "currentvalue": {"prefix": "Epoch: "},
                "pad": {"t": 50},
                "steps": slider_steps
            }]
        )

        # Compute best test metrics if available
        if self.test_sts_spearman_cosine_values:
            best_test_spearman = max(self.test_sts_spearman_cosine_values)
        else:
            best_test_spearman = float('nan')
        if self.test_sts_pearson_cosine_values:
            best_test_pearson = max(self.test_sts_pearson_cosine_values)
        else:
            best_test_pearson = float('nan')

        self.fig.layout.annotations[self.best_metrics_annotation_index].text = (
            f"Best Spearman (Test): {best_test_spearman:.4f}<br>"
            f"Best Pearson (Test): {best_test_pearson:.4f}<br>"
            f"Best Spearman (Val): {self.best_spearman:.4f}<br>"
            f"Best Pearson (Val): {self.best_pearson:.4f}"
        )

        for i in range(1, 4):
            for j in range(1, 3):
                self.fig.update_yaxes(autorange=True, row=i, col=j)
                self.fig.update_xaxes(autorange=True, row=i, col=j)

        clear_output(wait=True)
        self.fig.show(config=self.plot_config)

    def _prepare_datasets(self):
        wikipedia_url = "https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt"
        wikipedia_dataset_path = "data/wiki1m_for_simcse.txt"
        if not os.path.exists(wikipedia_dataset_path):
            util.http_get(wikipedia_url, wikipedia_dataset_path)
        train_sentences = []
        with open(wikipedia_dataset_path, encoding="utf8") as f:
            for line in f:
                line = line.strip()
                if len(line) >= 10:
                    train_sentences.append(line)
        self.train_sentences = train_sentences

        self.train_dataset = load_dataset("sentence-transformers/stsb", split="train")
        self.eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
        self.test_dataset = load_dataset("sentence-transformers/stsb", split="test")

        self.train_data_loader = DataLoader(
            SentenceDataset(self.train_sentences),
            batch_size=self.config["batch_size"],
            shuffle=True,
            num_workers=self.config["num_workers"],
            pin_memory=True
        )
        self.test_evaluator = EmbeddingSimilarityEvaluator(
            sentences1=self.test_dataset["sentence1"],
            sentences2=self.test_dataset["sentence2"],
            scores=self.test_dataset["score"]
        )

        self.evaluate_steps = max(len(self.train_data_loader) // 50, 1)

        # Ensure patience is at least five times as large as evaluate_steps
        if self.config.get("patience", 0) < self.evaluate_steps * 5:
            print(f"Warning: Patience ({self.config['patience']}) is less than evaluation steps x5 ({self.evaluate_steps * 5}). Adjusting patience to {self.evaluate_steps * 5}.")
            self.config["patience"] = self.evaluate_steps * 5

    def _apply_augmentation(self, sentences, aug):
        return aug.augment(sentences)

    def _initialize_models(self):
        word_embedding_model = models.Transformer(
            self.config["model_name"],
            max_seq_length=self.config["max_seq_length"],
            config_args={"attention_dropout": self.config["aug_p"], "dropout": self.config["aug_p"]}
        )
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        projection_layers = [torch.nn.Linear(768, self.config["projection_size"])]
        for _ in range(self.config["projection_depth"] - 1):
            projection_layers.append(torch.nn.BatchNorm1d(self.config["projection_size"]))
            projection_layers.append(torch.nn.ReLU())
            projection_layers.append(torch.nn.Linear(self.config["projection_size"], self.config["projection_size"]))
        projection_head = torch.nn.Sequential(*projection_layers)
        self.online_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, projection_head]).to(self.device)
        encoder_modules = [copy.deepcopy(self.online_model[i]) for i in range(2)]
        self.encoder = SentenceTransformer(modules=encoder_modules).to(self.device)

    def _update_encoder(self):
        for i in range(len(self.encoder)):
            self.encoder[i].load_state_dict(self.online_model[i].state_dict())

    def _initialize_optimizer_scheduler(self):
        self.optimizer = torch.optim.Adam(
            self.online_model.parameters(),
            lr=self.config["learning_rate"],
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=self.config["weight_decay"]
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=200, verbose=True, min_lr=1e-6)

    def _forward_pass(self, model, sentences, train):
        if train:
            model.train()
        else:
            model.eval()
        features = model.tokenize(sentences)
        features = {k: v.to(self.device, non_blocking=True) for k, v in features.items()}
        with autocast("cuda", enabled=self.config.get("use_amp", True)):
            embeddings = model[0](features)["token_embeddings"]
            pooled = model[1]({"token_embeddings": embeddings})["sentence_embedding"]
            if len(model) > 2:
                return model[2](pooled)
            else:
                return pooled

    def _mixed_barlow_twins_loss(self, z_a, z_b):
        N, D = z_a.size()
        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + 1e-6)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + 1e-6)
        c = torch.matmul(z_a_norm.T, z_b_norm) / N
        I = torch.eye(D, device=z_a.device)
        c_diff = (c - I).pow(2)
        off_diag_mask = ~torch.eye(D, dtype=torch.bool, device=z_a.device)
        c_diff[off_diag_mask] *= self.config["lambda_bt"]
        loss_bt = c_diff.sum()
        # MixUp Regularization
        idx = torch.randperm(N)
        alpha = torch.tensor(np.random.beta(1.0, 1.0), device=z_a.device, dtype=z_a.dtype)
        z_m = alpha * z_a + (1 - alpha) * z_b[idx, :]
        z_m_norm = (z_m - z_m.mean(dim=0)) / (z_m.std(dim=0) + 1e-6)
        cc_m_a = torch.matmul(z_m_norm.T, z_a_norm) / N
        cc_m_b = torch.matmul(z_m_norm.T, z_b_norm) / N
        cc_m_a_gt = alpha * torch.matmul(z_a_norm.T, z_a_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_a_norm) / N
        cc_m_b_gt = alpha * torch.matmul(z_a_norm.T, z_b_norm) / N + (1 - alpha) * torch.matmul(z_b_norm[idx, :].T, z_b_norm) / N
        loss_mix = self.config["lambda_mixup"] * self.config["lambda_bt"] * (
            (cc_m_a - cc_m_a_gt).pow(2).sum() + (cc_m_b - cc_m_b_gt).pow(2).sum()
        )
        return loss_bt + loss_mix

    def _evaluate_without_heads(self):
        self._update_encoder()
        self.encoder.eval()
        indices = list(range(len(self.eval_dataset["sentence1"])))
        random.shuffle(indices)
        sentences1 = [self.eval_dataset["sentence1"][i] for i in indices]
        sentences2 = [self.eval_dataset["sentence2"][i] for i in indices]
        scores = [self.eval_dataset["score"][i] for i in indices]
        evaluator = EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
        return evaluator(self.encoder)

    def fit(self, trial=None):
        latest_eval_metrics = {}
        for epoch in range(self.config["epochs"]):
            early_stop = False
            epoch_loss = 0
            pbar = tqdm(self.train_data_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']}:")
            for idx, sentences in enumerate(pbar):
                aug1, aug2 = random.sample(self.augmenters, 2)
                s1 = self._apply_augmentation(sentences, aug1)
                s2 = self._apply_augmentation(sentences, aug2)
                with autocast("cuda", enabled=self.config.get("use_amp", True)):
                    z_a = self._forward_pass(self.online_model, s1, train=True)
                    z_b = self._forward_pass(self.online_model, s2, train=True)
                    loss = self._mixed_barlow_twins_loss(z_a, z_b)
                epoch_loss += loss.item()
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                scale = self.scaler.get_scale()
                scale = scale if scale != 0 else 1e-8
                total_norm_scaled = 0.0
                for param in self.online_model.parameters():
                    if param.grad is not None:
                        param_norm = param.grad.data.norm(2).item()
                        total_norm_scaled += param_norm ** 2
                total_norm_scaled = math.sqrt(total_norm_scaled)
                total_norm = total_norm_scaled / scale
                mean_grad_norm = total_norm / len(list(self.online_model.parameters()))
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step(loss)
                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    latest_eval_metrics = self._evaluate_without_heads()
                    last_spearman = latest_eval_metrics.get('spearman_cosine', -float('inf'))
                    last_pearson = latest_eval_metrics.get('pearson_cosine', -float('inf'))
                    self._update_encoder()
                    self.encoder.eval()
                    test_metrics = self.test_evaluator(self.encoder)
                    self.test_sts_pearson_cosine_values.append(test_metrics.get('pearson_cosine', np.nan))
                    self.test_sts_spearman_cosine_values.append(test_metrics.get('spearman_cosine', np.nan))
                    self.test_iterations.append(idx)
                    self._update_plot()

                    # Report intermediate value and check for pruning
                    if trial is not None:
                        current_step = epoch * len(self.train_data_loader) + idx
                        trial.report(last_spearman, current_step)
                        if trial.should_prune():
                            print(f"Trial {trial.number} pruned at epoch {epoch+1}, iteration {idx}.")
                            save_plot(trial.number, self)
                            self.cleanup()
                            raise optuna.exceptions.TrialPruned()

                    if last_spearman > self.best_spearman:
                        self.best_spearman = last_spearman
                        self.best_pearson = last_pearson
                        self.patience_counter = 0
                    else:
                        self.patience_counter += self.evaluate_steps
                pbar.set_postfix({
                    "loss": loss.item(),
                    **latest_eval_metrics,
                    "mean_grad_norm": mean_grad_norm,
                    "learning_rate": self.optimizer.param_groups[0]['lr']
                })
                self.loss_values.append(loss.item())
                self.sts_pearson_cosine_values.append(latest_eval_metrics.get('pearson_cosine', np.nan))
                self.sts_spearman_cosine_values.append(latest_eval_metrics.get('spearman_cosine', np.nan))
                self.mean_grad_norm_values.append(mean_grad_norm)
                self.variance_values.append(torch.var(z_a, dim=0).mean().item())
                self.learning_rate_values.append(self.optimizer.param_groups[0]['lr'])
                self.iterations.append(idx)
                self.epochs.append(epoch)
                if idx % self.evaluate_steps == 0 or idx in [0, len(self.train_data_loader)-1]:
                    if self.patience_counter >= self.config["patience"]:
                        early_stop = True
                        print(f"Early stopping triggered at epoch {epoch+1}, iteration {idx}.")
                        print(f"Best Spearman Correlation: {self.best_spearman}")
                        print(f"Best Pearson Correlation: {self.best_pearson}")
                        break
            if early_stop:
                break
            avg_loss = epoch_loss / len(self.train_data_loader)
            pbar.set_description(f"Epoch {epoch+1} Loss: {avg_loss}")

    def cleanup(self):
        del self.online_model, self.optimizer, self.scheduler, self.scaler
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

class SentenceDataset(torch.utils.data.Dataset):
    def __init__(self, sentences):
        self.sentences = sentences
    def __len__(self):
        return len(self.sentences)
    def __getitem__(self, idx):
        return self.sentences[idx]

# Mount Google Drive for checkpoint and study persistence
from google.colab import drive
drive.mount('/content/drive')

CHECKPOINT_DIR = "/content/drive/MyDrive/violet_bert_base_checkpoints_2"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
STUDY_DB_PATH = os.path.join(CHECKPOINT_DIR, "llm_finetuning_study.db")


def save_plot(trial_number, trainer):
    # Also save the graph to a subfolder called "graphs"
    graphs_dir = os.path.join(CHECKPOINT_DIR, "graphs")
    os.makedirs(graphs_dir, exist_ok=True)
    graph_path = os.path.join(graphs_dir, f"checkpoint_trial_{trial_number}.png")
    trainer.fig.write_image(graph_path)
    print(f"Graph saved to {graph_path}")

def save_checkpoint(trial_number, trainer):
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_trial_{trial_number}.pt")
    torch.save({
        'online_model_state_dict': trainer.online_model.state_dict(),
        'optimizer_state_dict': trainer.optimizer.state_dict(),
        'scheduler_state_dict': trainer.scheduler.state_dict(),
        'best_spearman': trainer.best_spearman,
        'best_pearson': trainer.best_pearson,
        'epochs': trainer.epochs,
        'iterations': trainer.iterations,
    }, checkpoint_path)

    save_plot(trial_number, trainer)

    return checkpoint_path

def load_checkpoint(trainer, checkpoint_path):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        trainer.online_model.load_state_dict(checkpoint['online_model_state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        trainer.best_spearman = checkpoint.get('best_spearman', -float("inf"))
        trainer.best_pearson = checkpoint.get('best_pearson', -float("inf"))
        trainer.epochs = checkpoint.get('epochs', [])
        trainer.iterations = checkpoint.get('iterations', [])
        print(f"Loaded checkpoint from {checkpoint_path}")

def objective(trial):
    config = {
        "model_name": "bert-base-uncased",
        "batch_size": trial.suggest_categorical("batch_size", [64, 128]),
        "projection_depth": trial.suggest_int("projection_depth", 2, 3),
        "projection_size": trial.suggest_categorical("projection_size", [2048, 4096, 6144, 8192]),
        "epochs": 1,  # fixed at 1 for faster evaluation
        "warmup_proportion": 0.0,
        "max_seq_length": 75,  # or trial.suggest_categorical("max_seq_length", [32, 75]),
        "aug_p": trial.suggest_uniform("aug_p", 0.1, 0.4),
        "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 1e-4),
        "model_save_path": os.path.join(CHECKPOINT_DIR, f"train_stsb_bt-distilbert-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_trial{trial.number}"),
        "num_workers": 10,
        "weight_decay": trial.suggest_uniform("weight_decay", 1e-4, 0.2),
        "lambda_bt": trial.suggest_uniform("lambda_bt", 0.001, 0.2),
        "lambda_mixup": trial.suggest_uniform("lambda_mixup", 0.1, 1.5),
        "use_amp": True,
        "patience": 1000
    }

    trainer = BarlowTwinsNCSE(config)
    # Resume from checkpoint if available
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_trial_{trial.number}.pt")
    if os.path.exists(checkpoint_path):
        load_checkpoint(trainer, checkpoint_path)

    try:
        trainer.fit(trial=trial)
    except KeyboardInterrupt:
        save_checkpoint(trial.number, trainer)
        raise optuna.TrialPruned("Trial interrupted and checkpoint saved.")

    # Evaluate after training
    val_metrics = trainer._evaluate_without_heads()
    current_score = val_metrics.get("spearman_cosine", -float("inf"))

    global BEST_SCORE
    # Save the checkpoint only if this trial outperforms previous trials.
    if current_score > BEST_SCORE:
        BEST_SCORE = current_score
        print(f"Trial {trial.number} achieved a new best score: {current_score:.4f}. Saving checkpoint.")
        save_checkpoint(trial.number, trainer)
    else:
        print(f"Trial {trial.number} did not improve the best score ({BEST_SCORE:.4f}). Checkpoint not saved.")

    trainer.cleanup()
    return current_score

# Create or resume an Optuna study persisted on Google Drive
study = optuna.create_study(
    direction="maximize",
    study_name="llm_finetuning_study",
    storage=f"sqlite:///{STUDY_DB_PATH}",
    load_if_exists=True,
    pruner=optuna.pruners.HyperbandPruner(),
)

'''
study.enqueue_trial(
    {
        "batch_size": 128,
        "projection_depth": 2,
        "projection_size": 8192,
        "aug_p": 0.1205,
        "learning_rate": 2.919e-5,
        "weight_decay": 0.1337,
        "lambda_bt": 0.1128,
        "lambda_mixup": 1.094,
    }
)

study.enqueue_trial(
    {
        "batch_size": 128,
        "projection_depth": 3,
        "projection_size": 8192,
        "aug_p": 0.1323,
        "learning_rate": 3.064e-5,
        "weight_decay": 0.08471,
        "lambda_bt": 0.1575,
        "lambda_mixup": 1.289,
    }
)
'''

study.optimize(objective, n_trials=20)

Epoch 1/1::   1%|          | 66/7701 [01:00<1:56:35,  1.09it/s, loss=5.3e+4, pearson_cosine=0.592, spearman_cosine=0.593, mean_grad_norm=160, learning_rate=2.06e-5]
[I 2025-04-15 19:10:08,348] Trial 143 pruned. Trial interrupted and checkpoint saved.


Graph saved to /content/drive/MyDrive/violet_bert_base_checkpoints_2/graphs/checkpoint_trial_143.png



suggest_uniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float instead.


suggest_loguniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float(..., log=True) instead.

[W 2025-04-15 19:10:10,842] Trial 144 failed with parameters: {'batch_size': 128, 'projection_depth': 3, 'projection_size': 6144, 'aug_p': 0.1268875636524825} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "<ipython-input-2-311849fa0b5e>", line 495, in objective
    "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 1e-4),
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  Fi

KeyboardInterrupt: 