# Global modification of hyperparameters

In [234]:
# 模型设定NoCollapse是带SVD的Bert版本
Model_lst = ["NoCollapse","Bert","Bert_large","roberta-base","roberta-large","ALBERT","roberta-large-mnli"]
PLM_TYPE = Model_lst[1]

# training_type_lst决定了loss的设置
training_type_lst = ['man','mix','original','low_level','medium_level','high_level','bank','multiple','collapse','new_collapse']
training_type = training_type_lst[1]

# prompt模板的选择
template_lst = ['man','soft','mix','ptuning','ptr']
template_type = template_lst[1]

# few-shot
num_examples_per_label_ = 10

# Data construction, full data, standard few-shot, construction of few-shot based on data distribution of dataset
data_con_lst = ['full_data','fewshot','DA']
data_con = data_con_lst[1]

# few-shot or zero-shot
few_shot_train = True
# few_shot_train = False

# lr 常规是model=1e-4/2e-5 template=1e-3/1e-2
model_lr = 1e-4
template_lr= 1e-3

# 数据集选择
dataset_lst = ["go_emotions","emotion"]
dataset_ = dataset_lst[-1]

# 超参数
epoch_num = 5
bank_size = 16
batch_size_ = 8

use_cuda = True
# use_cuda = False

if data_con == "full_data":
  model_save_pth = PLM_TYPE+"_"+str(training_type)+"_template_"+template_type+"_"+data_con+"_"+"_bankSize_"+str(bank_size)+"_dataset_"+dataset_+".pth"
else:
  model_save_pth = PLM_TYPE+"_"+str(training_type)+"_template_"+template_type+"_"+data_con+"_"+str(num_examples_per_label_)+"_bankSize_"+str(bank_size)+"_dataset_"+dataset_+".pth"

model_save_pth

'Bert_mix_template_soft_fewshot_10_bankSize_16_dataset_emotion.pth'

# Development Dependent Environment

In [235]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [236]:
import os

framework_root = "/content/drive/MyDrive/Colab/OpenPromptV2/"
os.chdir(framework_root)

In [237]:
!pip install -r  requirements.txt
!python setup.py install

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
requirements:
['transformers>=4.10.0', 'sentencepiece==0.1.96', '# scikit-learn>=0.24.2', 'tqdm>=4.62.2', 'tensorboardX', 'nltk', 'yacs', 'dill', 'datasets', 'rouge==1.0.0', 'pyarrow', 'scipy']
running install
running bdist_egg
running egg_info
writing openprompt.egg-info/PKG-INFO
writing dependency_links to openprompt.egg-info/dependency_links.txt
writing requirements to openprompt.egg-info/requires.txt
writing top-level names to openprompt.egg-info/top_level.txt
reading manifest file 'openprompt.egg-info/SOURCES.txt'
adding license file 'LICENSE'
writing manifest file 'openprompt.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
copying openprompt/prompt_base.py -> build/lib/openprompt
copying openprompt/pipeline_base.py -> build/lib/openprompt
creating build/bdist.linux-x86_64/egg
creating build/bdist.linux-x86_64/egg/ope

In [238]:
 !pip install pytorch-metric-learning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Dataset (go_emotions/emotion)

In [239]:
import numpy as np
import torch
import random
from openprompt.utils.reproduciblity import set_seed
import re

def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    set_seed(seed)

set_seeds()

In [240]:
from datasets import load_dataset

In [241]:
if dataset_ == "go_emotions":
  emotions = load_dataset("go_emotions","simplified")

  df_train = emotions['train'].to_pandas()
  df_dev = emotions['validation'].to_pandas()
  df_test = emotions['test'].to_pandas()

  # Adjusting the order to determine if the hierarchy can be retained
  # Since the dataset and the paper are in a different tag order, reset
  labels_ = ['amusement','excitement','joy','love','desire','optimism','caring','pride','admiration','gratitude','relief','approval','realization','surprise','curiosity','confusion', 'fear', 'nervousness', 'remorse','embarrassment','disappointment', 'sadness','grief','disgust','anger','annoyance','disapproval','neutral']
  labels_cols = ["admiration","amusement","anger","annoyance","approval","caring","confusion","curiosity","desire","disappointment","disapproval","disgust","embarrassment","excitement","fear", "gratitude","grief","joy","love","nervousness","optimism","pride","realization","relief","remorse","sadness","surprise","neutral"]

  change_dict = {}
  for idx, item in enumerate(labels_cols):
    change_dict[idx] = labels_.index(item)
  
  print(change_dict)
  print(labels_)
  print(len(labels_))

  df_train["labels_num"] = list(map(len, df_train["labels"].values.tolist())) 
  df_train = df_train.drop(df_train[df_train["labels_num"]!=1].index)
  print(df_train["labels"].values.tolist()[0])
  df_train["label"] = list(map(lambda a: change_dict[a.tolist()[0]], df_train["labels"].values.tolist()))
  df_train["idx"] = df_train.index

  df_dev["labels_num"] = list(map(len, df_dev["labels"].values.tolist())) 
  df_dev = df_dev.drop(df_dev[df_dev["labels_num"]!=1].index)
  df_dev["label"] = list(map(lambda a: change_dict[a.tolist()[0]], df_dev["labels"].values.tolist()))
  df_dev["idx"] = df_dev.index

  df_test["labels_num"] = list(map(len, df_test["labels"].values.tolist())) 
  df_test = df_test.drop(df_test[df_test["labels_num"]!=1].index)
  df_test["label"] = list(map(lambda a: change_dict[a.tolist()[0]], df_test["labels"].values.tolist()))
  df_test["idx"] = df_test.index

elif dataset_ == "emotion":
  emo = load_dataset(dataset_)
  
  label_cols = ["sadness","joy","love","anger","fear","surprise"]
  labels_ = label_cols

  df_train = emo['train'].to_pandas()
  df_train["idx"] = df_train.index

  df_dev = emo['validation'].to_pandas()
  df_dev["idx"] = df_dev.index

  df_test = emo['test'].to_pandas()
  df_test["idx"] = df_test.index



  0%|          | 0/3 [00:00<?, ?it/s]

In [242]:
import pandas as pd
from datasets import Dataset

raw_dataset = {"train": df_train, "validation": df_dev, "test": df_test}
# print(raw_dataset)
# print(type(raw_dataset["train"]))

# Data distribution

In [243]:
print(len(df_train))

max_num = 0
max_labels = 0


for i in range(len(label_cols)):
  cur_num = len(df_train[df_train["label"]==i])
  if max_num < cur_num:
    max_num = cur_num

labels_dict = {}
for i in range(len(label_cols)):
  labels_dict[i]=len(df_train[df_train["label"]==i])/max_num

labels_dict

16000


{0: 0.8701976874300634,
 1: 1.0,
 2: 0.24319283849309958,
 3: 0.40264826557254757,
 4: 0.3612458038045505,
 5: 0.10667661320402835}

# openprompt

In [244]:
from openprompt.data_utils import InputExample

dataset = {}
for split in ['train', 'validation', 'test']:
    dataset[split] = []
    # iter dataframe type
    # print(type(raw_dataset[split]))
    for _, data in raw_dataset[split].iterrows():
        # print(dataset_)
        input_example = InputExample(
            text_a=data["text"], label=data["label"], guid=data["idx"])
        # print(input_example)
        dataset[split].append(input_example)

In [245]:
# for i in range(5):
#     print(dataset['train'][i])

## load PLM

In [246]:
# You can load the plm related things provided by openprompt simply by calling:
from openprompt.plms import load_plm

if PLM_TYPE == "NoCollapse":
  plm, tokenizer, model_config, WrapperClass = load_plm("BertForNoCollapse", "bert-base-cased")
elif PLM_TYPE == "Bert":
  plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")
elif PLM_TYPE == "Bert_large":
  plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-large-cased")
elif PLM_TYPE == "roberta-base":
  plm, tokenizer, model_config, WrapperClass = load_plm("roberta", "roberta-base")
elif PLM_TYPE == "roberta-large":
  plm, tokenizer, model_config, WrapperClass = load_plm("roberta", "roberta-large")
elif PLM_TYPE == "roberta-large-mnli":
  plm, tokenizer, model_config, WrapperClass = load_plm("roberta", "roberta-large-mnli")
elif PLM_TYPE == "ALBERT":
  plm, tokenizer, model_config, WrapperClass = load_plm("albert", "albert-base-v2")
else:
  print("Other models are not yet supported")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


##  ManualTemplate

In [247]:
# Constructing Template
# A template can be constructed from the yaml config, but it can also be constructed by directly passing arguments.
from openprompt.prompts import ManualTemplate, SoftTemplate, MixedTemplate, PtuningTemplate, PTRTemplate

In [248]:
if template_type == "man":
  # GoEmotions
  # man
  # template_text = '{"placeholder":"text_a"} Is was {"mask"}.'
  # template_text = '{"placeholder":"text_a"} I am {"mask"}.'
  template_text = '{"placeholder":"text_a"} The emotional aspect of this text is {"mask"}.'
  mytemplate = ManualTemplate(tokenizer=tokenizer, text=template_text)

## SoftTemplate

In [249]:
if template_type == "soft":
  # Soft
  # id 可以将相同文本保持一样的soft 
  # when self.text is '{"soft": None} {"soft": "the", "soft_id": 1} {"soft": None} {"soft": "it", "soft_id": 3} {"soft_id": 1} {"soft": "was"} {"mask"}', output is [1, 2, 3, 4, 2, 5, 0]
  # template_text = '{"placeholder":"text_a"} {"soft": "It was"} {"mask"}.'
  # template_text = '{"placeholder":"text_a"} {"soft": None, "duplicate": 10, "same": True} {"mask"}.'

  template_text = '{"placeholder":"text_a"} {"soft": "The emotional aspect of this text is"} {"mask"}.'

  # template_text = '{"placeholder":"text_a"} {"soft"} {"soft"} {"soft"} {"mask"}.'
  # template_text = '{"placeholder": "text_a"} {"placeholder": "text_b"} {"mask"} .'

  mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, text=template_text)

## MixedTemplate

In [250]:
if template_type == "mix":
  # 目前次好的temple
  # template_text = '{"placeholder":"text_a"} {"soft": "It was"} {"mask"}.'
  # 目前最好的temple
  template_text = '{"placeholder":"text_a"} {"soft": "The"} emotional {"soft": "aspect of this text is"} {"mask"}.'

  # template_text = '{"soft"} Emotion: {"soft"} {"placeholder":"text_a"} {"soft": "The"} emotional {"soft": "aspect of this text is"} {"mask"}.'
  # template_text = '{"placeholder":"text_a"} {"soft": "The"} emotional {"soft": "aspect of this text is"} {"mask"} {"soft"}.'
  # template_text = '{"placeholder":"text_a"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"mask"} .'

  mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer, text=template_text)

## PtuningTemplate

In [251]:
if template_type == "ptuning":
  # ptuning
  # 经过预训练导出的embedding高度离散化 不能够学习到soft之间的关系 ptuning使用的是ptuning 用mlp或者lstm模型初始化embedding


  # # 实验主要适用模板
  # template_text = '{"soft"} Emotion {"soft"} {"mask"} {"soft"} {"placeholder": "text_a"}'

  # # 目前次好的temple 06/28 macro_f1 = 0.056622359610616534
  # template_text = '{"placeholder":"text_a"} {"soft": "It was"} {"mask"}.'

  # 目前最好的temple 06/28 macro_f1 = 0.06245602902611468
  template_text = '{"placeholder":"text_a"} {"soft": "The"} emotional {"soft": "aspect of this text is"} {"mask"}.'

  # template_text = '{"placeholder": "text_a"} {"soft"} emotion {"soft"} {"soft"} {"mask"} {"soft"}.'
  # template_text = '{"soft"} Emotion: {"soft"} {"mask"} {"soft"} {"placeholder": "text_a"}'
  # template_text = '{"placeholder": "text_a"} {"soft"} emotional {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"mask"} {"soft"}.'
  # template_text = '{"placeholder": "text_a"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"soft"} {"mask"}'

  mytemplate = PtuningTemplate(model=plm, tokenizer=tokenizer, text=template_text)

## PTR Template

In [252]:
if template_type == "ptr":
  template_text = '{"placeholder":"text_a"} {"soft": "The"} emotional {"soft": "aspect of this text is"} {"mask"}.'
  mytemplate = PTRTemplate(model=plm, tokenizer=tokenizer, text=template_text)

In [253]:
# To better understand how does the template wrap the example, we visualize one instance.
print(dataset['train'][0])
wrapped_example = mytemplate.wrap_one_example(dataset['train'][0]) 
print(wrapped_example)

{
  "guid": 0,
  "label": 0,
  "meta": {},
  "text_a": "i didnt feel humiliated",
  "text_b": "",
  "tgt_text": null
}

[[{'text': 'i didnt feel humiliated', 'loss_ids': 0, 'shortenable_ids': 1}, {'text': '', 'loss_ids': 0, 'shortenable_ids': 0}, {'text': '<mask>', 'loss_ids': 1, 'shortenable_ids': 0}, {'text': '.', 'loss_ids': 0, 'shortenable_ids': 0}], {'guid': 0, 'label': 0}]


# Few-shot & dataloader


In [254]:
from openprompt.data_utils.data_sampler import FewShotSampler


support_sampler = FewShotSampler(num_examples_per_label=num_examples_per_label_,also_sample_dev=False)
dataset['support'] = support_sampler(dataset['train'], seed=42)
print(len(dataset['support']))

60


In [255]:
if data_con == "DA":
  datasetDA_dict = {}
  labels_dict_ = {}
  for k, v in labels_dict.items():
    labels_dict_[k] = int(max(v*num_examples_per_label_,1))
  print(labels_dict_)

  for item in dataset['support']:
    cur_label = int(item.label)
    if cur_label not in datasetDA_dict.keys():
      datasetDA_dict[cur_label] = []
    
    if len(datasetDA_dict[cur_label]) == labels_dict[cur_label]:
      continue

    datasetDA_dict[cur_label].append(item)

  dataset["DA"] = []
  for i in datasetDA_dict.values():
    for j in i:
      dataset["DA"].append(j)
  print(len(dataset["DA"]))

In [256]:
from openprompt import PromptDataLoader

if data_con == 'fewshot':
  train_dataloader = PromptDataLoader(
          dataset = dataset['support'],
          template = mytemplate, 
          tokenizer = tokenizer, 
          tokenizer_wrapper_class=WrapperClass,
          max_seq_length=128,
          decoder_max_length=3,
          batch_size=batch_size_ ,
          shuffle=True,
          teacher_forcing=False, 
          predict_eos_token=False,
          truncate_method="tail"
  )
elif data_con == 'full_data':
  train_dataloader = PromptDataLoader(
        dataset = dataset['train'],
        template = mytemplate, 
        tokenizer = tokenizer, 
        tokenizer_wrapper_class=WrapperClass,
        max_seq_length=128,
        decoder_max_length=3,
        batch_size=batch_size_ ,
        shuffle=True,
        teacher_forcing=False, 
        predict_eos_token=False,
        truncate_method="tail"
)
elif data_con == 'DA':
  train_dataloader = PromptDataLoader(
        dataset = dataset["DA"],
        template = mytemplate, 
        tokenizer = tokenizer, 
        tokenizer_wrapper_class=WrapperClass,
        max_seq_length=128,
        decoder_max_length=3,
        batch_size=batch_size_ ,
        shuffle=True,
        teacher_forcing=False, 
        predict_eos_token=False,
        truncate_method="tail"
)
print(next(iter(train_dataloader)))

# Evaluate
validation_dataloader = PromptDataLoader(
        dataset=dataset["validation"], 
        template=mytemplate, 
        tokenizer=tokenizer, 
        tokenizer_wrapper_class=WrapperClass, 
        max_seq_length=128,
        decoder_max_length=3,
        batch_size=batch_size_ ,
        shuffle=False, 
        teacher_forcing=False, 
        predict_eos_token=False,
        truncate_method="tail"
)

tokenizing: 60it [00:00, 1206.90it/s]


{"input_ids": [[101, 178, 1821, 2296, 1897, 17278, 103, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 178, 1631, 2504, 1374, 1552, 103, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 178, 1631, 1177, 26121, 1219, 1343, 1551, 178, 1108, 7851, 4006, 1250, 103, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

tokenizing: 2000it [00:01, 1167.12it/s]


## KnowledgeableVerbalizer & BackTranslation

In [257]:
# !pip install BackTranslation

In [258]:
# # KnowledgeableVerbalizer
# from BackTranslation import BackTranslation
# import re 
# trans = BackTranslation(url=[
#       'translate.google.com',
#       'translate.google.co.kr',
#     ], proxies={'http': '127.0.0.1:1234', 'http://host.name': '127.0.0.1:4012'})

# languages = ['af','sq','am','hy','az','eu','be','bn','bs',
#              'bg','ca','ceb','ny','zh-cn','zh-tw','co',
#              'hr','cs','da','nl','eo','et','tl','fi',
#              'fr','fy','gl','ka','de','el','gu','ht','ha',
#              'haw','he','hi','hmn','hu','is','ig','id',
#              'ga','it','ja','jw','kn','kk','ko','ku','ky','lo',
#              'la','lv','lb','mk','mg','ms','ml','mt','mi','mr',
#              'mn','my','ne','no','or','ps','fa','pl','pt','pa',
#              'ro','ru','sm','gd','sr','st','sn','sd','si','sk','sl',
#              'so','es','su','sw','sv','tg','ta','te','th','tr',
#              'ur','uk','ug','uz','vi','cy','xh','yi','yo','zu',
#              'ar', 'es']

# # labels = "realization", "surprise", "curiosity", "confusion"
# label_words = {}
# for i in label_cols:
#     print(i + "\n")
#     new_words = set()
#     for j in languages:
#         try:
#             print(j)
#             result = trans.translate(i, src='en', tmp = j)
#             new_word = result.result_text.lower()
#             new_word = re.sub(r'[\W\s]','',new_word)
#             print(new_word)
#             new_words.add(new_word)
#         except NameError:
#             print(j+"can not as translate target")    
#         except TypeError:
#             print(j+"can not as translate target")   
#     print(new_words)
#     label_words[i] = list(new_words)
#     print("\n")
# # print(label_words)

In [259]:
# from nltk.corpus import wordnet

# word_lst = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment',
#             'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']

# synonyms_double_lst = {}
# synonyms_lst = []

# for i in word_lst:
#     for syn in wordnet.synsets(i):
#         for lm in syn.lemmas():
#             synonyms_lst.append(lm.name())
#     synonyms_double_lst[i]=list(set(synonyms_lst))
#     synonyms_lst = []

# print(synonyms_double_lst)

In [260]:
# https://github.com/thunlp/OpenPrompt/blob/main/scripts/TextClassification/agnews/knowledgeable_verbalizer.txt
# from openprompt.prompts import KnowledgeableVerbalizer
# myverbalizer = KnowledgeableVerbalizer(tokenizer, num_classes=28).from_file("scripts/GoEmotions/backTran_verbalizer.txt")

## ManualVerbalizer

In [261]:
# Define the verbalizer
# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:

from openprompt.prompts import ManualVerbalizer

# for example the verbalizer contains multiple label words in each class
label_words = {i:[i] for i in labels_}
print(label_words)
myverbalizer = ManualVerbalizer(tokenizer, 
                                classes = labels_,
                                label_words=label_words)

# print(myverbalizer.label_words_ids)
# logits = torch.randn(len(labels),len(tokenizer)) # creating a pseudo output from the plm, and 
# print(myverbalizer.process_logits(logits)) # see what the verbalizer do

{'sadness': ['sadness'], 'joy': ['joy'], 'love': ['love'], 'anger': ['anger'], 'fear': ['fear'], 'surprise': ['surprise']}


## SoftVerbalizer

In [262]:
# # Soft openprompt
# from openprompt.prompts import SoftVerbalizer
# myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=len(labels))

# Model & freeze_plm

In [263]:
# Although you can manually combine the plm, template, verbalizer together, we provide a pipeline 
# model which take the batched data from the PromptDataLoader and produce a class-wise logits

from openprompt import PromptForClassification

if training_type in ['collapse','new_collapse']:
  prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer)

  # MLP 
  for para in prompt_model.plm.fc1.parameters():
    para.requires_grad = True

  for para in prompt_model.plm.fc2.parameters():
    para.requires_grad = True

  for para in prompt_model.plm.fc3.parameters():
    para.requires_grad = True

else:
  prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=True)

In [264]:
if use_cuda:
    prompt_model = prompt_model.cuda()

In [265]:
# print(prompt_model)

### optimizer

In [266]:
# Now the training is standard
from transformers import get_linear_schedule_with_warmup
from transformers.optimization import AdamW 
loss_func = torch.nn.CrossEntropyLoss()

no_decay = ['bias', 'LayerNorm.weight']

if training_type == 'man':
  # man
  # it's always good practice to set no decay to biase and LayerNorm parameters
  optimizer_grouped_parameters = [
      {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
      {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  ]

  optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)
  tot_step  = len(train_dataloader)*5
  scheduler = get_linear_schedule_with_warmup(optimizer, 0, tot_step)
else:
  # it's always good practice to set no decay to biase and LayerNorm parameters
  optimizer_grouped_parameters1 = [
      {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
      {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  ]

  # Using different optimizer for prompt parameters and model parameters
  optimizer_grouped_parameters2 = [
      {'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}
  ]

  optimizer1 = AdamW(optimizer_grouped_parameters1, lr=model_lr)
  optimizer2 = AdamW(optimizer_grouped_parameters2, lr=template_lr)

  tot_step  = len(train_dataloader)*5
  scheduler1 = get_linear_schedule_with_warmup(optimizer1, 0, tot_step)
  scheduler2 = get_linear_schedule_with_warmup(optimizer2, 0, tot_step)



### Evaluation functions

In [267]:
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
def compute_metrics(labels, preds):
    assert len(preds) == len(labels)
    results = dict()

    results["accuracy"] = accuracy_score(labels, preds)
    results["macro_precision"], results["macro_recall"], results[
        "macro_f1"], _ = precision_recall_fscore_support(
        labels, preds, average="macro")
    results["micro_precision"], results["micro_recall"], results[
        "micro_f1"], _ = precision_recall_fscore_support(
        labels, preds, average="micro")
    results["weighted_precision"], results["weighted_recall"], results[
        "weighted_f1"], _ = precision_recall_fscore_support(
        labels, preds, average="weighted")

    return results

# labels = [0,1]
# preds = [0,0]
# print(compute_metrics(labels, preds))

# Model training

## Manual Prompt

In [268]:
import time
import datetime
from tqdm import tqdm

min_f1 = 0.0

In [269]:
# Manual Prompt
if training_type == 'man' and few_shot_train:
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits = prompt_model(inputs)
          labels = inputs['label']
          loss = loss_func(logits, labels)
          loss.backward()
          tot_loss += loss.item()
          optimizer.step()
          scheduler.step()
          optimizer.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  prompt_model.eval()
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)

  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

## Mix/Soft prompt 

In [None]:
# Mix prompt Soft
if training_type == 'mix' and few_shot_train:
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          
          logits, _ = prompt_model(inputs)
          
          labels = inputs['label']
          loss = loss_func(logits, labels)
          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits, _ = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)

  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0, average loss: 2.335371255874634
Dev acc: 0.187
Dev f1:0.18101123680040457
save model in epoch:0


 20%|██        | 1/5 [00:14<00:57, 14.33s/it]

Epoch 1, average loss: 1.9589383602142334
Dev acc: 0.277
Dev f1:0.2543870059961078
save model in epoch:1


 40%|████      | 2/5 [00:28<00:42, 14.28s/it]

Epoch 2, average loss: 1.7602676749229431


# metric_learning

### metric learning level define

In [None]:
from pytorch_metric_learning import distances, losses, miners, reducers, testers
# from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func_metric = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(
    margin=0.2, distance=distance, type_of_triplets="semihard"
)

# 0-positive, 1-negative, 2-ambiguous, 3-neutral
# if training_type in ['high_level','collapse','new_collapse'] and dataset_ == 'go_emotions':
#   Hierarchical_class_high = {
#   0: ["amusement", "excitement", "joy", "love", "desire", "optimism", "caring", "pride", "admiration", "gratitude", "relief", "approval"],
#   1: ["fear", "nervousness", "remorse", "embarrassment", "disappointment", "sadness", "grief", "disgust", "anger", "annoyance", "disapproval"],
#   2: ["realization", "surprise", "curiosity", "confusion"]
#   }

# # 0-anger, 1-disgust, 2-fear, 3-joy, 4-sadness, 5-surprise TODO 06/22
# elif training_type == 'medium_level' and dataset_ == 'go_emotions':
#   Hierarchical_class_mid = {
#   0: ["anger", "annoyance", "disapproval"],
#   1: ["disgust"],
#   2: ["fear", "nervousness"],
#   3: ["joy", "amusement", "approval", "excitement", "gratitude",  "love", "optimism", "relief", "pride", "admiration", "desire", "caring"],
#   4: ["sadness", "disappointment", "embarrassment", "grief",  "remorse"],
#   5: ["surprise", "realization", "confusion", "curiosity"]
#   }

if dataset_ == 'go_emotions':
  Hierarchical_class_high = {
  0: ["amusement", "excitement", "joy", "love", "desire", "optimism", "caring", "pride", "admiration", "gratitude", "relief", "approval"],
  1: ["fear", "nervousness", "remorse", "embarrassment", "disappointment", "sadness", "grief", "disgust", "anger", "annoyance", "disapproval"],
  2: ["realization", "surprise", "curiosity", "confusion"]
  }

  Hierarchical_class_mid = {
  0: ["anger", "annoyance", "disapproval"],
  1: ["disgust"],
  2: ["fear", "nervousness"],
  3: ["joy", "amusement", "approval", "excitement", "gratitude",  "love", "optimism", "relief", "pride", "admiration", "desire", "caring"],
  4: ["sadness", "disappointment", "embarrassment", "grief",  "remorse"],
  5: ["surprise", "realization", "confusion", "curiosity"]
  }

elif dataset_ == 'emotion':
  Hierarchical_class = {
  0: ["joy", "love"],
  1: ["fear",  "sadness", "anger"],
  2: ["surprise"],
  }            

def label_change(x):
  for key, item in Hierarchical_class.items():
    if labels_[x] in item:
      return int(key) 
  return int(len(Hierarchical_class))

def label_change_mid(x):
  for key, item in Hierarchical_class_mid.items():
    if labels_[x] in item:
      return int(key) 
  return int(len(Hierarchical_class))

def label_change_high(x):
  for key, item in Hierarchical_class_high.items():
    if labels_[x] in item:
      return int(key) 
  return int(len(Hierarchical_class))

### low level Metric Learning（28）

In [None]:
# low level Metric Learning(28分类)
if training_type == 'original' and few_shot_train:
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits, _ = prompt_model(inputs)
          labels = inputs['label']

          indices_tuple1 = mining_func(logits, labels)
          loss1 = loss_func_metric(logits, labels, indices_tuple1)

          # Hierarchical_labels = inputs['label'].cpu().tolist()
          # Hierarchical_labels = list(map(label_change, Hierarchical_labels))
          # Hierarchical_labels = torch.Tensor(Hierarchical_labels).to(labels.device)
          # indices_tuple2 = mining_func(logits, Hierarchical_labels)
          # loss2 = loss_func_metric(logits, Hierarchical_labels, indices_tuple2)

          # loss = loss1+loss2
          loss = loss1

          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          scheduler1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          scheduler2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits, _ = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)

  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

### Mix/Soft prompt Soft + low level Metric Learning（28）Proxy-NCA

In [None]:
 # low level Metric Learning(28分类)
if training_type == 'low_level'and few_shot_train:
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits, _ = prompt_model(inputs)
          labels = inputs['label']

          # 交叉损失
          loss1 = loss_func(logits, labels)

          # 对比学习损失
          indices_tuple1 = mining_func(logits, labels)
          loss2 = loss_func_metric(logits, labels, indices_tuple1)

          # Hierarchical_labels = inputs['label'].cpu().tolist()
          # Hierarchical_labels = list(map(label_change, Hierarchical_labels))
          # Hierarchical_labels = torch.Tensor(Hierarchical_labels).to(labels.device)
          # indices_tuple2 = mining_func(logits, Hierarchical_labels)
          # loss2 = loss_func_metric(logits, Hierarchical_labels, indices_tuple2)

          loss = loss1+loss2
          # loss = loss1

          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          scheduler1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          scheduler2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits, _ = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)

  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

### Mix/Soft prompt Soft + medium level Metric Learning（7）Proxy Anchor

In [None]:
# Mix prompt Soft + medium level Metric Learning
if training_type == 'medium_level'and few_shot_train:
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits, _ = prompt_model(inputs)
          labels = inputs['label']
          Hierarchical_labels = inputs['label'].cpu().tolist()
          Hierarchical_labels = list(map(label_change_mid, Hierarchical_labels))
          Hierarchical_labels = torch.Tensor(Hierarchical_labels).to(labels.device)
          # 交叉损失
          loss1 = loss_func(logits, labels)
          # 对比学习
          indices_tuple = mining_func(logits, Hierarchical_labels)
          loss2 = loss_func_metric(logits, Hierarchical_labels, indices_tuple)
          loss = loss1 + loss2 
          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          scheduler1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          scheduler2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits, _ = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)
                
  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

### Mix/Soft prompt Soft + high level Metric Learning(4) TML

In [None]:
# Mix prompt Soft + Metric Learning
if training_type == 'high_level'and few_shot_train:
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits, _ = prompt_model(inputs)
          labels = inputs['label']
          # 交叉损失
          loss1 = loss_func(logits, labels)
          # 对比学习 TML
          Hierarchical_labels = inputs['label'].cpu().tolist()
          Hierarchical_labels_mid = list(map(label_change_mid, Hierarchical_labels))
          Hierarchical_labels_mid = torch.Tensor(Hierarchical_labels_mid).to(labels.device)
          indices_tuple = mining_func(logits, Hierarchical_labels_mid)
          loss2 = loss_func_metric(logits, Hierarchical_labels_mid, indices_tuple)

          Hierarchical_labels_high = list(map(label_change_high, Hierarchical_labels))
          Hierarchical_labels_high = torch.Tensor(Hierarchical_labels_high).to(labels.device)
          indices_tuple = mining_func(logits, Hierarchical_labels_high)
          loss3 = loss_func_metric(logits, Hierarchical_labels_high, indices_tuple)

          loss = loss1 + 1/7*loss2 + 1/4*loss3
          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          scheduler1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          scheduler2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits, _ = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)

  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

### Mix prompt Soft + Metric Learning bank_type

In [None]:
# Mix prompt Soft + Metric Learning bank size
if training_type == 'bank'and few_shot_train:
  step_index = int(bank_size/batch_size_)
  logits_stack = []
  labels_stack = []
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits = prompt_model(inputs)
          labels = inputs['label']
          loss1 = loss_func(logits, labels)
          
          for logit, label in zip(logits.cpu().tolist(), labels.cpu().tolist()):
            logits_stack.append(logit)
            labels_stack.append(label)

          if step % step_index == 1:
            # print("bank_size"+str(len(labels_stack)))
            # print(len(labels_stack))
            # print(labels_stack)
            Hierarchical_labels = list(map(label_change,labels_stack))

            Hierarchical_labels = torch.Tensor(Hierarchical_labels).to(labels.device)
            # print(Hierarchical_labels)
            logits_stack = torch.Tensor(logits_stack).to(labels.device)

            # contradict loss
            indices_tuple = mining_func(logits_stack, Hierarchical_labels)
            loss2 = loss_func_metric(logits_stack, Hierarchical_labels, indices_tuple)
            # print(loss2)

            loss = loss1 + loss2

            logits_stack = []
            labels_stack = [] 
          else:
            loss = loss1
  
          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          scheduler1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          scheduler2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)

  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

### Multiple Metric Learning

In [None]:
 # Multiple Metric Learning
if training_type == 'multiple'and few_shot_train:
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits = prompt_model(inputs)
          labels = inputs['label']

          indices_tuple1 = mining_func(logits, labels)
          loss1 = loss_func_metric(logits, labels, indices_tuple1)

          Hierarchical_labels = inputs['label'].cpu().tolist()
          Hierarchical_labels = list(map(label_change, Hierarchical_labels))
          Hierarchical_labels = torch.Tensor(Hierarchical_labels).to(labels.device)
          indices_tuple2 = mining_func(logits, Hierarchical_labels)
          loss2 = loss_func_metric(logits, Hierarchical_labels, indices_tuple2)

          loss = loss1+loss2

          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          scheduler1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          scheduler2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)

  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

# Collapse

## Mix prompt Soft + Metric Learning bank + collapse

In [None]:
# Mix prompt Soft + Metric Learning bank size + collapse TODO
from torch import linalg as LA

if training_type == 'collapse'and few_shot_train:
  step_index = int(bank_size/batch_size_)+1
  logits_stack = []
  labels_stack = []
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits, _ = prompt_model(inputs)

          labels = inputs['label']

          # prompt loss / class loss
          loss1 = loss_func(logits, labels)

          # collapse loss
          w1 = prompt_model.plm.fc1.weight.cpu()
          diff_matrix1 = torch.matmul(w1,w1.transpose(1,0))-torch.eye(w1.shape[0])
          loss_w1 = LA.matrix_norm(diff_matrix1)


          w2 = prompt_model.plm.fc2.weight.cpu()
          diff_matrix2 = w2-torch.eye(w2.shape[0])
          loss_w2 = LA.matrix_norm(diff_matrix2)

          w3 = prompt_model.plm.fc3.weight.cpu()
          diff_matrix3 = torch.matmul(w3,w3.transpose(1,0))-torch.eye(w3.shape[0])
          loss_w3 = LA.matrix_norm(diff_matrix3)

          loss_collapse =  loss_w1 + loss_w2 + loss_w3
          
          for logit, label in zip(logits.cpu().tolist(), labels.cpu().tolist()):
            logits_stack.append(logit)
            labels_stack.append(label)

          if step % step_index == 1:
            if dataset_ == "go_emotions":
              Hierarchical_labels = list(map(label_change_high,labels_stack))
            elif dataset_ == "emotion":
              Hierarchical_labels = list(map(label_change,labels_stack))

            Hierarchical_labels = torch.Tensor(Hierarchical_labels).to(labels.device)
            logits_stack = torch.Tensor(logits_stack).to(labels.device)

            # contradict loss
            indices_tuple = mining_func(logits_stack, Hierarchical_labels)
            loss2 = loss_func_metric(logits_stack, Hierarchical_labels, indices_tuple)

            loss = loss1 + loss2 + loss_collapse

            logits_stack = []
            labels_stack = [] 
          else:
            loss = loss1 + loss_collapse

          # print("====================")
          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          scheduler1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          scheduler2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits, _ = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

                  

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              # f1_score = dev_res["weighted_f1"]
              f1_score = dev_res["weighted_f1"]

              print("Dev acc: "+str(acc_score))
              print("Dev weighted_f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)

  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

## Mix prompt Soft + SVD

In [None]:
from torch import linalg as LA

u_lst = []
sigma_lst = []
V_lst = []

if training_type == 'new_collapse'and few_shot_train:
  step_index = int(bank_size/batch_size_)+1
  #logits_stack = []
  #labels_stack = []
  start_time = time.time()
  for epoch in tqdm(range(epoch_num)):
      tot_loss = 0 
      for step, inputs in enumerate(train_dataloader):
          prompt_model.train()
          if use_cuda:
              inputs = inputs.cuda()
          logits, _ = prompt_model(inputs)
          labels = inputs['label']

          # prompt loss / class loss
          loss1 = loss_func(logits, labels)
          

          # print("\n"+"loss1")
          # print(loss1)

          # collapse loss
          w1 = prompt_model.plm.fc1.weight.cpu()
          diff_matrix1 = torch.matmul(w1,w1.transpose(1,0))-torch.eye(w1.shape[0])
          loss_w1 = LA.matrix_norm(diff_matrix1)

          print("\n"+"loss_U")
          print(loss_w1)
          print("\n"+"w_U")
          print(w1)

          w2 = prompt_model.plm.fc2.weight.cpu()
          # diff_matrix2 = w2-torch.eye(w2.shape[0])
          # # print(w2,shape)

          # loss_w2 = LA.matrix_norm(diff_matrix2)
          loss_w2 = LA.matrix_norm(w2)-1

          print("\n"+"loss_Sigma")
          print(loss_w2)
          print("\n"+"w_Sigma")
          print(w2)

          w3 = prompt_model.plm.fc3.weight.cpu()
          diff_matrix3 = torch.matmul(w3,w3.transpose(1,0))-torch.eye(w3.shape[0])
          loss_w3 = LA.matrix_norm(diff_matrix3)

          print("\n"+"loss_V")
          print(loss_w3)
          print("\n"+"w_V")
          print(w3)

          # TODO 
          # w0 = prompt_model.plm.bert.cpu() 
          # diff_matrix0 = w3 - w0
          # loss_w4 = LA.matrix_norm(diff_matrix0)

          loss_collapse =  loss_w1 + loss_w2 + loss_w3
          
          # for logit, label in zip(logits.cpu().tolist(), labels.cpu().tolist()):
          #   logits_stack.append(logit)
          #   labels_stack.append(label)

          # if step % step_index == 1:
          #   # print(len(labels_stack))
          #   # print(labels_stack)
          #   Hierarchical_labels = list(map(label_change,labels_stack))

          #   Hierarchical_labels = torch.Tensor(Hierarchical_labels).to(labels.device)
          #   # print(Hierarchical_labels)
          #   logits_stack = torch.Tensor(logits_stack).to(labels.device)

          #   # contradict loss
          #   indices_tuple = mining_func(logits_stack, Hierarchical_labels)
          #   loss2 = loss_func_metric(logits_stack, Hierarchical_labels, indices_tuple)
          #   # print(loss2)

          #   # print("\n"+"loss2")
          #   # print(loss2)
          #   loss = loss1 + loss2 + loss_collapse

          #   logits_stack = []
          #   labels_stack = [] 
          # else:
            # loss = loss1 + loss_collapse

          # loss = loss1 + loss_collapse
          loss = loss1

          # print("====================")
          loss.backward()
          tot_loss += loss.item()
          optimizer1.step()
          scheduler1.step()
          optimizer1.zero_grad()
          optimizer2.step()
          scheduler2.step()
          optimizer2.zero_grad()
          if step %100 == 1:
              print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)
              
          if step % 500 == 1:
              # evaluation
              prompt_model.eval()
              
              allpreds = []
              alllabels = []
              for step, inputs in enumerate(validation_dataloader):
                  if use_cuda:
                      inputs = inputs.cuda()
                  logits, _ = prompt_model(inputs)
                  labels = inputs['label']
                  alllabels.extend(labels.cpu().tolist())
                  allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

              # acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
              # print("Dev acc: "+str(acc))
              dev_res = compute_metrics(alllabels, allpreds)
              acc_score = dev_res["accuracy"]
              f1_score = dev_res["weighted_f1"]

              print("Dev acc: "+str(acc_score))
              print("Dev f1:"+str(f1_score))

              if min_f1 < f1_score:
                min_f1 = float(f1_score)
                print("save model in epoch:"+str(epoch))
                torch.save(prompt_model.state_dict(),model_save_pth)
                
  end_time = time.time()  
  complete_time = end_time - start_time
  print("running time: "+str(datetime.timedelta(seconds=complete_time)))

# Test & Zero shot

In [None]:
# Test Zero shot
test_dataloader = PromptDataLoader(
    dataset=dataset["test"], 
    template=mytemplate, 
    tokenizer=tokenizer, 
    tokenizer_wrapper_class=WrapperClass,
    max_seq_length=256,  
    decoder_max_length=3, 
    batch_size=batch_size_,
    shuffle=False, 
    teacher_forcing=False, 
    predict_eos_token=False,
    truncate_method="tail"
)

In [None]:
from tqdm import tqdm

if training_type == "man" and few_shot_train == False:
  test_model = prompt_model
else: 
  test_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=True)
  test_model.load_state_dict(torch.load(model_save_pth))

allpreds = []
all_logits = []
alllabels = []
all_mask = []

pbar = tqdm(test_dataloader)
for step, inputs in enumerate(pbar):
    test_model.eval()
    if use_cuda:
        inputs = inputs.cuda()

    # logits= prompt_model(inputs)
    # TokenSim需要用到中间输出 目前只有Bert和No collapse支持
    logits, outputs_mask = test_model(inputs)
    
    labels = inputs['label']

    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

    ## store for T-NSE
    all_logits.extend(logits.cpu().tolist())

    ## store for tokenUmi
    all_mask.extend(outputs_mask.cpu().tolist())

result = compute_metrics(alllabels, allpreds)
print(result)

In [None]:
from sklearn.metrics import classification_report

print(classification_report(alllabels, allpreds, target_names=labels_))

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from matplotlib import pyplot as plt

# 使用sklearn工具中confusion_matrix方法计算混淆矩阵
confusion_mat = confusion_matrix(alllabels, allpreds)
print("confusion_mat.shape : {}".format(confusion_mat.shape))
# print("confusion_mat : {}".format(confusion_mat))

In [None]:
# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵，参考plot_confusion_matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=labels_)

fig, ax = plt.subplots(figsize=(12,12))
disp.plot(
    include_values=True,            # 混淆矩阵每个单元格上显示具体数值
    cmap="YlGnBu",                 
    ax=ax,                        
    xticks_rotation="vertical",  
    values_format="d"               # 显示的数值格式
)
plt.show()

# 一维热力图化

In [None]:
import seaborn as sns

In [None]:
logits

In [None]:
mean_logits = torch.mean(logits, dim=0)
mean_logits

In [None]:
ax = sns.heatmap((mean_logits.cpu().detach().numpy()).reshape(1,-1), center=0)

In [None]:
ax = sns.heatmap(logits.cpu().detach().numpy(), center=0)

In [None]:
ax = sns.boxplot(data=logits.cpu().detach().numpy())

# t-SNE

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import manifold, datasets

# X, y = np.array(all_logits), alllabels
X, y = np.array(all_mask), alllabels


X.shape

In [None]:
all_mask[0]

In [None]:
import matplotlib
hex_list = []
for name, hex in matplotlib.colors.cnames.items():
    # print(name, hex)
    hex_list.append(str(hex))

In [None]:
len(hex_list)

In [None]:
'''t-SNE'''
tsne = manifold.TSNE(n_components=2, init='pca', random_state=42)
X_tsne = tsne.fit_transform(X)

print("Org data dimension is {}. Embedded data dimension is {}".format(X.shape[-1], X_tsne.shape[-1]))

'''嵌入空间可视化'''
x_min, x_max = X_tsne.min(0), X_tsne.max(0)
X_norm = (X_tsne - x_min) / (x_max - x_min)  # 归一化
pic_name_tSNE = "figTSNE/"+PLM_TYPE+"_"+str(training_type)+"_template_"+template_type+"_fewshot_"+str(num_examples_per_label_)+"_bankSize_"+str(bank_size)+"_dataset_"+dataset_+".png"
plt.figure(figsize=(9, 9))
for i in range(X_norm.shape[0]):
    if y[i] == 27:
      continue
    if dataset_ == 'go_emotions':
      plt.text(X_norm[i, 0], X_norm[i, 1], str(y[i]), color=hex_list[y[i]+1], 
              fontdict={'weight': 'bold', 'size': 9})
    else:
      plt.text(X_norm[i, 0], X_norm[i, 1], str(y[i]), color=hex_list[y[i]+10], 
          fontdict={'weight': 'bold', 'size': 9})
plt.xticks([])
plt.yticks([])
plt.savefig(pic_name_tSNE)
plt.show()

# TokenSim

In [None]:
from sklearn.decomposition import TruncatedSVD
import matplotlib.pyplot as plt
import numpy as np
import torch
import seaborn as sns
from numpy.linalg import norm
from numpy import dot
from scipy import stats

In [None]:
def singular_spectrum(W, norm=False): 
    if norm:
        W = W/np.trace(W)
    M = np.min(W.shape)
    svd = TruncatedSVD(n_components=M-1, n_iter=7, random_state=10)
    svd.fit(W) 
    svals = svd.singular_values_
    svecs = svd.components_
    return svals, svecs

In [None]:
def tokenSimi(tokens_matrix,seqlen=None,nopad=False):
    """calculate the average cosine similarity,with/without normalization"""

    # simi 计算的是两两之间的关系，而cls_simi计算的是cls位置与其他位置的关系
    simi = []
    cls_simi = []
    if nopad:
        l = seqlen
    else:
        l = tokens_matrix.shape[0]
    for i in range(l):
        for j in range(l):
            if i!=j:
                simi.append(dot(tokens_matrix[i],tokens_matrix[j])/(norm(tokens_matrix[i])*norm(tokens_matrix[j])))
    for i in range(l):
        cls_simi.append(dot(tokens_matrix[0],tokens_matrix[i])/(norm(tokens_matrix[0])*norm(tokens_matrix[i])))
    return sum(simi)/len(simi),sum(cls_simi)/len(cls_simi)

In [None]:
plt.style.use('default') 

def draw_hist(hidden_outputs, layer_name):
    #all the layers, sample_i = 0,
    sample_i = 0
    colors = ["blue","red","cyan","black","green"]
    # plt.subplots(figsize=(8, 4))
    sv = []
    tokenuni,clsuni = [],[]
    for i in range(len(hidden_outputs)):
      sv_batch, _ = singular_spectrum(hidden_outputs[i])
      # print(sv_batch)
      
      sv.extend(sv_batch)
      tokenUni1,clsUni1 = tokenSimi(hidden_outputs[i])
      tokenuni.append(tokenUni1)
      clsuni.append(clsUni1)

    #keep only one figure handler at once is better
    print("*****")
    print(training_type)
    print(layer_name)
    print(np.percentile(sv/max(sv), 10))
    print(np.percentile(sv/max(sv), 25))
    print(np.percentile(sv/max(sv), 50))
    print("*****")

    pic_name = "fig/"+PLM_TYPE+"_"+training_type+"_template_"+template_type+"_fewshot_"+str(num_examples_per_label_)+"_bankSize_"+str(bank_size)+"_dataset_"+dataset_+"_"+layer_name+".png"
    print(pic_name)
    # pic_name = "a0Epoch_{}Layer_{}Apply_{}Nonlinear_pdf_average_annotate.png".format(args.epoch,i_layer,args.apply_exrank1)
    fig, axs = plt.subplots(figsize=(8, 5))
    w1_stats= stats.describe(sv)
    plt.title("PDF for {} at Layer {}".format(training_type,layer_name))
    sns.distplot(sv/max(sv), hist=True, kde=True, 
        bins=50, color = 'darkblue', 
        hist_kws={'edgecolor':'black'},
        kde_kws={'linewidth': 2,"bw_adjust":0.1})
    
    # axs.text(axs.get_xlim()[1]*0.65,axs.get_ylim()[1]*0.6,"variance:"+str("%.2f"%w1_stats.variance)+"\n"+"skewness:"+str("%.2f"%w1_stats.skewness)+"\n"+"kurtosis:"+str("%.2f"%w1_stats.kurtosis)+"\n"+"tokenuni:"+str("%.2f"%(sum(tokenuni)/len(tokenuni)))+"\n"+"clsUni:"+str("%.2f"%(sum(clsuni)/len(clsuni))),fontsize = 14)
    axs.text(axs.get_xlim()[1]*0.65,axs.get_ylim()[1]*0.6,"skewness:"+str("%.2f"%w1_stats.skewness)+"\n"+"\n"+"tokenuni:"+str("%.2f"%(sum(tokenuni)/len(tokenuni)))+"\n",fontsize = 14)

    plt.xlabel("singular value",size=14)
    plt.ylabel("#singular value",size=14)
    axs.tick_params(axis='x', labelsize=14)
    axs.tick_params(axis='y', labelsize=14)

    median = np.percentile(sv/max(sv), 50)
    axs.vlines(x=median,ymin=0,ymax=0.98*axs.get_ylim()[1],color="red",label="median=%.4f"%median)
    axs.legend()
    axs.tick_params(axis='x', labelsize=14)

    plt.savefig(pic_name)

In [None]:
#len(np.array(all_mask))

In [None]:
#np.array(all_mask)[0]

In [None]:
#len(np.array((all_mask)[0]))

In [None]:
draw_hist([np.array(all_mask)],"MaskMid")

# 模型折线图TODO

# more detail
https://github.com/thunlp/OpenPrompt/blob/main/tutorial/4.1_all_tasks_are_generation.py
https://thunlp.github.io/OpenPrompt/modules/root.html?highlight=softtemplate