Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about the paper. #3

Open
chaochen99 opened this issue Feb 24, 2022 · 8 comments
Open

Questions about the paper. #3

chaochen99 opened this issue Feb 24, 2022 · 8 comments

Comments

@chaochen99
Copy link

chaochen99 commented Feb 24, 2022

Hello,

I was very fortunate to read your paper, and the experimental results are exciting.
The paper mentions two observations:
Observation 1: Original BERT layers fail to improve the performance.
Observation 2: Embedding biases harm the sentence embeddings performance.
Based on your experimental results, these two phenomena do exist and can be improved. But I don't see a connection between these two observations and prompts.

How does prompt solve the bias problem?

Looking forward to your reply, thanks!

@kongds
Copy link
Owner

kongds commented Feb 24, 2022

Thank you,

For observation 1, we find representing sentences by [CLS] token or averaging is not efficient. By reformulating the sentence embedding task as the mask language task, we can efficiently use the original BERT layers by leveraging the large-scale knowledge.

For observation 2, we learn to directly representing sentences from [MASK] tokens, rather than weighted average of these token embeddings according to probability distribution (Eq 4 in paper).

@chaochen99
Copy link
Author

chaochen99 commented Feb 24, 2022

Thanks for your reply,

At the same time, I only found the base version of the model. I would like to know how the prompt bert performs in bert-large and roberta-lage because we want to flow your work. could you share the checkpoint?

Looking forward to your reply.

@kongds
Copy link
Owner

kongds commented Feb 24, 2022

For large model, i only trained unsupervised bert-large-uncased. But i can't find the checkpoint of it.
The result is as follows:

STS12 STS13 STS14 STS15 STS16 STSb SICK-R Avg.
75.52 87.67 79.24 85.33 80.57 83.00 73.01 80.62

The training command is as follows:

BC=(python -m torch.distributed.launch --nproc_per_node 4 train.py)
GPU=0,1,2,3
BATCH=64
MODEL=bert-large-uncased
LR=3e-5
EXP=unsup-bert-large
EPOCH=1
TEMPLATE="*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*"
ES=125 # --eval_steps
BMETRIC=stsb_spearman # --metric_for_best_model
TRAIN_FILE=data/wiki1m_for_simcse.txt
args=(--mlp_only_train --mask_embedding_sentence\
      --mask_embedding_sentence_delta\
      --mask_embedding_sentence_template "*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*")
CHECKPOINT=result/$EXP
CUDA_VISIBLE_DEVICES=$GPU ${BC[@]}\
              --model_name_or_path $MODEL\
              --train_file $TRAIN_FILE\
              --output_dir $CHECKPOINT\
              --num_train_epochs $EPOCH\
              --per_device_train_batch_size $BATCH \
              --learning_rate $LR \
              --max_seq_length 32\
              --evaluation_strategy steps\
              --metric_for_best_model $BMETRIC\
              --load_best_model_at_end\
              --eval_steps $ES\
              --overwrite_output_dir\
              --temp 0.05\
              --do_train\
              --fp16\
              --preprocessing_num_workers 10\
              ${args[@]}

@chaochen99
Copy link
Author

chaochen99 commented Feb 25, 2022

Here is my result using the parameters:

STS12 STS13 STS14 STS15 STS16 STSb SICK-R Avg.
72.98 86.66 79.24 85.33 80.48 82.41 72.09 79.88

torch==1.7.1+cu110
GPU is 3090
seed=42

Could you tell me what torch version and gpu you are using?

Looking forward to your reply!

@kongds
Copy link
Owner

kongds commented Feb 25, 2022

I use 4 * 16gb V100 and torch==1.6.1+cu101 with apex. By the way, the result of bert-large is from the old codebase, which may be slightly different.

But I don't have V100 cards right now to verify that the gap comes from torch, apex or codebase.

@kongds
Copy link
Owner

kongds commented Feb 25, 2022

I find the trainer_state.json of bert-large, this might help.

{
  "best_metric": 0.8650828143261685,
  "best_model_checkpoint": "result/unsup-bert-large",
  "epoch": 1.0,
  "global_step": 3907,
  "is_hyper_param_search": false,
  "is_local_process_zero": true,
  "is_world_process_zero": true,
  "log_history": [
    {
      "epoch": 0.03,
      "eval_avg_sts": 0.8029055192313586,
      "eval_sickr_spearman": 0.7471528897173139,
      "eval_stsb_spearman": 0.8586581487454031,
      "step": 125
    },
    {
      "epoch": 0.06,
      "eval_avg_sts": 0.8109103797908428,
      "eval_sickr_spearman": 0.7589070128301474,
      "eval_stsb_spearman": 0.8629137467515381,
      "step": 250
    },
    {
      "epoch": 0.1,
      "eval_avg_sts": 0.8093174992808367,
      "eval_sickr_spearman": 0.7579198776333913,
      "eval_stsb_spearman": 0.8607151209282823,
      "step": 375
    },
    {
      "epoch": 0.13,
      "learning_rate": 2.6160737138469415e-05,
      "loss": 0.004,
      "step": 500
    },
    {
      "epoch": 0.13,
      "eval_avg_sts": 0.8045124402414239,
      "eval_sickr_spearman": 0.7583830895756604,
      "eval_stsb_spearman": 0.8506417909071873,
      "step": 500
    },
    {
      "epoch": 0.16,
      "eval_avg_sts": 0.7965214393405081,
      "eval_sickr_spearman": 0.7470930429649221,
      "eval_stsb_spearman": 0.8459498357160942,
      "step": 625
    },
    {
      "epoch": 0.19,
      "eval_avg_sts": 0.8055590122682623,
      "eval_sickr_spearman": 0.7558082382897661,
      "eval_stsb_spearman": 0.8553097862467586,
      "step": 750
    },
    {
      "epoch": 0.22,
      "eval_avg_sts": 0.8155483810926036,
      "eval_sickr_spearman": 0.7704474455726826,
      "eval_stsb_spearman": 0.8606493166125245,
      "step": 875
    },
    {
      "epoch": 0.26,
      "learning_rate": 2.232147427693883e-05,
      "loss": 0.0005,
      "step": 1000
    },
    {
      "epoch": 0.26,
      "eval_avg_sts": 0.8133897854649895,
      "eval_sickr_spearman": 0.7663550996679526,
      "eval_stsb_spearman": 0.8604244712620265,
      "step": 1000
    },
    {
      "epoch": 0.29,
      "eval_avg_sts": 0.8196710191935307,
      "eval_sickr_spearman": 0.7747364788378018,
      "eval_stsb_spearman": 0.8646055595492596,
      "step": 1125
    },
    {
      "epoch": 0.32,
      "eval_avg_sts": 0.8218381673513775,
      "eval_sickr_spearman": 0.7785935203765866,
      "eval_stsb_spearman": 0.8650828143261685,
      "step": 1250
    },
    {
      "epoch": 0.35,
      "eval_avg_sts": 0.8165785131017016,
      "eval_sickr_spearman": 0.7717390018903535,
      "eval_stsb_spearman": 0.8614180243130498,
      "step": 1375
    },
    {
      "epoch": 0.38,
      "learning_rate": 1.8482211415408245e-05,
      "loss": 0.0005,
      "step": 1500
    },
    {
      "epoch": 0.38,
      "eval_avg_sts": 0.8047818318689836,
      "eval_sickr_spearman": 0.7539911256601595,
      "eval_stsb_spearman": 0.8555725380778076,
      "step": 1500
    },
    {
      "epoch": 0.42,
      "eval_avg_sts": 0.7982487502134017,
      "eval_sickr_spearman": 0.7442794290737788,
      "eval_stsb_spearman": 0.8522180713530247,
      "step": 1625
    },
    {
      "epoch": 0.45,
      "eval_avg_sts": 0.799967337700157,
      "eval_sickr_spearman": 0.7443161248352774,
      "eval_stsb_spearman": 0.8556185505650366,
      "step": 1750
    },
    {
      "epoch": 0.48,
      "eval_avg_sts": 0.801522708223845,
      "eval_sickr_spearman": 0.7457060008175922,
      "eval_stsb_spearman": 0.8573394156300977,
      "step": 1875
    },
    {
      "epoch": 0.51,
      "learning_rate": 1.4642948553877656e-05,
      "loss": 0.0004,
      "step": 2000
    },
    {
      "epoch": 0.51,
      "eval_avg_sts": 0.7970257249600163,
      "eval_sickr_spearman": 0.7415913684817947,
      "eval_stsb_spearman": 0.8524600814382378,
      "step": 2000
    },
    {
      "epoch": 0.54,
      "eval_avg_sts": 0.8056394392946331,
      "eval_sickr_spearman": 0.7538111050919695,
      "eval_stsb_spearman": 0.8574677734972966,
      "step": 2125
    },
    {
      "epoch": 0.58,
      "eval_avg_sts": 0.8038362240185967,
      "eval_sickr_spearman": 0.7490150555200695,
      "eval_stsb_spearman": 0.8586573925171238,
      "step": 2250
    },
    {
      "epoch": 0.61,
      "eval_avg_sts": 0.7985717984461609,
      "eval_sickr_spearman": 0.7436228919482213,
      "eval_stsb_spearman": 0.8535207049441006,
      "step": 2375
    },
    {
      "epoch": 0.64,
      "learning_rate": 1.080368569234707e-05,
      "loss": 0.0004,
      "step": 2500
    },
    {
      "epoch": 0.64,
      "eval_avg_sts": 0.8017446914529242,
      "eval_sickr_spearman": 0.74755755174693,
      "eval_stsb_spearman": 0.8559318311589184,
      "step": 2500
    },
    {
      "epoch": 0.67,
      "eval_avg_sts": 0.806035852944099,
      "eval_sickr_spearman": 0.7568251047383123,
      "eval_stsb_spearman": 0.8552466011498855,
      "step": 2625
    },
    {
      "epoch": 0.7,
      "eval_avg_sts": 0.7931850193859609,
      "eval_sickr_spearman": 0.7340857884155682,
      "eval_stsb_spearman": 0.8522842503563536,
      "step": 2750
    },
    {
      "epoch": 0.74,
      "eval_avg_sts": 0.7977208601703414,
      "eval_sickr_spearman": 0.7462593671372607,
      "eval_stsb_spearman": 0.8491823532034221,
      "step": 2875
    },
    {
      "epoch": 0.77,
      "learning_rate": 6.964422830816484e-06,
      "loss": 0.0005,
      "step": 3000
    },
    {
      "epoch": 0.77,
      "eval_avg_sts": 0.8073046774483807,
      "eval_sickr_spearman": 0.7641048905966784,
      "eval_stsb_spearman": 0.8505044643000832,
      "step": 3000
    },
    {
      "epoch": 0.8,
      "eval_avg_sts": 0.8022302732886464,
      "eval_sickr_spearman": 0.7558746172719535,
      "eval_stsb_spearman": 0.8485859293053393,
      "step": 3125
    },
    {
      "epoch": 0.83,
      "eval_avg_sts": 0.8059561544995988,
      "eval_sickr_spearman": 0.7599832457200711,
      "eval_stsb_spearman": 0.8519290632791264,
      "step": 3250
    },
    {
      "epoch": 0.86,
      "eval_avg_sts": 0.7939477153857237,
      "eval_sickr_spearman": 0.7442896596983852,
      "eval_stsb_spearman": 0.8436057710730622,
      "step": 3375
    },
    {
      "epoch": 0.9,
      "learning_rate": 3.125159969285897e-06,
      "loss": 0.0004,
      "step": 3500
    },
    {
      "epoch": 0.9,
      "eval_avg_sts": 0.8054149775684041,
      "eval_sickr_spearman": 0.75679162706061,
      "eval_stsb_spearman": 0.8540383280761983,
      "step": 3500
    },
    {
      "epoch": 0.93,
      "eval_avg_sts": 0.8034251823040465,
      "eval_sickr_spearman": 0.7553765827811417,
      "eval_stsb_spearman": 0.8514737818269512,
      "step": 3625
    },
    {
      "epoch": 0.96,
      "eval_avg_sts": 0.8037347568705746,
      "eval_sickr_spearman": 0.7554017510782953,
      "eval_stsb_spearman": 0.8520677626628539,
      "step": 3750
    },
    {
      "epoch": 0.99,
      "eval_avg_sts": 0.8021041735230073,
      "eval_sickr_spearman": 0.75362008540155,
      "eval_stsb_spearman": 0.8505882616444647,
      "step": 3875
    },
    {
      "epoch": 1.0,
      "step": 3907,
      "train_runtime": 6852.5932,
      "train_samples_per_second": 0.57
    }
  ],
  "max_steps": 3907,
  "num_train_epochs": 1,
  "total_flos": 169441187714039808,
  "trial_name": null,
  "trial_params": null
}

@Yubo8Zhang
Copy link

你好,我也在读您的论文后关于observation 1有一些疑问。
1.文中用公式(1)衡量维基百科中的100,000条句子的句向量编码的各向异性程度(我理解就是不均匀程度)时,这10万条句子的主题分布是怎么样的呢?因为如果存在某一个主题的句子占比过大,那么一个好的模型(同时满足alignment & uniformity)就会倾向于把它们映射到语义空间的邻域,公式(1)的值就会偏大。这样的话表1的结论可能就会不太严谨🧐
2.使用prompt 模板中[mask]位的对应向量作为sentence embedding和之前研究常用的[cls]位作为sentence embedding我理解似乎原理差不多,为什么prompt方法能帮助缓解bias呢?

期待您的回复,感谢!

@kongds
Copy link
Owner

kongds commented Feb 26, 2023

感谢关注我们的论文

  1. 我们用的100,000的句子是随机抽取的,且这个值计算是根据100,000*100,000的句子对平均而来的。即使存在100个相同主题的句子,这些句子对对于结果的影响也只有千分之一。
  2. 虽然使用[cls]也不会存在token bias的问题,但是[cls]主要的问题是没法直接利用原始预训练模型来表示句子表征。prompt和cls最主要的区别是直接利用template来帮助预训练模型表示句子。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants