In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/stsb-data/test.txt
/kaggle/input/stsb-data/train.txt
/kaggle/input/stsb-data/dev.txt


In [2]:
import random
import re

import pandas as pd
import torch
from transformers import AutoTokenizer,AutoModel,AutoConfig
from torch.utils.data import Dataset,DataLoader
from torch import nn
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from scipy.stats import spearmanr

In [3]:
class P_dataset(Dataset):
    def __init__(self,df,tokenizer,max_len=128,dup_rate=0.15,mode='train',num=None):
        self.df = df
        self.mode = mode
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.dup_rate = dup_rate
        self.num = num
        if mode == 'train':
            self.listdf = self.get_listdf(df)
        else:
            assert num in [0,1,2]
    def get_listdf(self,df):
        lis = df['text_a'].tolist()
        lis.extend(df['text_b'].tolist())
        return lis
    def __getitem__(self, x):
        if self.mode == 'train':
            text = self.listdf[x]
            prompt_texts = self.template(text,mode='train')
            return prompt_texts
        else:
            data = self.df.iloc[x]
            text_a,text_b,label = data['text_a'],data['text_b'],data['label']

            prompt_text_a = self.template(text_a,mode='dev')
            prompt_text_b = self.template(text_b,mode='dev')

            return prompt_text_a+prompt_text_b,label
    def template(self,sentence,mode='train'):
        semtence_tem = []
        l = 2 if mode=='train' else 1
        for i in range(l):
            prompt_sentence,template_sentence = self.sample_template(sentence,is_repeat=(i==1))
            semtence_tem += [prompt_sentence, template_sentence]

        return semtence_tem
    def sample_template(self,sentence,is_repeat):
        prompt_num_words = []
        for i in range((len(sentence)+1)//2):
            prompt_num_words += [f'[prompt_{i}]']
        ask_choice = ['这句话的意思是','[prompt_81][prompt_82][prompt_83][prompt_84]','[prompt_71][prompt_72][prompt_73][prompt_74][prompt_75]']

        if self.mode=='train':
            sentence = self.word_repeat(sentence)
            pattern = random.choice(ask_choice)
            cho = random.sample(list(range(0,len(pattern))),(len(pattern)+1)//2)
            
            
            if pattern==ask_choice[0]:
                pattern = list(pattern)
                if is_repeat:
                    pattern = self.word_repeat(pattern)
                for i in cho:
                    pattern[i] = f'[prompt_{i+51}]'
#             elif pattern==ask_choice[1]:
#                 pattern = list(pattern)
#                 for i in cho:
#                     pattern[i] = f'[prompt_{i+61}]'
            else:
                pattern = re.findall(r'\[.+?\]', pattern)
                if is_repeat:
                    pattern = self.word_repeat(pattern)
        else:
            pattern = ask_choice[self.num]
            if pattern==ask_choice[0]:
                pattern = list(pattern)
#             elif pattern==ask_choice[1]:
#                 pattern = list(pattern)
            else:
                pattern = re.findall(r'\[.+?\]', pattern)
        prompt_sentence =  list(sentence) + pattern + ["[MASK]"] +['。']
        template_sentence =  ['[X]']*len(sentence) + pattern + ["[MASK]"] +['。']
        
        prompt_sentence = "".join(prompt_sentence)
        template_sentence = "".join(template_sentence)
        return prompt_sentence, template_sentence
    def word_repeat(self,text,dup_rate=0.15):
        dup_rate = min(1,dup_rate)
        text_tokens = list(text)
        conduct_num = random.randint(0,max(int(len(text)*dup_rate),1))
        if conduct_num==0:
            return text
        sample =random.sample(range(0,len(text)), conduct_num)
        for num in sample:
            text_tokens[num] += text_tokens[num]
        return text_tokens
    def prompt_repeat(self,text_tokens,dup_rate=0.1):
        if dup_rate<0.01:
            return "".join(text_tokens),"".join(text_tokens2)
        conduct_num = random.randint(0,max(int(len(text_tokens)*dup_rate),1))

        sample = random.sample(range(0, len(text_tokens)), conduct_num)

        for num in sample:
            text_tokens[num] += text_tokens[num]
        return "".join(text_tokens)
    def __len__(self):
        if self.mode=='train':
            return len(self.listdf)
        else:
            return len(self.df)
    def collate_fn(self,batch):
        if self.mode == 'dev':
            label = [x[1] for x in batch]
            batch = [x[0] for x in batch]

        prompt_1 = [x[0] for x in batch]
        template_1 = [x[1] for x in batch]
        prompt_2 = [x[2] for x in batch]
        template_2 = [x[3] for x in batch]

        batch_prompt_1 = self.tokenizer(prompt_1,truncation=True, max_length=self.max_len,padding='longest')
        batch_template_1 = self.tokenizer(template_1, truncation=True, max_length=self.max_len, padding='longest')
        batch_prompt_2 = self.tokenizer(prompt_2, truncation=True, max_length=self.max_len, padding='longest')
        batch_template_2 = self.tokenizer(template_2, truncation=True, max_length=self.max_len, padding='longest')

        return_dict= {
            "input_prompt_1":torch.as_tensor(batch_prompt_1['input_ids'],dtype=torch.long),
            "mask_prompt_1": torch.as_tensor(batch_prompt_1['attention_mask'],dtype=torch.long),
            "input_template_1": torch.as_tensor(batch_template_1['input_ids'],dtype=torch.long),
            "mask_template_1": torch.as_tensor(batch_template_1['attention_mask'],dtype=torch.long),
            "input_prompt_2": torch.as_tensor(batch_prompt_2['input_ids'],dtype=torch.long),
            "mask_prompt_2": torch.as_tensor(batch_prompt_2['attention_mask'],dtype=torch.long),
            "input_template_2": torch.as_tensor(batch_template_2['input_ids'],dtype=torch.long),
            "mask_template_2": torch.as_tensor(batch_template_2['attention_mask'],dtype=torch.long),
        }
        if self.mode=='dev':
            return_dict['label'] = torch.as_tensor(label, dtype=torch.long)
        return return_dict

In [4]:
def load_data(batch_size=30):
    train_df = pd.read_csv('/kaggle/input/stsb-data/train.txt', sep='\|\|', header=None,
                       names=['id', 'text_a', 'text_b', 'label'], engine='python')

    dev_df = pd.read_csv('/kaggle/input/stsb-data/dev.txt', sep='\|\|', header=None,
                       names=['id', 'text_a', 'text_b', 'label'], engine='python')
    tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')
    tokenizer.add_tokens('[X]')
    for i in range(0,98):
        tokenizer.add_tokens(f'[prompt_{i}]')
    train_dataset = P_dataset(train_df,tokenizer,mode='train')
    train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=train_dataset.collate_fn)
    dev_dataloaders = []
    
    for i in range(2):
        dev_dataset = P_dataset(dev_df, tokenizer, mode='dev',num=i)
        dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False, collate_fn=dev_dataset.collate_fn)
        dev_dataloaders.append(dev_dataloader)

    return train_dataloader,dev_dataloaders,tokenizer

In [5]:
class PromptBert(nn.Module):
    def __init__(self,model_path,dropout_prob,tokenizer):
        super(PromptBert, self).__init__()
        self.tokenizer = tokenizer

        #修dropout率
        conf = AutoConfig.from_pretrained(model_path)
        conf.attention_probs_dropout_prob = dropout_prob
        conf.hiddem_dropout_prob = dropout_prob

        self.bert = AutoModel.from_pretrained(model_path,config=conf)

        self.mask_id = self.tokenizer.convert_tokens_to_ids('[MASK]')

        self.bert.resize_token_embeddings(len(self.tokenizer))


    def forward(self,input_prompt,mask_prompt,input_template,mask_template):
        promot_out = self.cal_mask_embedding(input_prompt,mask_prompt)
        template_out = self.cal_mask_embedding(input_template,mask_template)
        return promot_out-template_out
    def cal_mask_embedding(self,input_ids,mask):
        last_hidden_state,_ = self.bert(input_ids,mask,return_dict=False)

        mask_index = (input_ids==self.mask_id).long()
        mask_index = mask_index.unsqueeze(-1).expand(last_hidden_state.shape).float()

        return torch.sum(last_hidden_state*mask_index,dim=1)

In [6]:
def cal_loss(query,key,tao=0.05):
    query = F.normalize(query,dim=1)
    key = F.normalize(key, dim=1)

    N,D = query.shape

    batch_pos = torch.exp(torch.div(torch.bmm(query.view(N,1,D),key.view(N,D,1)).view(N,1),tao))

    batch_all = torch.sum(torch.exp(torch.div(torch.mm(query,torch.t(key)),tao)),dim=1)

    loss = torch.mean(-torch.log(torch.div(batch_pos,batch_all)))
    return loss

In [7]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [8]:
from torch.optim.lr_scheduler import StepLR
if __name__ == '__main__':
    seed_everything(3407)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(device)
    # device = "cpu"
    lr = 2e-5
    model_path = 'hfl/chinese-roberta-wwm-ext'
    epoch = 10
    batch_size = 16

    train_loader, dev_loaders,tokenizer = load_data(batch_size=batch_size)
    model = PromptBert(model_path, dropout_prob=0.2, tokenizer=tokenizer).to(device)
    opt = torch.optim.AdamW(model.parameters(),lr=lr)
    sch = StepLR(opt, step_size=2, gamma=0.8)
    best_corr = -10
    for e in range(epoch):
        model.train()
        tq = tqdm(train_loader)
        for data in tq:
            opt.zero_grad()
            input_prompt_1 = data['input_prompt_1'].to(device)
            mask_prompt_1 = data['mask_prompt_1'].to(device)
            input_template_1 = data['input_template_1'].to(device)
            mask_template_1 = data['mask_template_1'].to(device)

            input_prompt_2 = data['input_prompt_2'].to(device)
            mask_prompt_2 = data['mask_prompt_2'].to(device)
            input_template_2 = data['input_template_2'].to(device)
            mask_template_2 = data['mask_template_2'].to(device)

            query = model(input_prompt_1, mask_prompt_1, input_template_1, mask_template_1)
            key = model(input_prompt_2, mask_prompt_2, input_template_2, mask_template_2)

            loss =  cal_loss(query,key)
            loss.backward()
            opt.step()

            tq.update()
            tq.set_description(f'e={e} loss={loss.item():.6f}')
        sch.step()
        
        all_pre = None
        all_label = None
        model.eval()
        for i in range(2):
            predict = []
            lab = []
            for data in tqdm(dev_loaders[i]):
                input_prompt_1 = data['input_prompt_1'].to(device)
                mask_prompt_1 = data['mask_prompt_1'].to(device)
                input_template_1 = data['input_template_1'].to(device)
                mask_template_1 = data['mask_template_1'].to(device)

                input_prompt_2 = data['input_prompt_2'].to(device)
                mask_prompt_2 = data['mask_prompt_2'].to(device)
                input_template_2 = data['input_template_2'].to(device)
                mask_template_2 = data['mask_template_2'].to(device)

                label = data['label']
                with torch.no_grad():
                    outa = model(input_prompt_1, mask_prompt_1, input_template_1, mask_template_1)
                    outb = model(input_prompt_2, mask_prompt_2, input_template_2, mask_template_2)

                    sim = torch.cosine_similarity(outa,outb,dim=-1)

                    pre = sim.detach().cpu().numpy()

                    predict.extend(pre)
                    lab.extend(label)
            if all_pre==None:
                all_pre = predict
                all_label = lab
            else:
                all_pre = [i+j for i,j in zip(all_pre, predict)]
        spearman_corr, _ = spearmanr(all_pre, all_label)
        if best_corr<spearman_corr:
            print(f'Spearman Correlation:={spearman_corr}----------->best')
            torch.save(model.state_dict(),'best.pt')
            best_corr=spearman_corr
        else:
            print(f'Spearman Correlation:={spearman_corr}')

cuda


Downloading (…)okenizer_config.json:   0%|          | 0.00/19.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt: 0.00B [00:00, ?B/s]

Downloading (…)/main/tokenizer.json: 0.00B [00:00, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


Downloading pytorch_model.bin:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
e=0 loss=0.000122: 100%|██████████| 654/654 [04:26<00:00,  2.45it/s]
100%|██████████| 92/92 [00:10<00:00,  8.89it/s]
100%|██████████| 92/92 [00:0

Spearman Correlation:=0.7898678303196947----------->best


e=1 loss=0.000066: 100%|██████████| 654/654 [04:25<00:00,  2.47it/s]
100%|██████████| 92/92 [00:10<00:00,  8.95it/s]
100%|██████████| 92/92 [00:09<00:00,  9.39it/s]


Spearman Correlation:=0.7918495658374547----------->best


e=2 loss=0.000065: 100%|██████████| 654/654 [04:25<00:00,  2.47it/s]
100%|██████████| 92/92 [00:10<00:00,  8.94it/s]
100%|██████████| 92/92 [00:09<00:00,  9.28it/s]


Spearman Correlation:=0.7883808274994062


e=3 loss=0.000197: 100%|██████████| 654/654 [04:25<00:00,  2.46it/s]
100%|██████████| 92/92 [00:10<00:00,  8.89it/s]
100%|██████████| 92/92 [00:09<00:00,  9.38it/s]


Spearman Correlation:=0.8013330544282936----------->best


e=4 loss=0.000084: 100%|██████████| 654/654 [04:26<00:00,  2.46it/s]
100%|██████████| 92/92 [00:10<00:00,  8.96it/s]
100%|██████████| 92/92 [00:09<00:00,  9.37it/s]


Spearman Correlation:=0.7804322820206404


e=5 loss=0.000413: 100%|██████████| 654/654 [04:25<00:00,  2.46it/s]
100%|██████████| 92/92 [00:10<00:00,  8.95it/s]
100%|██████████| 92/92 [00:09<00:00,  9.38it/s]


Spearman Correlation:=0.7878814791936224


e=6 loss=0.000135: 100%|██████████| 654/654 [04:25<00:00,  2.46it/s]
100%|██████████| 92/92 [00:10<00:00,  8.95it/s]
100%|██████████| 92/92 [00:09<00:00,  9.44it/s]


Spearman Correlation:=0.7800612138862777


e=7 loss=0.000616: 100%|██████████| 654/654 [04:24<00:00,  2.47it/s]
100%|██████████| 92/92 [00:10<00:00,  8.83it/s]
100%|██████████| 92/92 [00:09<00:00,  9.40it/s]


Spearman Correlation:=0.7787136622796977


e=8 loss=0.000058: 100%|██████████| 654/654 [04:25<00:00,  2.46it/s]
100%|██████████| 92/92 [00:10<00:00,  8.76it/s]
100%|██████████| 92/92 [00:09<00:00,  9.29it/s]


Spearman Correlation:=0.7380896949477198


e=9 loss=0.000149: 100%|██████████| 654/654 [04:25<00:00,  2.46it/s]
100%|██████████| 92/92 [00:10<00:00,  8.87it/s]
100%|██████████| 92/92 [00:09<00:00,  9.30it/s]

Spearman Correlation:=0.771403731383638



