In [1]:
"""
This examples trains a CrossEncoder for the STSbenchmark task. A CrossEncoder takes a sentence pair
as input and outputs a label. Here, it output a continious labels 0...1 to indicate the similarity between the input pair.

It does NOT produce a sentence embedding and does NOT work for individual sentences.

Usage:
python training_stsbenchmark.py
"""
from torch.utils.data import DataLoader
import math
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from sentence_transformers import InputExample
import logging
from datetime import datetime
from datasets import load_dataset
from softprompt_crossencoder import PromptTunedCrossEncoder

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### /print debug information to stdout


EVIDENCE_LABEL = 'CONTRADICT'

# #Define our Cross-Encoder
train_batch_size = 16
num_epochs = 4
model_save_path = f'output/ce-{EVIDENCE_LABEL}'

#We use cross-encoder/ms-marco-MiniLM-L-12-v2 as base model and set num_labels=1, which predicts a continous score between 0 and 1
model = PromptTunedCrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', num_labels=1)


# Read scifact dataset
logger.info("Read scifact train dataset")
corpus = load_dataset('scifact', 'corpus')
claims = load_dataset('scifact', 'claims')


corpus_df = corpus['train'].to_pandas()
corpus_df['doc_id_str'] = corpus_df['doc_id'].apply(lambda x: str(x))

train_samples = []
train_df = claims['train'].to_pandas()
train_df = train_df.loc[(train_df['evidence_label'] == 'CONTRADICT') | (train_df['evidence_label'] == 'SUPPORT')]
for i, doc in train_df.iterrows():
    claim = doc['claim']
    evidence = corpus_df[corpus_df['doc_id_str'].apply(lambda x: str(x)) == doc['evidence_doc_id']]['abstract'].iloc[0][doc['evidence_sentences'][0]]
    type_ = doc['evidence_label']
    score = 1 if type_ == EVIDENCE_LABEL else 0
    train_samples.append(InputExample(texts=[claim, evidence], label=score))

dev_samples = []
dev_df = claims['validation'].to_pandas()
dev_df = dev_df.loc[(dev_df['evidence_label'] == 'CONTRADICT') | (dev_df['evidence_label'] == 'SUPPORT')]
for i, doc in dev_df.iterrows():
    claim = doc['claim']
    evidence = corpus_df[corpus_df['doc_id_str'].apply(lambda x: str(x)) == doc['evidence_doc_id']]['abstract'].iloc[0][doc['evidence_sentences'][0]]
    type_ = doc['evidence_label']
    score = 1 if type_ == EVIDENCE_LABEL else 0
    dev_samples.append(InputExample(texts=[claim, evidence], label=score))



2022-10-17 18:37:53 - Use pytorch device: cpu
Initializing soft prompt...
2022-10-17 18:37:53 - Read scifact train dataset
2022-10-17 18:37:54 - Reusing dataset scifact (/Users/domenicrosati/.cache/huggingface/datasets/scifact/corpus/1.0.0/15660e43ecfb3f7420850027005a63611abb2d401e9746b4059c1260745d9831)


  0%|          | 0/1 [00:00<?, ?it/s]

2022-10-17 18:37:55 - Reusing dataset scifact (/Users/domenicrosati/.cache/huggingface/datasets/scifact/claims/1.0.0/15660e43ecfb3f7420850027005a63611abb2d401e9746b4059c1260745d9831)


  0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
model.model.soft_prompt.weight

Parameter containing:
tensor([[-0.0177, -0.0030, -0.0131,  ...,  0.0339, -0.0048,  0.0278],
        [-0.0145, -0.0273,  0.0435,  ..., -0.0140,  0.0118, -0.0067],
        [-0.0159, -0.0098,  0.0285,  ..., -0.0207,  0.0194, -0.0070],
        ...,
        [-0.0286, -0.0096,  0.0198,  ..., -0.0143,  0.0192,  0.0015],
        [-0.0200, -0.0352,  0.0410,  ..., -0.0092,  0.0259, -0.0069],
        [-0.0332, -0.0179,  0.0151,  ..., -0.0184,  0.0263, -0.0073]],
       requires_grad=True)

In [24]:
inputs

{'input_ids': tensor([  101,  1996,  3899,  2003, 21392,  2075,   102]), 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

loss: 4540475.5


In [30]:
from transformers import (
    AdamW,
    get_scheduler
)

In [31]:
class Config:
    # Same default parameters as run_clm_no_trainer.py in tranformers
    # https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
    num_train_epochs = 3
    weight_decay = 0.01
    learning_rate = 0.01
    lr_scheduler_type = "linear"
    num_warmup_steps = 0
    max_train_steps = num_train_epochs
    
    # Prompt-tuning
    # number of prompt tokens
    n_prompt_tokens = 20
    # If True, soft prompt will be initialized from vocab 
    # Otherwise, you can set `random_range` to initialize by randomization.
    init_from_vocab = True
    # random_range = 0.5
args = Config()

In [33]:
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.model.named_parameters() if n == "soft_prompt.weight"],
        "weight_decay": args.weight_decay,
    }
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps,
    num_training_steps=args.max_train_steps,
)




In [48]:
inputs = model.tokenizer("The dog is peeing", "The dog is urinating")
import torch
inputs['input_ids'] = torch.tensor(inputs['input_ids'])
outputs = model.forward(input_ids=inputs["input_ids"], labels=torch.Tensor([1]))
loss = outputs.loss
print(f"loss: {loss}")

loss: 10269.396484375


In [49]:
outputs

SequenceClassifierOutput(loss=tensor(10269.3965, grad_fn=<MseLossBackward0>), logits=tensor([[1.8434]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [50]:
inputs['input_ids']

tensor([  101,  1996,  3899,  2003, 21392,  2075,   102,  1996,  3899,  2003,
        24471, 19185,   102])

In [51]:
loss.to(torch.float32).backward()
optimizer.step()

In [58]:
model.model.soft_prompt.weight.shape

torch.Size([100, 384])

In [57]:
model.model.bert.embeddings.word_embeddings

Embedding(30522, 384, padding_idx=0)