In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import accuracy_score
from transformers import (
    DistilBertTokenizerFast,
    DistilBertPreTrainedModel,
    DistilBertModel,
    Trainer,
    TrainingArguments
)
from transformers import AutoModel, AutoConfig, AutoTokenizer
import torch
import torch.nn as nn
import wandb

In [2]:
df = pd.read_csv('data_logs/wmt14_bleu_threshold.csv')

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/ubuntu/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_3760282/2468758156.py", line 1, in <module>
    df = pd.read_csv('data_logs/wmt14_bleu_threshold.csv')
  File "/usr/lib/python3/dist-packages/pandas/io/parsers.py", line 685, in parser_f
    return _read(filepath_or_buffer, kwds)
  File "/usr/lib/python3/dist-packages/pandas/io/parsers.py", line 457, in _read
    parser = TextFileReader(fp_or_buf, **kwds)
  File "/usr/lib/python3/dist-packages/pandas/io/parsers.py", line 895, in __init__
    self._make_engine(self.engine)
  File "/usr/lib/python3/dist-packages/pandas/io/parsers.py", line 1135, in _make_engine
    self._engine = CParserWrapper(self.f, **self.options)
  File "/usr/lib/python3/dist-packages/pandas/io/parsers.py", line 1917, in __init__
    self._reader = parsers.TextReader(src, **kwds)
  File "pandas

In [None]:
labels_1b = (df['1b'] - df['1b'].min()).astype(int).tolist()
labels_3b = (df['3b'] - df['3b'].min()).astype(int).tolist()
labels_8b = (df['8b'] - df['8b'].min()).astype(int).tolist()

In [None]:
print("Labels for 1b:", set(labels_1b))
print("Labels for 3b:", set(labels_3b))
print("Labels for 8b:", set(labels_8b))

In [None]:
texts = df['input_text'].tolist()

In [None]:
train_texts, val_texts, train_labels_texts, val_labels_texts = train_test_split(
    texts,
    list(zip(labels_1b, labels_3b, labels_8b)),
    test_size=0.2,
    random_state=42
)

In [None]:
train_labels_1b, train_labels_3b, train_labels_8b = zip(*train_labels_texts)
val_labels_1b, val_labels_3b, val_labels_8b = zip(*val_labels_texts)

In [None]:
# tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

In [None]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

In [None]:
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)

In [None]:
class MultiHeadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels_list):
        self.encodings = encodings
        self.labels_list = labels_list  

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        labels = torch.tensor([label[idx] for label in self.labels_list], dtype=torch.long)
        item['labels'] = labels
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

In [None]:
train_dataset = MultiHeadDataset(
    train_encodings,
    [list(train_labels_1b), list(train_labels_3b)] # list(train_labels_8b)
)

In [None]:
val_dataset = MultiHeadDataset(
    val_encodings,
    [list(val_labels_1b), list(val_labels_3b)] # list(val_labels_8b)
)

In [None]:
class DistilBertMultiHeadClassification(DistilBertPreTrainedModel):
    def __init__(self, config, num_labels_per_head):
        super().__init__(config)
        self.num_heads = len(num_labels_per_head)
        self.num_labels_per_head = num_labels_per_head

        self.distilbert = DistilBertModel(config)

        self.classifier_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.hidden_size, 128),  
                nn.ReLU(),                         
                nn.Linear(128, 64),                 
                nn.ReLU(),                          
                nn.Linear(64, num_labels)          
            )
            for num_labels in num_labels_per_head
        ])

        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        # Get the outputs from DistilBERT backbone
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)
        pooled_output = hidden_state[:, 0]  # Take the representation of [CLS] token

        # Compute logits for each head
        logits = [classifier(pooled_output) for classifier in self.classifier_heads]

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            losses = []
            # labels: Tensor of shape (batch_size, num_heads)
            for i in range(self.num_heads):
                # Extract labels for the current head
                head_labels = labels[:, i]  # Shape: (batch_size,)
                # Compute loss for the current head
                losses.append(loss_fct(logits[i], head_labels))
            loss = sum(losses) / self.num_heads  # Average the loss over all heads

        return {'loss': loss, 'logits': logits}

In [None]:
class SentenceTransformerMultiHeadClassification(nn.Module):
    def __init__(self, model_name, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.sentence_transformer = AutoModel.from_pretrained(model_name)
        hidden_size = self.sentence_transformer.config.hidden_size

        # self.classifier_heads = nn.ModuleList([
        #     nn.Linear(hidden_size, 1)  
        #     for _ in range(num_heads)
        # ])

        # Each head outputs a single logit
        self.classifier_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, 128),
                nn.ReLU(),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 1)  
            )
            for _ in range(num_heads)
        ])
    
    def forward(self, input_ids=None, attention_mask=None, labels=None):
        # Get model outputs
        outputs = self.sentence_transformer(input_ids=input_ids, attention_mask=attention_mask)
        
        # Perform mean pooling to get sentence embeddings
        pooled_output = self.mean_pooling(outputs, attention_mask)

        # Compute logits for each head
        logits = [head(pooled_output) for head in self.classifier_heads]
        logits = torch.stack(logits, dim=1).squeeze(-1)  # Shape: (batch_size, num_heads)

        loss = None
        if labels is not None:
            # labels shape: (batch_size, num_heads)
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels.float())

        return {'loss': loss, 'logits': logits}

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element contains token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
        return sum_embeddings / sum_mask

In [None]:
num_labels_per_head = [
    max(labels_1b)+1,
    max(labels_3b)+1,
    # max(labels_8b)+1,
]

In [None]:
print("Number of classes per head:", num_labels_per_head)

In [None]:
# model = DistilBertMultiHeadClassification.from_pretrained(
#     'distilbert-base-uncased',
#     num_labels_per_head=num_labels_per_head
# )

In [None]:
num_heads = 2  

model = SentenceTransformerMultiHeadClassification(
    model_name='sentence-transformers/all-mpnet-base-v2',
    num_heads=num_heads
)

In [None]:
print(f"Hidden size: {model.sentence_transformer.config.hidden_size}")

In [None]:
# for param in model.distilbert.parameters():
#     param.requires_grad = False

for param in model.sentence_transformer.parameters():
    param.requires_grad = False

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred

    if not isinstance(logits, torch.Tensor):
        logits = torch.tensor(logits)

    # Apply sigmoid to logits to get probabilities
    probs = torch.sigmoid(logits)

    # Convert probabilities to binary predictions (threshold at 0.5)
    preds = (probs >= 0.5).int().numpy()

    labels = np.array(labels)
    labels = labels.astype(int)

    accuracies = []
    # precisions = []
    # recalls = []
    # f1s = []

    num_heads = labels.shape[1]

    for i in range(num_heads): 
        pred = preds[:, i]
        label = labels[:, i]

        accuracy = accuracy_score(label, pred)
        # precision = precision_score(label, pred, zero_division=0)
        # recall = recall_score(label, pred, zero_division=0)
        # f1 = f1_score(label, pred, zero_division=0)

        accuracies.append(accuracy)
        # precisions.append(precision)
        # recalls.append(recall)
        # f1s.append(f1)

    # Compute average metrics across heads
    avg_accuracy = np.mean(accuracies)
    # avg_precision = np.mean(precisions)
    # avg_recall = np.mean(recalls)
    # avg_f1 = np.mean(f1s)

    metrics = {
        'accuracy': avg_accuracy,
        # 'precision': avg_precision,
        # 'recall': avg_recall,
        # 'f1': avg_f1
    }

    # Add per-head metrics
    # for i, (acc, prec, rec, f1) in enumerate(zip(accuracies, precisions, recalls, f1s)):
    #     metrics[f'accuracy_head_{i+1}'] = acc
    #     metrics[f'precision_head_{i+1}'] = prec
    #     metrics[f'recall_head_{i+1}'] = rec
    #     metrics[f'f1_head_{i+1}'] = f1

    for i, acc in enumerate(accuracies):
        metrics[f'accuracy_head_{i+1}'] = acc

    return metrics

In [None]:
training_args = TrainingArguments(
    output_dir='./data_logs/sentence_transformer_multihead',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    logging_steps=10,
    logging_dir='./data_logs/sentence_transformer_multihead',
    report_to="wandb",
)

In [None]:
wandb.init(project="MESS+", name="sentence-transformer-multihead-run")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.evaluate()

# eval_results = trainer.evaluate(eval_dataset=val_dataset, metric_key_prefix="eval")
# print("Validation Results:")
# for key, value in eval_results.items():
#     if key.startswith("eval_accuracy"):
#         print(f"{key}: {value}")

# train_results = trainer.evaluate(eval_dataset=train_dataset, metric_key_prefix="train")
# print("Training Results:")
# for key, value in train_results.items():
#     if key.startswith("train_accuracy"):
#         print(f"{key}: {value}")

wandb.finish()