In [None]:
import json
import sys
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
import logging
from datetime import datetime
import gzip
import os
import tarfile
from collections import defaultdict
from torch.utils.data import IterableDataset
import tqdm
from torch.utils.data import Dataset
import random
import pickle
import argparse
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
with open ("train.json", "r") as f:
    lines = f.readlines()
train_examples = []
for line in lines:
    # print(line)
    sample=json.loads(line)
    ss = "[CLS] "
    for c in sample["context"]:
        ss = ss + c + " [SEP] "
    context = ss.strip()
    for i, p in enumerate(sample["positive_responses"]):
        pos = "[CLS] " + p + " [SEP]"
        for j, n in enumerate(sample["adversarial_negative_responses"]):
            neg = "[CLS] " + n + " [SEP]"
            train_examples.append(InputExample(texts=[context, pos, neg]))

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

model_name = "sentence-t5-base"

train_batch_size = 16           #Increasing the train batch size improves the model performance, but requires more GPU memory
max_seq_length = 512            #Max length for passages. Increasing it, requires more GPU memory
# ce_score_margin = args.ce_score_margin             #Margin for the CrossEncoder score between negative and positive passages
# num_negs_per_system = args.num_negs_per_system         # We used different systems to mine hard negatives. Number of hard negatives to add from each system
num_epochs = 10                 # Number of epochs we want to train

use_pre_trained_model = True
# Load our embedding model
if use_pre_trained_model:
    logging.info("use pretrained SBERT model")
    model = SentenceTransformer(model_name)
    model.max_seq_length = max_seq_length
else:
    logging.info("Create new SBERT model")
    word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling)
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

model_save_path = 'output/train_bi-encoder-mnrl-disentangle-{}-{}'.format(model_name.replace("/", "-"),  datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))


### Now we read the MS Marco dataset
# data_folder = 'msmarco-data'

# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
# train_dataset = MSMARCODataset(train_queries, corpus=corpus)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MultipleNegativesRankingLoss(model=model)

# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          warmup_steps=1000,
          use_amp=True,
          checkpoint_path=model_save_path,
          checkpoint_save_steps=len(train_dataloader),
          optimizer_params = {'lr': 2e-5},
          )

# Save the model
model.save(model_save_path)


In [None]:
# 1.先对文本进行constrastive learning，并对sentence transformer模型的输出加一层线性层
# 2.对线性层的输出再进行一次contrastive learning，保证线性层的输出与sentence transformer模型的输出差别不大，此时冻结sentence transformer模型训练，
#   包括以后的步骤也会冻结sentence transformer模型
# 3，对线性层的输出进行切分：robust与non robust，并做contrastive learning
#    3.1 response内部contrastive learning：robust与non robust远离
#    3.2 不同response之间contrastive learning：robust与robust相近
#  对于这一步，设立一个分类器



In [None]:
import json
import sys
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import torch
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
import logging
from datetime import datetime
import gzip
import os
import tarfile
from collections import defaultdict
from torch.utils.data import IterableDataset
import tqdm
from torch.utils.data import Dataset
import random
import pickle
import argparse
from enum import Enum
# import torch.nn as nn
from torch import nn, Tensor
import torch.nn.functional as F

In [None]:
class Classifier(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(Classifier, self).__init__()
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, hidden_states, labels):
        # print(hidden_states.shape)
        n = hidden_states.shape[0]
        a = torch.ones(n)
        labels = labels * a
        labels = labels.long()
        logits = self.classifier(hidden_states)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return logits, loss

class SiameseDistanceMetric(Enum):
    """
    The metric for the contrastive loss
    """
    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
    COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)

class Second_MultipleNegativesRankingLoss(nn.Module):
    """
        This loss expects as input a batch consisting of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n)
        where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.

        For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and
        n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.

        This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc))
        as it will sample in each batch n-1 negative docs randomly.

        The performance usually increases with increasing batch sizes.

        For more information, see: https://arxiv.org/pdf/1705.00652.pdf
        (Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4)

        You can also provide one or multiple hard negatives per anchor-positive pair by structering the data like this:
        (a_1, p_1, n_1), (a_2, p_2, n_2)

        Here, n_1 is a hard negative for (a_1, p_1). The loss will use for the pair (a_i, p_i) all p_j (j!=i) and all n_j as negatives.

        Example::

            from sentence_transformers import SentenceTransformer, losses, InputExample
            from torch.utils.data import DataLoader

            model = SentenceTransformer('distilbert-base-uncased')
            train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']),
                InputExample(texts=['Anchor 2', 'Positive 2'])]
            train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
            train_loss = losses.MultipleNegativesRankingLoss(model=model)
    """
    def __init__(self, scale: float = 20.0, similarity_fct = util.cos_sim):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
        """
        super(Second_MultipleNegativesRankingLoss, self).__init__()
        # self.model = model
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        # self.distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
        # self.classify_model = nn.Linear(2 * model.get_sentence_embedding_dimension(), 3)


    def forward(self, reps):
        # reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        embeddings_a = reps[0]
        embeddings_b = torch.cat(reps[1:])
        # print(embeddings_a.shape, embeddings_b.shape)
        scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)  # Example a[i] should match with b[i]
        return self.cross_entropy_loss(scores, labels)

    def get_config_dict(self):
        return {'scale': self.scale, 'similarity_fct': self.similarity_fct.__name__}

class test_ContrastiveLoss(nn.Module):
    """
    Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
    two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.

    Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

    :param model: SentenceTransformer model
    :param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
    :param margin: Negative samples (label == 0) should have a distance of at least the margin value.
    :param size_average: Average by the size of the mini-batch.

    Example::

        from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
        from torch.utils.data import DataLoader

        model = SentenceTransformer('all-MiniLM-L6-v2')
        train_examples = [
            InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
            InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]

        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
        train_loss = losses.ContrastiveLoss(model=model)

        model.fit([(train_dataloader, train_loss)], show_progress_bar=True)

    """

    def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True):
        super(test_ContrastiveLoss, self).__init__()
        self.distance_metric = distance_metric
        self.margin = margin
        self.size_average = size_average

    def get_config_dict(self):
        distance_metric_name = self.distance_metric.__name__
        for name, value in vars(SiameseDistanceMetric).items():
            if value == self.distance_metric:
                distance_metric_name = "SiameseDistanceMetric.{}".format(name)
                break

        return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average}

    def forward(self, reps, labels):
        # reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        # assert len(reps) == 2
        labels = torch.tensor(labels)
        rep_anchor, rep_other = reps
        # print(rep_anchor.shape,rep_other.shape)

        distances = self.distance_metric(rep_anchor, rep_other)
        # print(distances)
        losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
        # print(losses)
        return losses.mean() if self.size_average else losses.sum()

model_encoder = SentenceTransformer("output/train_bi-encoder-mnrl-bert-base-uncased-2023-08-29_15-22-44/86808").cuda()
class Disentangle_Model(nn.Module):
    def __init__(self, max_seq_length=1024, batch_size=32, num_labels=3):
        super(Disentangle_Model, self).__init__()
        # self.model_name_or_path = model_name_or_path
        # self.device = device
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        self.num_labels = num_labels
        # self.model_encoder = SentenceTransformer("output/train_bi-encoder-mnrl-bert-base-uncased-2023-08-29_15-22-44/86808")
        self.linear = nn.Linear(768, 768)
        self.second_mutiple_negatives_ranking_loss = Second_MultipleNegativesRankingLoss()
        self.contrastive_loss = test_ContrastiveLoss()
        self.classifier = Classifier(int(1.5 * 768), self.num_labels)
        self.loss_fct = nn.CrossEntropyLoss()
        # self.optimizer = AdamW(self.parameters(), lr=2e-5, eps=1e-8)
        # self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=10000)

    def forward(self, batch):
        # self.train()
        # print(len(batch[0]))
        with torch.no_grad():
            context_embedding = torch.tensor(model_encoder.encode(batch[0])).cuda()
            pos_response_embedding = torch.tensor(model_encoder.encode(batch[1])).cuda()
            neg_response_embedding = torch.tensor(model_encoder.encode(batch[2])).cuda()
        # print(neg_response_embedding.device)
        # print(context_embedding.shape, pos_response_embedding.shape, neg_response_embedding.shape)
        context_embedding = self.linear(context_embedding)
        pos_response_embedding = self.linear(pos_response_embedding)
        neg_response_embedding = self.linear(neg_response_embedding)
        # print(context_embedding.shape, pos_response_embedding.shape, neg_response_embedding.shape)
        reps = [context_embedding, pos_response_embedding, neg_response_embedding]
        contrastive_loss_1 = self.second_mutiple_negatives_ranking_loss(reps)
        # 3.对线性层的输出进行切分：robust与non robust，并做contrastive learning
        pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
        neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)
        pos = [pos_non_robust_embedding, pos_robust_embedding]
        # print(pos_non_robust_embedding.shape, pos_robust_embedding.shape)
        pos_inside_contrastive_loss = self.contrastive_loss(pos, 0.0)
        neg = [neg_non_robust_embedding, neg_robust_embedding]
        neg_inside_contrastive_loss = self.contrastive_loss(neg, 0.0)
        outside_robust_contrastive_loss = self.contrastive_loss([pos_robust_embedding, neg_robust_embedding], 0.0)
        # outside_diff_contrastive_loss = self.contrastive_loss([pos_robust_embedding, neg_non_robust_embedding], 0.0)
        outside_non_robust_contrastive_loss = self.contrastive_loss([pos_non_robust_embedding, neg_non_robust_embedding], 1.0)
        # 4.对于这一步，设立一个分类器
        hidden_state = torch.cat([context_embedding, pos_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 1)
        classification_loss_1 = outputs[1]

        hidden_state = torch.cat([context_embedding, pos_non_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 2)
        classification_loss_2 = outputs[1]

        hidden_state = torch.cat([context_embedding, neg_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 0)
        classification_loss_3 = outputs[1]

        hidden_state = torch.cat([context_embedding, neg_non_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 2)
        classification_loss_4 = outputs[1]

        loss = contrastive_loss_1 + pos_inside_contrastive_loss + neg_inside_contrastive_loss + outside_robust_contrastive_loss + outside_non_robust_contrastive_loss + classification_loss_1 + classification_loss_2 + classification_loss_3 + classification_loss_4
        return loss

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))

In [None]:
a = torch.zeros([30-20])
print(a)

In [None]:
# pad_sequence,补0到同一长度，并对补0后的robust rep与context做contrastive learning
from torch.nn.utils.rnn import pad_sequence
class Classifier(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(Classifier, self).__init__()
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, hidden_states, labels):
        # print(hidden_states.shape)
        n = hidden_states.shape[0]
        a = torch.ones(n)
        labels = labels * a
        labels = labels.long()
        logits = self.classifier(hidden_states)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return logits, loss

class SiameseDistanceMetric(Enum):
    """
    The metric for the contrastive loss
    """
    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
    COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)

class Second_MultipleNegativesRankingLoss(nn.Module):
    """
        This loss expects as input a batch consisting of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n)
        where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.

        For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and
        n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.

        This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc))
        as it will sample in each batch n-1 negative docs randomly.

        The performance usually increases with increasing batch sizes.

        For more information, see: https://arxiv.org/pdf/1705.00652.pdf
        (Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4)

        You can also provide one or multiple hard negatives per anchor-positive pair by structering the data like this:
        (a_1, p_1, n_1), (a_2, p_2, n_2)

        Here, n_1 is a hard negative for (a_1, p_1). The loss will use for the pair (a_i, p_i) all p_j (j!=i) and all n_j as negatives.

        Example::

            from sentence_transformers import SentenceTransformer, losses, InputExample
            from torch.utils.data import DataLoader

            model = SentenceTransformer('distilbert-base-uncased')
            train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']),
                InputExample(texts=['Anchor 2', 'Positive 2'])]
            train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
            train_loss = losses.MultipleNegativesRankingLoss(model=model)
    """
    def __init__(self, scale: float = 20.0, similarity_fct = util.cos_sim):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
        """
        super(Second_MultipleNegativesRankingLoss, self).__init__()
        # self.model = model
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        # self.distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
        # self.classify_model = nn.Linear(2 * model.get_sentence_embedding_dimension(), 3)


    def forward(self, reps):
        # reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        embeddings_a = reps[0]
        embeddings_b = torch.cat(reps[1:])
        # print(embeddings_a.shape, embeddings_b.shape)
        scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)  # Example a[i] should match with b[i]
        return self.cross_entropy_loss(scores, labels)

    def get_config_dict(self):
        return {'scale': self.scale, 'similarity_fct': self.similarity_fct.__name__}

class test_ContrastiveLoss(nn.Module):
    """
    Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
    two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.

    Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

    :param model: SentenceTransformer model
    :param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
    :param margin: Negative samples (label == 0) should have a distance of at least the margin value.
    :param size_average: Average by the size of the mini-batch.

    Example::

        from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
        from torch.utils.data import DataLoader

        model = SentenceTransformer('all-MiniLM-L6-v2')
        train_examples = [
            InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
            InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]

        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
        train_loss = losses.ContrastiveLoss(model=model)

        model.fit([(train_dataloader, train_loss)], show_progress_bar=True)

    """

    def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True):
        super(test_ContrastiveLoss, self).__init__()
        self.distance_metric = distance_metric
        self.margin = margin
        self.size_average = size_average

    def get_config_dict(self):
        distance_metric_name = self.distance_metric.__name__
        for name, value in vars(SiameseDistanceMetric).items():
            if value == self.distance_metric:
                distance_metric_name = "SiameseDistanceMetric.{}".format(name)
                break

        return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average}

    def forward(self, reps, labels):
        # reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        # assert len(reps) == 2
        labels = torch.tensor(labels)
        rep_anchor, rep_other = reps
        # print(rep_anchor.shape,rep_other.shape)

        distances = self.distance_metric(rep_anchor, rep_other)
        # print(distances)
        losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
        # print(losses)
        return losses.mean() if self.size_average else losses.sum()

model_encoder = SentenceTransformer("output/train_bi-encoder-mnrl-distilbert-base-uncased-2023-08-23_21-32-34/72340").cuda()
class Disentangle_Model(nn.Module):
    def __init__(self, max_seq_length=1024, batch_size=32, num_labels=3):
        super(Disentangle_Model, self).__init__()
        # self.model_name_or_path = model_name_or_path
        # self.device = device
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        self.num_labels = num_labels
        # self.model_encoder = SentenceTransformer("output/train_bi-encoder-mnrl-bert-base-uncased-2023-08-29_15-22-44/86808")
        self.linear = nn.Linear(768, 768)
        self.second_mutiple_negatives_ranking_loss = Second_MultipleNegativesRankingLoss()
        self.contrastive_loss = test_ContrastiveLoss()
        self.classifier = Classifier(1536, self.num_labels)
        self.loss_fct = nn.CrossEntropyLoss()
        # self.optimizer = AdamW(self.parameters(), lr=2e-5, eps=1e-8)
        # self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=10000)

    def forward(self, batch):
        # self.train()
        # print(len(batch[0]))
        with torch.no_grad():
            context_embedding = torch.tensor(model_encoder.encode(batch[0])).cuda()
            pos_response_embedding = torch.tensor(model_encoder.encode(batch[1])).cuda()
            neg_response_embedding = torch.tensor(model_encoder.encode(batch[2])).cuda()
        # print(neg_response_embedding.device)
        # print(context_embedding.shape, pos_response_embedding.shape, neg_response_embedding.shape)
        context_embedding = self.linear(context_embedding)
        pos_response_embedding = self.linear(pos_response_embedding)
        neg_response_embedding = self.linear(neg_response_embedding)
        # print(context_embedding.shape, pos_response_embedding.shape, neg_response_embedding.shape)
        reps = [context_embedding, pos_response_embedding, neg_response_embedding]
        contrastive_loss_1 = self.second_mutiple_negatives_ranking_loss(reps)
        # 3.对线性层的输出进行切分：robust与non robust，并做contrastive learning
        pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)

        pos_pad = torch.zeros([pos_robust_embedding.size(0), 384]).cuda()
        pos_robust_embedding = torch.cat([pos_robust_embedding, pos_pad], dim=1)
        pos_non_robust_embedding = torch.cat([pos_non_robust_embedding, pos_pad], dim=1)

        neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)
        neg_robust_embedding = torch.cat([neg_robust_embedding, pos_pad], dim=1)
        neg_non_robust_embedding = torch.cat([neg_non_robust_embedding, pos_pad], dim=1)
        pos = [pos_non_robust_embedding, pos_robust_embedding]
        # print(pos_non_robust_embedding.shape, pos_robust_embedding.shape)
        pos_inside_contrastive_loss = self.contrastive_loss(pos, 0.0)
        neg = [neg_non_robust_embedding, neg_robust_embedding]
        neg_inside_contrastive_loss = self.contrastive_loss(neg, 0.0)
        outside_robust_contrastive_loss = self.contrastive_loss([pos_robust_embedding, neg_robust_embedding], 0.0)
        # outside_diff_contrastive_loss = self.contrastive_loss([pos_robust_embedding, neg_non_robust_embedding], 0.0)
        outside_non_robust_contrastive_loss = self.contrastive_loss([pos_non_robust_embedding, neg_non_robust_embedding], 1.0)
        
        reps = [context_embedding, pos_robust_embedding, neg_robust_embedding]
        contrastive_loss_2 = self.second_mutiple_negatives_ranking_loss(reps)
        # 4.对于这一步，设立一个分类器
        hidden_state = torch.cat([context_embedding, pos_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 1)
        classification_loss_1 = outputs[1]

        hidden_state = torch.cat([context_embedding, pos_non_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 2)
        classification_loss_2 = outputs[1]

        hidden_state = torch.cat([context_embedding, neg_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 0)
        classification_loss_3 = outputs[1]

        hidden_state = torch.cat([context_embedding, neg_non_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 2)
        classification_loss_4 = outputs[1]

        loss = contrastive_loss_1 + contrastive_loss_2 + pos_inside_contrastive_loss + neg_inside_contrastive_loss + outside_robust_contrastive_loss + outside_non_robust_contrastive_loss + classification_loss_1 + classification_loss_2 + classification_loss_3 + classification_loss_4
        return loss

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))

In [None]:
disentangle_model = Disentangle_Model()
state_dict = torch.load("./disentangle_model/checkpoint_distill_bert_pad.bin")
disentangle_model.load_state_dict(state_dict["model_state_dict"])

In [None]:
print(disentangle_model)

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(MLP, self).__init__()
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels
        
    def forward(self, hidden_states, labels):
        logits = self.classifier(hidden_states)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return logits, loss
classification_model = MLP(1536, 2).cuda()
classification_model_path = "./disentangle_model/classification_checkpoint_pad.bin"
state_dict = torch.load(classification_model_path)
classification_model.load_state_dict(state_dict["model_state_dict"])
m = nn.Softmax(dim=-1)

In [None]:
import json
import numpy as np
from openTSNE import TSNE
from openTSNE.affinity import PerplexityBasedNN
from openTSNE import initialization
# from openTSNE.callbacks import ErrorLogger
import utils
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances

In [None]:
with open ("test.json", "r") as f:
    dev_lines = f.readlines()
dev_examples = []
context_list = []
pos_list = []
neg_list = []
line = dev_lines[2]
y_list = []

# print(line)
sample=json.loads(line)
ss = "[CLS] "
for c in sample["context"]:
    ss = ss + c + " [SEP] "
context = ss.strip()
y_list.append(0)
for i, p in enumerate(sample["positive_responses"]):
    pos = "[CLS] " + p + " [SEP]"
    pos_list.append(pos)
    y_list.append(1)
for j, n in enumerate(sample["adversarial_negative_responses"]):
    neg = "[CLS] " + n + " [SEP]"
    neg_list.append(neg)
    y_list.append(2)

In [None]:
print("context:")
print(context)
print("\n")
# g = "[CLS] " + "How delightful, considering your class usually induces the enthusiasm of a sloth on a lazy Sunday." + " [SEP]"
# neg_list.append(g)
# y_list.append(3)

print("pos:")
for i, pos in enumerate(pos_list):
    print(i, pos)
print("\n")

print("neg:")
for i, neg in enumerate(neg_list):
    print(i, neg)

In [None]:
with open ("test.json", "r") as f:
    dev_lines = f.readlines()
dev_examples = []
context_list = []

# line = dev_lines[1]

d_pos_list = []
d_neg_list = []
# print(line)
for line in dev_lines:
    sample=json.loads(line)
    ss = "[CLS] "
    for c in sample["context"]:
        ss = ss + c + " [SEP] "
    context = ss.strip()
    pos_list = []
    neg_list = []
    reference_list = []
    for i, p in enumerate(sample["positive_responses"][:3]):
        pos = "[CLS] " + p + " [SEP]"
        reference_list.append(pos)
    for i, p in enumerate(sample["positive_responses"][3:]):
        pos = "[CLS] " + p + " [SEP]"
        pos_list.append(pos)
    for j, n in enumerate(sample["adversarial_negative_responses"]):
        neg = "[CLS] " + n + " [SEP]"
        neg_list.append(neg)
    with torch.no_grad():
        context_embedding = model_encoder.encode(context)
        context_embedding = disentangle_model.linear(torch.tensor(context_embedding))
        context_embedding = context_embedding.reshape(1,-1)
        d_list = []
        for pos in reference_list:
            pos_embedding = model_encoder.encode(pos)       
            pos_response_embedding = disentangle_model.linear(torch.tensor(pos_embedding))        
            pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
            pos_robust_pad = torch.zeros([384,])
            pos_robust_embedding = torch.cat([pos_robust_embedding, pos_robust_pad])
            pos_robust_embedding = pos_robust_embedding.reshape(1,-1)
            d = paired_cosine_distances(context_embedding, pos_robust_embedding)
            d_list.append(d.item())
        for pos in pos_list:
            pos_embedding = model_encoder.encode(pos)       
            pos_response_embedding = disentangle_model.linear(torch.tensor(pos_embedding))        
            pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
            pos_robust_pad = torch.zeros([384,])
            pos_robust_embedding = torch.cat([pos_robust_embedding, pos_robust_pad])
            pos_robust_embedding = pos_robust_embedding.reshape(1,-1)
            d = paired_cosine_distances(context_embedding, pos_robust_embedding)
            d_pos_list.append(np.mean(d_list-d))
        # print(d_list)
        for neg in neg_list:
            neg_embedding = model_encoder.encode(neg)
            neg_response_embedding = disentangle_model.linear(torch.tensor(neg_embedding))       
            neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)
            neg_robust_pad = torch.zeros([384,])
            neg_robust_embedding = torch.cat([neg_robust_embedding, neg_robust_pad])
            neg_robust_embedding = neg_robust_embedding.reshape(1,-1)

            d_neg = paired_cosine_distances(context_embedding, neg_robust_embedding)
            d_neg_list.append(np.mean(d_list-d))


In [None]:
count = 0
for dis in d_pos_list:
    if dis > 0:
        count += 1
print(count, len(d_pos_list))

In [None]:
count = 0
for dis in d_neg_list:
    if dis < 0:
        count += 1
print(count, len(d_neg_list))

In [None]:
with torch.no_grad():
    context_embedding = model_encoder.encode(context)
    context_embedding = disentangle_model.linear(torch.tensor(context_embedding))
    context_embedding = context_embedding.reshape(1,-1)
    d_list = []
    for pos in pos_list[:3]:
        pos_embedding = model_encoder.encode(pos)       
        pos_response_embedding = disentangle_model.linear(torch.tensor(pos_embedding))        
        pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
        pos_robust_pad = torch.zeros([384,])
        pos_robust_embedding = torch.cat([pos_robust_embedding, pos_robust_pad])
        pos_robust_embedding = pos_robust_embedding.reshape(1,-1)
        d = paired_cosine_distances(context_embedding, pos_robust_embedding)
        d_list.append(d.item())
    print(d_list)
    neg = neg_list[5]
    neg_embedding = model_encoder.encode(neg)
    neg_response_embedding = disentangle_model.linear(torch.tensor(neg_embedding))       
    neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)
    neg_robust_pad = torch.zeros([384,])
    neg_robust_embedding = torch.cat([neg_robust_embedding, neg_robust_pad])
    neg_robust_embedding = neg_robust_embedding.reshape(1,-1)

    d_neg = paired_cosine_distances(context_embedding, neg_robust_embedding)
    print(d_neg)
    print(np.mean(d_list-d_neg))

In [None]:
with open ("train.json", "r") as f:
    dev_lines = f.readlines()
dev_examples = []

line = dev_lines[1]
y_list = []

# print(line)
pos_dis_list = []
neg_dis_list = []

for line in dev_lines:
    context_list = []
    pos_list = []
    neg_list = []
    sample=json.loads(line)
    ss = "[CLS] "
    for c in sample["context"]:
        ss = ss + c + " [SEP] "
    context = ss.strip()
    y_list.append(0)
    for i, p in enumerate(sample["positive_responses"]):
        pos = "[CLS] " + p + " [SEP]"
        pos_list.append(pos)
    for j, n in enumerate(sample["adversarial_negative_responses"]):
        neg = "[CLS] " + n + " [SEP]"
        neg_list.append(neg)
    with torch.no_grad():
        context_embedding = model_encoder.encode(context)
        context_embedding = disentangle_model.linear(torch.tensor(context_embedding))
        context_embedding = context_embedding.reshape(1,-1)
        # d_list = []
        for pos in pos_list:
            pos_embedding = model_encoder.encode(pos)       
            pos_response_embedding = disentangle_model.linear(torch.tensor(pos_embedding))        
            pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
            pos_robust_pad = torch.zeros([384,])
            pos_robust_embedding = torch.cat([pos_robust_embedding, pos_robust_pad])
            pos_robust_embedding = pos_robust_embedding.reshape(1,-1)
            d = paired_cosine_distances(context_embedding, pos_robust_embedding)
            pos_dis_list.append(d.item())
        for neg in neg_list:
            neg_embedding = model_encoder.encode(neg)
            neg_response_embedding = disentangle_model.linear(torch.tensor(neg_embedding))       
            neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)
            neg_robust_pad = torch.zeros([384,])
            neg_robust_embedding = torch.cat([neg_robust_embedding, neg_robust_pad])
            neg_robust_embedding = neg_robust_embedding.reshape(1,-1)

            d= paired_cosine_distances(context_embedding, neg_robust_embedding)
            neg_dis_list.append(d.item())
    # print(d_neg)
    # print(np.mean(np.abs(d_list-d_neg)))

In [None]:
print(np.mean(pos_dis_list), np.mean(neg_dis_list))

In [None]:
count = 0
for dis in pos_dis_list:
    if dis <= 1.0:
        count += 1
print(count, len(pos_dis_list))

In [None]:
4531/5710


In [None]:
count = 0
for dis in neg_dis_list:
    if dis > 0.9:
        count += 1
print(count, len(pos_dis_list))

In [None]:
5141/5710

In [None]:
(4531+5141)/(5710*2)

In [None]:
with torch.no_grad():
    context_embedding = model_encoder.encode(context)
    context_embedding = disentangle_model.linear(torch.tensor(context_embedding))
    for pos in pos_list:
        pos_embedding = model_encoder.encode(pos)       
        pos_response_embedding = disentangle_model.linear(torch.tensor(pos_embedding))        
        pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
        pos_robust_pad = torch.zeros([384,])
        pos_robust_embedding = torch.cat([pos_robust_embedding, pos_robust_pad])
            
    for neg in neg_list:
        neg_embedding = model_encoder.encode(neg)
        neg_response_embedding = disentangle_model.linear(torch.tensor(neg_embedding))       
        neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)

In [None]:
print(pos_robust_embedding.shape)

x:原编码 x1：经过线性层之后的编码 x2：拆分之后的编码。 y：源类型 y1:拆分之后的类型（0：neg robust， 1：pos robust， 2：pos non robust, 3: neg_non_robust, 4:context_robust,  5:context_non_robust）

In [None]:
with torch.no_grad():
    context_embedding = model_encoder.encode(context)
    print(context_embedding.shape)
    x_list = []
    y_list = []
    x_list.append(context_embedding)
    y_list.append(0)
    # x1_list = []
    x1_list = []
    context_embedding = disentangle_model.linear(torch.tensor(context_embedding))
    # context_non_robust_embedding, context_robust_embedding = context_embedding.chunk(2,-1)
    x1_list.append(context_embedding.detach().numpy())
    x2_list = []
    y1_list = []
    x2_list.append(context_embedding.detach().numpy())
    y1_list.append(4)
    dis_list = []
    score_list = []
    # x2_list = [context_robust_embedding.detach().numpy()]
    # y1_list = [0]
    # print(context_embedding.detach().numpy().shape)
    for pos in pos_list:
        pos_embedding = model_encoder.encode(pos)
        x_list.append(pos_embedding)
        y_list.append(1)
        pos_response_embedding = disentangle_model.linear(torch.tensor(pos_embedding))
        x1_list.append(pos_response_embedding.detach().numpy())
        
        pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
        pos_pad = torch.zeros([384])
        pos_robust_embedding = torch.cat([pos_robust_embedding, pos_pad])
        pos_non_robust_embedding = torch.cat([pos_non_robust_embedding, pos_pad])

        hidden_state = torch.cat([context_embedding, pos_robust_embedding], dim=1)
        # print(hidden_state.shape)
        labels = torch.tensor([1])
        outputs = classification_model(hidden_state, labels)
        logits = outputs[0]
        # print(logits)              
        outputs = m(logits)
        predict = outputs
        predict = predict.detach().cpu().numpy()

        score_list.append(round(predict[0][1], 3))
        # sample["score"] = str(round(predict[0][1], 3))
        # sample["distance"] = str(d[0])
        x2_list.append(pos_robust_embedding.detach().numpy())
        y1_list.append(1)
        
        x2_list.append(pos_non_robust_embedding.detach().numpy())
        y1_list.append(2)
        pos_robust_embedding = pos_robust_embedding.reshape(1,-1)
        d = paired_cosine_distances(context_embedding, pos_robust_embedding)
        dis_list.append(d.item())
        # pos_non_robust_list.append(pos_non_robust_embedding.detach().numpy())
        # y1_list.append(2)
        # x1_list.append(pos_response_embedding.detach().numpy())
            
    for neg in neg_list:

        neg_embedding = model_encoder.encode(neg)
        neg_response_embedding = disentangle_model.linear(torch.tensor(neg_embedding))       
        neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)
        neg_robust_pad = torch.zeros([384,])
        neg_robust_embedding = torch.cat([neg_robust_embedding, neg_robust_pad])
        hidden_state = torch.cat([context_embedding, neg_robust_embedding], dim=1)
        # print(hidden_state.shape)
        labels = torch.tensor([0])
        outputs = classification_model(hidden_state, labels)
        logits = outputs[0]
        # print(logits)              
        outputs = m(logits)
        predict = outputs
        predict = predict.detach().cpu().numpy()

        score_list.append(round(predict[0][1], 3))
        # sample["score"] = str(round(predict[0][1], 3))
        # sample["distance"] = str(d[0])
        x2_list.append(neg_robust_embedding.detach().numpy())
        y1_list.append(1)

        x2_list.append(neg_non_robust_embedding.detach().numpy())
        y1_list.append(2)
        neg_robust_embedding = neg_robust_embedding.reshape(1,-1)
        d = paired_cosine_distances(context_embedding, neg_robust_embedding)
        dis_list.append(d.item())


        neg_robust_embedding = neg_robust_embedding.reshape(1,-1)

        d= paired_cosine_distances(context_embedding, neg_robust_embedding)
        neg_dis_list.append(d.item())
x = np.array(x_list)  #经过线性层之前的编码
y = np.array(y_list) #经过线性层前的类别
x1 = np.array(x1_list) #经过线性层后的编码
x2 = np.array(x2_list) #拆分后的编码
y1 = np.array(y1_list) #拆分后的类别
print(len(x), len(y), len(x1), len(x2), len(y1))
# with torch.no_grad():
#     context_embedding = torch.tensor(disentangle_model.model_encoder.encode(data[0]))
#     pos_response_embedding = torch.tensor(disentangle_model.model_encoder.encode(data[1]))
#     neg_response_embedding = torch.tensor(disentangle_model.model_encoder.encode(data[2]))
#     context_embedding = disentangle_model.linear(context_embedding)
#     pos_response_embedding = disentangle_model.linear(pos_response_embedding)
#     neg_response_embedding = disentangle_model.linear(neg_response_embedding)
#     pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
#     neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)

In [None]:
with open ("train.json", "r") as f:
    dev_lines = f.readlines()
dev_examples = []

# line = dev_lines[1]
y_list = []

# print(line)
pos_dis_list = []
neg_dis_list = []
x_list = []
y_list = []
x1_list = []
x2_list = []
y1_list = []
for line in dev_lines[:60]:
    pos_list = []
    neg_list = []
    sample=json.loads(line)
    ss = "[CLS] "
    for c in sample["context"]:
        ss = ss + c + " [SEP] "
    context = ss.strip()
    for i, p in enumerate(sample["positive_responses"]):
        pos = "[CLS] " + p + " [SEP]"
        pos_list.append(pos)
    for j, n in enumerate(sample["adversarial_negative_responses"]):
        neg = "[CLS] " + n + " [SEP]"
        neg_list.append(neg)
    with torch.no_grad():
        context_embedding = model_encoder.encode(context)
        # print(context_embedding.shape)

        x_list.append(context_embedding)
        y_list.append(2)
        # x1_list = []

        context_embedding = disentangle_model.linear(torch.tensor(context_embedding))
        # context_non_robust_embedding, context_robust_embedding = context_embedding.chunk(2,-1)
        x1_list.append(context_embedding.detach().numpy())

        # x2_list.append(context_embedding.detach().numpy())
        y1_list.append(2)

        for pos in pos_list:
            pos_embedding = model_encoder.encode(pos)
            x_list.append(pos_embedding)
            y_list.append(1)
            pos_response_embedding = disentangle_model.linear(torch.tensor(pos_embedding))
            # x_list.append(pos_response_embedding.detach().numpy())
            # y_list.append(1)
            pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
            pos_pad = torch.zeros([384])
            pos_robust_embedding = torch.cat([pos_robust_embedding, pos_pad])
            pos_non_robust_embedding = torch.cat([pos_non_robust_embedding, pos_pad])
            x1_list.append(pos_robust_embedding.detach().numpy())
            y1_list.append(1)

        for neg in neg_list:

            neg_embedding = model_encoder.encode(neg)
            x_list.append(neg_embedding)
            y_list.append(0)
            neg_response_embedding = disentangle_model.linear(torch.tensor(neg_embedding))       
            neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)
            neg_robust_pad = torch.zeros([384,])
            neg_robust_embedding = torch.cat([neg_robust_embedding, neg_robust_pad])
            # hidden_state = torch.cat([context_embedding, neg_robust_embedding], dim=1)
            # sample["score"] = str(round(predict[0][1], 3))
            # sample["distance"] = str(d[0])
            x1_list.append(neg_robust_embedding.detach().numpy())
            y1_list.append(0)

x = np.array(x_list)  #经过线性层之前的编码
y = np.array(y_list) #经过线性层前的类别
x1 = np.array(x1_list) #经过线性层后的编码
x2 = np.array(x2_list) #拆分后的编码
y1 = np.array(y1_list) #拆分后的类别
print(len(x), len(y), len(x1), len(x2), len(y1))
# with torch.no_grad():
#     context_embedding = torch.tensor(disentangle_model.model_encoder.encode(data[0]))
#     pos_response_embedding = torch.tensor(disentangle_model.model_encoder.encode(data[1]))
#     neg_response_embedding = torch.tensor(disentangle_model.model_encoder.encode(data[2]))
#     context_embedding = disentangle_model.linear(context_embedding)
#     pos_response_embedding = disentangle_model.linear(pos_response_embedding)
#     neg_response_embedding = disentangle_model.linear(neg_response_embedding)
#     pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
#     neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)

In [None]:
with open ("test.json", "r") as f:
    dev_lines = f.readlines()
dev_examples = []
context_list = []
pos_list = []
neg_list = []
line = dev_lines[2]
y_list = []

# print(line)
sample=json.loads(line)
ss = "[CLS] "
for c in sample["context"]:
    ss = ss + c + " [SEP] "
context = ss.strip()
y_list.append(0)
for i, p in enumerate(sample["positive_responses"]):
    pos = "[CLS] " + p + " [SEP]"
    pos_list.append(pos)
    y_list.append(1)
for j, n in enumerate(sample["adversarial_negative_responses"]):
    neg = "[CLS] " + n + " [SEP]"
    neg_list.append(neg)
    y_list.append(2)

In [None]:
with torch.no_grad():
    context_embedding = model_encoder.encode(context)
    print(context_embedding.shape)
    x_list = []
    y_list = []
    x_list.append(context_embedding)
    y_list.append(0)
    # x1_list = []
    x1_list = []
    context_embedding = disentangle_model.linear(torch.tensor(context_embedding))
    # context_non_robust_embedding, context_robust_embedding = context_embedding.chunk(2,-1)
    x1_list.append(context_embedding.detach().numpy())
    x2_list = []
    y1_list = []
    x_robust_list = []
    y_robust_list = []
    x2_list.append(context_embedding.detach().numpy())
    x_robust_list.append(context_embedding.detach().numpy())
    y1_list.append(0)
    y_robust_list.append(0)
    # x2_list = [context_robust_embedding.detach().numpy()]
    # y1_list = [0]
    # print(context_embedding.detach().numpy().shape)
    for pos in pos_list:
        pos_embedding = model_encoder.encode(pos)
        x_list.append(pos_embedding)
        y_list.append(1)
        pos_response_embedding = disentangle_model.linear(torch.tensor(pos_embedding))
        x1_list.append(pos_response_embedding.detach().numpy())
        
        pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
        pos_pad = torch.zeros([384])
        pos_robust_embedding = torch.cat([pos_robust_embedding, pos_pad])
        pos_non_robust_embedding = torch.cat([pos_non_robust_embedding, pos_pad])

        
        x2_list.append(pos_robust_embedding.detach().numpy())
        x_robust_list.append(pos_robust_embedding.detach().numpy())
        y1_list.append(1)
        y_robust_list.append(1)
        
        x2_list.append(pos_non_robust_embedding.detach().numpy())
        y1_list.append(2)
        
        # pos_non_robust_list.append(pos_non_robust_embedding.detach().numpy())
        # y1_list.append(2)
        # x1_list.append(pos_response_embedding.detach().numpy())
            
    for neg in neg_list:
        neg_embedding = model_encoder.encode(neg)
        x_list.append(neg_embedding)
        neg_response_embedding = disentangle_model.linear(torch.tensor(neg_embedding))
        x1_list.append(neg_response_embedding.detach().numpy())
        y_list.append(2)
        
        neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)

        neg_robust_embedding = torch.cat([neg_robust_embedding, pos_pad])
        neg_non_robust_embedding = torch.cat([neg_non_robust_embedding, pos_pad])
        x2_list.append(neg_non_robust_embedding.detach().numpy())
        y1_list.append(3)
        x2_list.append(neg_robust_embedding.detach().numpy())
        x_robust_list.append(neg_robust_embedding.detach().numpy())
        y_robust_list.append(2)
        y1_list.append(0)
        # print(x2_list)

x = np.array(x_list)  #经过线性层之前的编码
y = np.array(y_list) #经过线性层前的类别
x1 = np.array(x1_list) #经过线性层后的编码
x2 = np.array(x2_list) #拆分后的编码
y1 = np.array(y1_list) #拆分后的类别
x_robust = np.array(x_robust_list)
y_robust = np.array(y_robust_list)
print(len(x), len(y), len(x1), len(x2), len(y1), len(x_robust), len(y_robust))

In [None]:
tsne = TSNE(
    perplexity=3,
    n_iter=200,
    metric="cosine",
    n_jobs=8,
    random_state=0,
)
embedding = tsne.fit(x)

In [None]:
x = embedding
fig, ax = plt.subplots()


classes = np.unique(y)

default_colors = matplotlib.rcParams["axes.prop_cycle"]
colors = {k: v["color"] for k, v in zip(classes, default_colors())}

point_colors = list(map(colors.get, y))

ax.scatter(x[:, 0], x[:, 1], c=point_colors, marker="v", s=50)

legend_handles = [
    matplotlib.lines.Line2D(
        [],
        [],
        marker="s",
        color="w",
        markerfacecolor=colors[yi],
        ms=10,
        alpha=1,
        linewidth=0,
        markeredgecolor="k",
#         label=yi
    )
    for yi in classes
]
legend_kwargs_ = dict(labels=["context", "pos", "neg"],bbox_to_anchor=(1.0, 1.0), frameon=False, fontsize=10)
ax.legend(handles=legend_handles, **legend_kwargs_)
ax.set_title("Normal")
plt.savefig("fig1_1.svg")

In [None]:
tsne = TSNE(
    perplexity=3,
    n_iter=200,
    metric="cosine",
    n_jobs=8,
    random_state=0,
)
embedding = tsne.fit(x_robust)
x = embedding
fig, ax = plt.subplots()


classes = np.unique(y_robust)

default_colors = matplotlib.rcParams["axes.prop_cycle"]
colors = {k: v["color"] for k, v in zip(classes, default_colors())}

point_colors = list(map(colors.get, y_robust))

ax.scatter(x[:, 0], x[:, 1], c=point_colors, marker="v", s=50)

legend_handles = [
    matplotlib.lines.Line2D(
        [],
        [],
        marker="s",
        color="w",
        markerfacecolor=colors[yi],
        ms=10,
        alpha=1,
        linewidth=0,
        markeredgecolor="k",
#         label=yi
    )
    for yi in classes
]
legend_kwargs_ = dict(labels=["context", "pos_robust", "neg_robust", ],bbox_to_anchor=(0.98, 0.2), frameon=False, fontsize=10)
ax.legend(handles=legend_handles, **legend_kwargs_)
ax.set_title("Disentangled")
plt.savefig("fig1_2.svg")

上图为经过线性层前的编码的可视化

通过对编码前的response和编码后的response进行可视化，发现仍能有很好的区分

x:原编码 x1：经过线性层之后的编码 x2：拆分之后的编码。 y：源类型 y1:拆分之后的类型（0：neg robust， 1：pos robust， 2：pos non robust, 3: neg_non_robust）

In [None]:
tsne = TSNE(
    perplexity=3,ax.set_title("example")


In [None]:
from enum import Enum
from typing import Iterable, Dict
import torch.nn.functional as F
from torch import nn, Tensor
from sentence_transformers.SentenceTransformer import SentenceTransformer
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
class MLP(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(MLP, self).__init__()
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels
        
    def forward(self, hidden_states, labels):
        logits = self.classifier(hidden_states)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return logits, loss

In [None]:
with open ("train.json", "r") as f:
    train_lines = f.readlines()
train_classification_examples = []
context_list = []
pos_list = []
neg_list = []
for line in train_lines:
    # print(line)
    sample=json.loads(line)
    ss = "[CLS] "
    for c in sample["context"]:
        ss = ss + c + " [SEP] "
    context = ss.strip()
    for i, p in enumerate(sample["positive_responses"]):
        pos = "[CLS] " + p + " [SEP]"
        train_classification_examples.append((context, pos, 1))
    for j, n in enumerate(sample["adversarial_negative_responses"]):
        neg = "[CLS] " + n + " [SEP]"
        train_classification_examples.append((context, neg, 0))
# model = SentenceTransformer("output/train_bi-encoder-mnrl-distilbert-base-uncased-2023-08-23_21-32-34/72340")

In [None]:
with open ("test.json", "r") as f:
    test_lines = f.readlines()
test_classification_examples = []
context_list = []
pos_list = []
neg_list = []
for line in test_lines:
    # print(line)
    sample=json.loads(line)
    ss = "[CLS] "
    for c in sample["context"]:
        ss = ss + c + " [SEP] "
    context = ss.strip()
    for i, p in enumerate(sample["positive_responses"]):
        pos = "[CLS] " + p + " [SEP]"
        test_classification_examples.append((context, pos, 1))
    for j, n in enumerate(sample["adversarial_negative_responses"]):
        neg = "[CLS] " + n + " [SEP]"
        test_classification_examples.append((context, neg, 0))

In [None]:
train_classification_dataloader = DataLoader(train_classification_examples, shuffle=True, batch_size=1024)
test_classification_dataloader = DataLoader(test_classification_examples, shuffle=True, batch_size=1024)
num_train_data = len(train_classification_examples)
num_test_data = len(test_classification_examples)

In [None]:
classification_model = MLP(1152, 2)

In [None]:
import os
num_epoch = 30
epoch = 0
load_checkpoint_path = "./disentangle_classifier/"
if not os.path.exists(load_checkpoint_path):
   os.mkdir(load_checkpoint_path)
optimizer = torch.optim.Adam(classification_model.parameters(), lr=1e-5)
for epoch in range(num_epoch):
    accu = 0
    for train_data in train_classification_dataloader:
        with torch.no_grad():
            labels = train_data[2]
            context_embedding = torch.tensor(model_encoder.encode(train_data[0]))
            context_embedding = disentangle_model.linear(context_embedding)
            context_non_robust_embedding, context_robust_embedding = context_embedding.chunk(2,-1)
            response_embedding = torch.tensor(model_encoder.encode(train_data[1]))
            response_embedding = disentangle_model.linear(response_embedding)
            non_robust_embedding, robust_embedding = response_embedding.chunk(2, -1)
            # print(context_embedding.shape)
            hidden_state = torch.cat([context_embedding, robust_embedding], dim=1)
            # print(hidden_state.shape)

        outputs = classification_model(hidden_state, labels)
        # print(data[2])
        loss = outputs[1]
        logits = outputs[0]
        # print(outputs)
        # print(loss)
        loss.backward()
        optimizer.step()
    total_num_right = 0
    for test_data in test_classification_dataloader:
        with torch.no_grad():
            labels = test_data[2]
            context_embedding = torch.tensor(model_encoder.encode(test_data[0]))
            context_embedding = disentangle_model.linear(context_embedding)
            context_non_robust_embedding, context_robust_embedding = context_embedding.chunk(2,-1)
            response_embedding = torch.tensor(model_encoder.encode(test_data[1]))
            response_embedding = disentangle_model.linear(response_embedding)
            non_robust_embedding, robust_embedding = response_embedding.chunk(2, -1)
            # print(context_embedding.shape)
            # non_robust_embedding, robust_embedding = response_embedding.chunk(2, -1)
            # print(context_embedding.shape)
            hidden_state = torch.cat([context_embedding, robust_embedding], dim=1)
            # print(hidden_state.shape)
            outputs = classification_model(hidden_state, labels)
            logits = outputs[0]
            score = logits.argmax(dim=-1)
            num_right = ((score == labels).float()).sum()
            total_num_right += num_right
    right_rate = total_num_right/num_test_data
    print(right_rate)
    if right_rate > accu:
        accu = right_rate
        state = {
            "epoch": epoch,
            "model_state_dict": classification_model.state_dict(),
            "best_accu": accu,
        }
        torch.save(state, load_checkpoint_path + "checkpoint_v3.bin")
    # print(right_rate) 

In [None]:
# 1.先对文本进行constrastive learning，并对sentence transformer模型的输出加一层线性层
# 2.对线性层的输出再进行一次contrastive learning，保证线性层的输出与sentence transformer模型的输出差别不大，此时冻结sentence transformer模型训练，
#   包括以后的步骤也会冻结sentence transformer模型
# 3，对线性层的输出进行切分：robust与non robust，并做contrastive learning
#    3.1 response内部contrastive learning：robust与non robust远离
#    3.2 不同response之间contrastive learning：robust与robust相近
#  对于这一步，设立一个分类器
import json
import sys
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import torch
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
import logging
from datetime import datetime
import gzip
import os
import tarfile
from collections import defaultdict
from torch.utils.data import IterableDataset
from tqdm import tqdm
from torch.utils.data import Dataset
import random
import pickle
import argparse
from enum import Enum
# import torch.nn as nn
from torch import nn, Tensor
import torch.nn.functional as F
from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
import datetime

def log(message, logfile="app.log"):
    """
    Log a message with a timestamp to a specified logfile.
    
    :param message: Message to be logged
    :param logfile: File to which the log will be written (default: app.log)
    """
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with open(logfile, "a") as log_file:
        log_file.write(f"{timestamp} - {message}\n")

class Classifier(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(Classifier, self).__init__()
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, hidden_states, labels):
        # print(hidden_states.shape)
        n = hidden_states.shape[0]
        a = torch.ones(n)
        labels = labels * a
        labels = labels.long().cuda()
        logits = self.classifier(hidden_states)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return logits, loss

class MLP(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(MLP, self).__init__()
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels
        
    def forward(self, hidden_states, labels):
        logits = self.classifier(hidden_states)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return logits, loss

class SiameseDistanceMetric(Enum):
    """
    The metric for the contrastive loss
    """
    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
    COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)

class Second_MultipleNegativesRankingLoss(nn.Module):
    """
        This loss expects as input a batch consisting of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n)
        where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.

        For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and
        n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.

        This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc))
        as it will sample in each batch n-1 negative docs randomly.

        The performance usually increases with increasing batch sizes.

        For more information, see: https://arxiv.org/pdf/1705.00652.pdf
        (Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4)

        You can also provide one or multiple hard negatives per anchor-positive pair by structering the data like this:
        (a_1, p_1, n_1), (a_2, p_2, n_2)

        Here, n_1 is a hard negative for (a_1, p_1). The loss will use for the pair (a_i, p_i) all p_j (j!=i) and all n_j as negatives.

        Example::

            from sentence_transformers import SentenceTransformer, losses, InputExample
            from torch.utils.data import DataLoader

            model = SentenceTransformer('distilbert-base-uncased')
            train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']),
                InputExample(texts=['Anchor 2', 'Positive 2'])]
            train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
            train_loss = losses.MultipleNegativesRankingLoss(model=model)
    """
    def __init__(self, scale: float = 20.0, similarity_fct = util.cos_sim):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
        """
        super(Second_MultipleNegativesRankingLoss, self).__init__()
        # self.model = model
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        # self.distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
        # self.classify_model = nn.Linear(2 * model.get_sentence_embedding_dimension(), 3)


    def forward(self, reps):
        # reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        embeddings_a = reps[0]
        embeddings_b = torch.cat(reps[1:])
        # print(embeddings_a.shape, embeddings_b.shape)
        scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)  # Example a[i] should match with b[i]
        return self.cross_entropy_loss(scores, labels)

    def get_config_dict(self):
        return {'scale': self.scale, 'similarity_fct': self.similarity_fct.__name__}

class test_ContrastiveLoss(nn.Module):
    """
    Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
    two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.

    Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

    :param model: SentenceTransformer model
    :param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
    :param margin: Negative samples (label == 0) should have a distance of at least the margin value.
    :param size_average: Average by the size of the mini-batch.

    Example::

        from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
        from torch.utils.data import DataLoader

        model = SentenceTransformer('all-MiniLM-L6-v2')
        train_examples = [
            InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
            InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]

        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
        train_loss = losses.ContrastiveLoss(model=model)

        model.fit([(train_dataloader, train_loss)], show_progress_bar=True)

    """

    def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True):
        super(test_ContrastiveLoss, self).__init__()
        self.distance_metric = distance_metric
        self.margin = margin
        self.size_average = size_average

    def get_config_dict(self):
        distance_metric_name = self.distance_metric.__name__
        for name, value in vars(SiameseDistanceMetric).items():
            if value == self.distance_metric:
                distance_metric_name = "SiameseDistanceMetric.{}".format(name)
                break

        return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average}

    def forward(self, reps, labels):
        # reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        # assert len(reps) == 2
        labels = torch.tensor(labels)
        rep_anchor, rep_other = reps
        # print(rep_anchor.shape,rep_other.shape)

        distances = self.distance_metric(rep_anchor, rep_other)
        # print(distances)
        losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
        # print(losses)
        return losses.mean() if self.size_average else losses.sum()

model_encoder = SentenceTransformer("output/train_bi-encoder-mnrl-distilbert-base-uncased-2023-08-23_21-32-34/72340").cuda()
class Disentangle_Model(nn.Module):
    def __init__(self, max_seq_length=1024, batch_size=32, num_labels=3):
        super(Disentangle_Model, self).__init__()
        # self.model_name_or_path = model_name_or_path
        # self.device = device
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        self.num_labels = num_labels
        # self.model_encoder = SentenceTransformer("output/train_bi-encoder-mnrl-bert-base-uncased-2023-08-29_15-22-44/86808")
        self.linear = nn.Linear(768, 768)
        self.second_mutiple_negatives_ranking_loss = Second_MultipleNegativesRankingLoss()
        self.contrastive_loss = test_ContrastiveLoss()
        self.classifier = Classifier(int(1536), self.num_labels)
        self.loss_fct = nn.CrossEntropyLoss()
        # self.optimizer = AdamW(self.parameters(), lr=2e-5, eps=1e-8)
        # self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=10000)

    def forward(self, batch):
        # self.train()
        # print(len(batch[0]))
        with torch.no_grad():
            context_embedding = torch.tensor(model_encoder.encode(batch[0])).cuda()
            pos_response_embedding = torch.tensor(model_encoder.encode(batch[1])).cuda()
            neg_response_embedding = torch.tensor(model_encoder.encode(batch[2])).cuda()
        # print(neg_response_embedding.device)
        # print(context_embedding.shape, pos_response_embedding.shape, neg_response_embedding.shape)
        context_embedding = self.linear(context_embedding)
        pos_response_embedding = self.linear(pos_response_embedding)
        neg_response_embedding = self.linear(neg_response_embedding)
        # print(context_embedding.shape, pos_response_embedding.shape, neg_response_embedding.shape)
        reps = [context_embedding, pos_response_embedding, neg_response_embedding]
        contrastive_loss_1 = self.second_mutiple_negatives_ranking_loss(reps)
        # 3.对线性层的输出进行切分：robust与non robust，并做contrastive learning
        pos_non_robust_embedding, pos_robust_embedding = pos_response_embedding.chunk(2, -1)
        robust_pad = torch.zeros([pos_robust_embedding.size(0),384]).cuda() #等分切分，robust_pad==non_robust_pad
        # robust_pad = torch.zeros([pos_robust_embedding.size(0), 256]).cuda() #不等分切分1:2(256:512)，robust_pad 256,non_robust 512
        pos_robust_embedding = torch.cat([pos_robust_embedding, robust_pad], dim=1)
        
        # non_robust_pad = torch.zeros([pos_robust_embedding.size(0), 512]).cuda()
        pos_non_robust_embedding = torch.cat([pos_non_robust_embedding, robust_pad], dim=1)

        neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.chunk(2, -1)
        # neg_non_robust_embedding, neg_robust_embedding = neg_response_embedding.split([256,512], -1)
        neg_robust_embedding = torch.cat([neg_robust_embedding, robust_pad], dim=1)
        neg_non_robust_embedding = torch.cat([neg_non_robust_embedding, robust_pad], dim=1)
        pos = [pos_non_robust_embedding, pos_robust_embedding]
        # print(pos_non_robust_embedding.shape, pos_robust_embedding.shape)
        pos_inside_contrastive_loss = self.contrastive_loss(pos, 0.0)
        neg = [neg_non_robust_embedding, neg_robust_embedding]
        neg_inside_contrastive_loss = self.contrastive_loss(neg, 0.0)
        outside_robust_contrastive_loss = self.contrastive_loss([pos_robust_embedding, neg_robust_embedding], 0.0)
        # outside_diff_contrastive_loss = self.contrastive_loss([pos_robust_embedding, neg_non_robust_embedding], 0.0)
        outside_non_robust_contrastive_loss = self.contrastive_loss([pos_non_robust_embedding, neg_non_robust_embedding], 1.0)
        
        reps = [context_embedding, pos_robust_embedding, neg_robust_embedding]
        contrastive_loss_2 = self.second_mutiple_negatives_ranking_loss(reps)
        # 4.对于这一步，设立一个分类器
        hidden_state = torch.cat([context_embedding, pos_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 1)
        classification_loss_1 = outputs[1]

        hidden_state = torch.cat([context_embedding, pos_non_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 2)
        classification_loss_2 = outputs[1]

        hidden_state = torch.cat([context_embedding, neg_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 0)
        classification_loss_3 = outputs[1]

        hidden_state = torch.cat([context_embedding, neg_non_robust_embedding], dim=1)
        outputs = self.classifier(hidden_state, 2)
        classification_loss_4 = outputs[1]

        loss = contrastive_loss_1 + contrastive_loss_2 + pos_inside_contrastive_loss + neg_inside_contrastive_loss + outside_robust_contrastive_loss + outside_non_robust_contrastive_loss + classification_loss_1 + classification_loss_2 + classification_loss_3 + classification_loss_4
        return loss

        # self.scheduler.step()
        # self.optimizer.zero_grad()
        # if step % 100 == 0:
        #     print("loss:", loss.item())

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))

if __name__ == "__main__":
    m = nn.Softmax(dim=-1)
    disentangle_model = Disentangle_Model().cuda()

    load_checkpoint_path = "./disentangle_model/"

    model_path = load_checkpoint_path + "checkpoint_distill_bert_pad.bin"
    state_dict = torch.load(model_path)
    disentangle_model.load_state_dict(state_dict["model_state_dict"])
    # print(train_classification_examples)

    with open ("train.json", "r") as f:
        test_lines = f.readlines()
    test_classification_examples = []
    context_list = []
    pos_list = []
    neg_list = []
    for line in test_lines:
        # print(line)
        sample=json.loads(line)
        ss = "[CLS] "
        for c in sample["context"]:
            ss = ss + c + " [SEP] "
        context = ss.strip()
        for i, p in enumerate(sample["positive_responses"]):
            pos = "[CLS] " + p + " [SEP]"
            test_classification_examples.append((context, pos, 1))
        for j, n in enumerate(sample["adversarial_negative_responses"]):
            neg = "[CLS] " + n + " [SEP]"
            test_classification_examples.append((context, neg, 0))


    test_classification_dataloader = DataLoader(test_classification_examples, shuffle=False, batch_size=1)

    num_test_data = len(test_classification_examples)
    
    classification_model = MLP(1536, 2).cuda()
    classification_model_path = load_checkpoint_path + "classification_checkpoint_pad.bin"
    state_dict = torch.load(classification_model_path)
    classification_model.load_state_dict(state_dict["model_state_dict"])
    total_num_right = 0
    total_num_right_1 = 0
    samples = []
    for test_data in test_classification_dataloader:
        sample = {}
        with torch.no_grad():
            labels = test_data[2].to("cuda")
            # print(test_data)
            sample["context"] = test_data[0]
            sample["response"] = test_data[1]
            sample["label"] = test_data[2].item()
            context_embedding = torch.tensor(model_encoder.encode(test_data[0])).cuda()
            context_embedding = disentangle_model.linear(context_embedding)
            # context_non_robust_embedding, context_robust_embedding = context_embedding.chunk(2,-1)
            response_embedding = torch.tensor(model_encoder.encode(test_data[1])).cuda()
            response_embedding = disentangle_model.linear(response_embedding)
            non_robust_embedding, robust_embedding = response_embedding.chunk(2, -1)
            # pos_pad = torch.zeros([robust_embedding.size(0), 384]).cuda()
            # non_robust_embedding, robust_embedding = response_embedding.split([256,512], -1)
            # print(robust_embedding.shape)
            robust_pad = torch.zeros([robust_embedding.size(0), 384]).cuda()
            # robust_pad = torch.zeros([robust_embedding.size(0),256]).cuda()
            # print(robust_pad.shape)
            # print(robust_embedding.shape)
            robust_embedding = torch.cat([robust_embedding, robust_pad], dim=1)

            # pos_dis_list.append(d.item())
            # print(context_embedding.shape)
            # non_robust_embedding, robust_embedding = response_embedding.chunk(2, -1)
            # print(context_embedding.shape)
            hidden_state = torch.cat([context_embedding, robust_embedding], dim=1)
            # print(hidden_state.shape)
            outputs = classification_model(hidden_state, labels)
            logits = outputs[0]
            # print(logits)              
            outputs = m(logits)
            predict = outputs
            predict = predict.detach().cpu().numpy()
            # print(predict)
            # pred = np.argmax(predict, axis=1).flatten()
            # print(pred[0])
            score = logits.argmax(dim=-1)
            # num_right = ((score == labels).float()).sum()
            a = robust_embedding.cpu().reshape(1,-1)
            b = context_embedding.cpu().reshape(1,-1)
            d = paired_cosine_distances(a, b)
            # total_num_right += num_right
            sample["predicted"] = score.item()
            
            sample["predicted_score"] = str(round(predict[0][1], 3))
            sample["distance"] = str(d[0])
            # if d[0] < 0.9 and int(sample["predicted"]) == 0:
            #     # total_num_right += 1
            #     sample["predicted"] = 1
            # if d[0] > 0.9 and int(sample["predicted"]) == 1:
            #     # total_num_right += 1
            #     sample["predicted"] = 0
            samples.append(sample)
    # right_rate = total_num_right/num_test_data
    # print(right_rate)
pos_right = 0
neg_right = 0
total_pos = 0
total_neg = 0
for s in samples:
    if int(s["label"]) == 1:
        total_pos += 1
        if int(s["predicted"]) == 1:
            pos_right += 1
    if int(s["label"]) == 0:
        total_neg += 1
        if int(s["predicted"]) == 0:
            neg_right += 1
    if int(s["label"]) == int(s["predicted"]):
        total_num_right += 1
right_rate = total_num_right/num_test_data
print(right_rate, pos_right/total_pos, neg_right/total_neg)
        
json_samples = json.dumps(samples, ensure_ascii=False, indent=2)
with open("train_classification_result_small_distance.json", 'w', encoding='utf-8') as f:
    f.write(json_samples)



In [None]:
class Classifier(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(Classifier, self).__init__()
        # self.layer1 = nn.Linear(hidden_size, 2)
        self.layer3 = nn.Linear(2, 2, bias=False)
        # self.relu = nn.ReLU(inplace=True)
        self.num_labels = num_labels

    def forward(self, hidden_states, labels):
        # print(hidden_states.shape)
        n = hidden_states.shape[0]
        a = torch.ones(n).cuda()
        labels = labels * a
        labels = labels.long()
        logits = self.layer3(hidden_states)
        
        # logits = self.relu(value1)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        return logits, loss

Gate = Classifier(2, 2).cuda()


In [None]:
with open("train_classification_result_small_distance.json", "r") as f:
    lines = json.load(f)
train_examples = []
for line in lines:
    train_examples.append((torch.tensor([float(line["predicted_score"]), float(line["distance"])]), line["label"]))
    
with open("test_classification_result_small_distance.json", "r") as f:
    lines = json.load(f)
test_examples = []
for line in lines:
    test_examples.append((torch.tensor([float(line["score"]), float(line["distance"])]), line["label"]))
    
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
test_dataloader = DataLoader(test_examples, shuffle=True, batch_size=32)
num_train_data = len(train_examples)
num_test_data = len(test_examples)

In [None]:
torch.tensor((1,20))

In [None]:
print(

In [None]:
optimizer = torch.optim.SGD(Gate.parameters(), lr=1e-6)
num_epoch = 10000
# import os
for epoch in range(num_epoch):
    accu = 0
    pbar = tqdm(train_dataloader)
    print("\n")
    print("*" * 20, "epoch: ", epoch, " Training", "*" * 20)
    for train_data in train_dataloader:

        data = train_data[0].cuda()
        # print(data)
        labels = train_data[1].to("cuda")
        # print(labels.shape)
        # h = torch.cat([train_data[0].unsqueeze(0),train_data[1].unsqueeze(0)],dim=0)
        # print(h.shape)
        # print(h)
        outputs = Gate(data, labels)
        # print(data[2])
        loss = outputs[1]
        logits = outputs[0]
        # print(outputs)
        # print(loss)
        loss.backward()
        optimizer.step()
        pbar.update(1)
    total_num_right = 0
    print("\n")
    pbar = tqdm(test_dataloader)
    print("*" * 20, "loss: ", loss.item(), "*" * 20)
    print("*" * 20, "epoch: ", epoch, " Testing", "*" * 20)
    for test_data in test_dataloader:

        with torch.no_grad():
            data = test_data[0].cuda()
            labels = test_data[1].to("cuda")

            outputs = Gate(data, labels)
            logits = outputs[0]
            score = logits.argmax(dim=-1)
            num_right = ((score == labels).float()).sum()
            total_num_right += num_right
        pbar.update(1)
    right_rate = total_num_right/num_test_data
    print(right_rate)
    print(Gate.state_dict())
    # log(right_rate, logfile)
    if right_rate > accu:
        accu = right_rate
        state = {
            "epoch": epoch,
            "model_state_dict": Gate.state_dict(),
            "best_accu": accu,
        }
        torch.save(state, load_checkpoint_path + "classification_checkpoint_pad_gate.bin")


In [None]:
import numpy as np
with open("classification_result_small_distance.json", "r") as f:
    samples = json.load(f)

distances = []
for s in samples:
    distances.append(float(s["distance"]))
print(np.mean(distances))

In [None]:
# with open("train_classification_result_small_distance.json", "r") as f:
#     samples = json.load(f)
# num_right = 0
# pos_right = 0
# neg_right = 0
# total_pos = 0
# total_neg = 0
accu = 0
p_value = 0
for p in range(0, 10, 1):
    p = p / 10
    with open("test_classification_result_small_distance.json", "r") as f:
        samples = json.load(f)
    num_right = 0
    pos_right = 0
    neg_right = 0
    total_pos = 0
    total_neg = 0
    for i, s in enumerate(samples):
        if float(s["distance"]) < p and int(s["label"]) == 1:
            num_right += 1
        if float(s["distance"]) > p and int(s["label"]) == 0:
            num_right += 1

    right_rate = num_right/len(samples)
    print(p,right_rate)
    if accu < right_rate:
        accu = right_rate
        p_value = p

print(accu, p_value)

In [None]:

with open("dev_classification_result_small_distance.json", "r") as f:
    samples = json.load(f)
num_right = 0
pos_right = 0
neg_right = 0
total_pos = 0
total_neg = 0
for i, s in enumerate(samples):
    if float(s["distance"]) < p and int(s["label"]) == 1:
        num_right += 1
    if float(s["distance"]) > p and int(s["label"]) == 0:
        num_right += 1

right_rate = num_right/len(samples)
print(p,right_rate)
if accu < right_rate:
    accu = right_rate
    p_value = p

print(accu, p_value)

In [None]:
with open("test_classification_result_small_distance.json", "r") as f:
    samples = json.load(f)
pos_right = 0
neg_right = 0
total_pos = 0
total_neg = 0
total_num_right = 0
for s in samples:
    if int(s["label"]) == 1:
        total_pos += 1
        if int(s["predicted"]) == 1:
            pos_right += 1
    if int(s["label"]) == 0:
        total_neg += 1
        if int(s["predicted"]) == 0:
            neg_right += 1
    if int(s["label"]) == int(s["predicted"]):
        total_num_right += 1
right_rate = total_num_right/num_test_data
print(right_rate, pos_right/total_pos, neg_right/total_neg)

In [None]:
import random
import numpy as np
import datetime
def log(message, logfile="app.log"):
    """
    Log a message with a timestamp to a specified logfile.
    
    :param message: Message to be logged
    :param logfile: File to which the log will be written (default: app.log)
    """
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with open(logfile, "a") as log_file:
        log_file.write(f"{timestamp} - {message}\n")
with open("classification_result_small_distance.json", "r") as f:
    samples = json.load(f)
num_right = 0
pos_right = 0
neg_right = 0
total_pos = 0
total_neg = 0
random.shuffle(samples)
mean_distance = 0
logfile = "threshold_value_test.log"
message = "使用均值动态求阈值"
with open(logfile, "w") as log_file:
    log_file.write(f"{message}\n")
for i, s in enumerate(samples):
    if float(s["distance"]) < 0.9 and int(s["predicted"]) == 0:
        s["predicted"] = 1
    # if float(s["distance"]) > 1.4 and int(s["predicted"]) == 1:
    #     s["predicted"] = 0
    if int(s["label"]) == 1:
        total_pos += 1
        if int(s["predicted"]) == 1:
            pos_right += 1
    if int(s["label"]) == 0:
        total_neg += 1
        if int(s["predicted"]) == 0:
            neg_right += 1
    if int(s["label"]) == int(s["predicted"]):
        num_right += 1
    # if int(s["label"]) != int(s["predicted"]):
    #     print(i, s)
right_rate = num_right/num_test_data
print(right_rate, pos_right/total_pos, neg_right/total_neg)

In [None]:
import random, json
import numpy as np
import datetime
import copy
def log(message, logfile="app.log"):
    """
    Log a message with a timestamp to a specified logfile.
    
    :param message: Message to be logged
    :param logfile: File to which the log will be written (default: app.log)
    """
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with open(logfile, "a") as log_file:
        log_file.write(f"{message}\n")
with open("test_classification_result_small_distance.json", "r") as f:
    samples = json.load(f)
num_right = 0
pos_right = 0
neg_right = 0
total_pos = 0
total_neg = 0
random.shuffle(samples)
mean_distance = float(samples[0]["distance"])
distances=[]
distances.append(float(samples[0]["distance"]))
logfile = "threshold_value_test.log"
message = "使用均值动态求阈值"
log(message, logfile)
for i, s in enumerate(samples):
    mean_distance = np.mean(distances)
    mean_distance = round(mean_distance, 1)
    num_right = 0
    pos_right = 0
    neg_right = 0
    total_pos = 0
    total_neg = 0
    num_right_raw = 0
    examples = copy.deepcopy(samples[:i+1])
    distances = []
    for j, ss in enumerate(examples):
        if int(ss["label"]) == int(ss["predicted"]):
            num_right_raw += 1
        if float(ss["distance"]) < mean_distance and int(ss["predicted"]) == 0:
            ss["predicted"] = 1
        # if float(s["distance"]) > 1.4 and int(s["predicted"]) == 1:
        #     s["predicted"] = 0
        if int(ss["label"]) == 1:
            total_pos += 1
            if int(ss["predicted"]) == 1:
                pos_right += 1
        if int(ss["label"]) == 0:
            total_neg += 1
            if int(ss["predicted"]) == 0:
                neg_right += 1
        if int(ss["label"]) == int(ss["predicted"]):
            num_right += 1
        distances.append(float(ss["distance"]))

    a = num_right/len(examples) * 100
    a = round(a, 2)
    
    b = num_right_raw/len(examples) * 100
    b = round(b, 2)
    examples  = []
    print(a, b, mean_distance)
    m = "{}----".format(i) + "{}----".format(a) + "{}----".format(b) + "{}".format(mean_distance)
    log(m, logfile)
    # if int(s["label"]) != int(s["predicted"]):
    #     print(i, s)
right_rate = num_right/num_test_data
print(right_rate, pos_right/total_pos, neg_right/total_neg, num_right_raw/num_test_data)

In [None]:
import random, json
import numpy as np
import datetime
import copy
def log(message, logfile="app.log"):
    """
    Log a message with a timestamp to a specified logfile.
    
    :param message: Message to be logged
    :param logfile: File to which the log will be written (default: app.log)
    """
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with open(logfile, "a") as log_file:
        log_file.write(f"{message}\n")
with open("test_classification_result_small_distance.json", "r") as f:
    samples = json.load(f)
num_test_data = len(samples)
num_right = 0
pos_right = 0
neg_right = 0
total_pos = 0
total_neg = 0
random.shuffle(samples)
mean_distance = float(samples[0]["distance"])
distances=[]
distances.append(float(samples[0]["distance"]))
logfile = "threshold_value_test_1.log"
message = "使用均值动态求阈值"
log(message, logfile)
for i, s in enumerate(samples):
    mean_distance = np.mean(distances)
    std_distance = np.std(distances)
    mean_distance = round(mean_distance, 2)
    num_right = 0
    pos_right = 0
    neg_right = 0
    total_pos = 0
    total_neg = 0
    num_right_raw = 0
    examples = copy.deepcopy(samples[:i+1])
    distances = []
    for j, ss in enumerate(examples):
        if int(ss["label"]) == int(ss["predicted"]):
            num_right_raw += 1
        # h = (np.abs(float(ss["distance"]) - mean_distance)) + float(ss["score"])
        # h = float(ss["score"])/float(ss["distance"])
        h = (float(ss["score"]) - float(ss["distance"])) * (float(ss["distance"])/std_distance) * (float(ss["score"]) - float(ss["distance"]))
        # if float(ss["distance"]) < mean_distance and int(ss["predicted"]) == 0:
        #     ss["predicted"] = 1
        # if float(s["distance"]) > 1.4 and int(s["predicted"]) == 1:
        #     s["predicted"] = 0
        if int(ss["label"]) == 1:
            total_pos += 1
            if h < 0.5:
                pos_right += 1
        if int(ss["label"]) == 0:
            total_neg += 1
            if h > 0.5:
                neg_right += 1
        # if int(ss["label"]) == int(ss["predicted"]):
        #     num_right += 1
        distances.append(float(ss["distance"]))
    num_right = neg_right + pos_right
    a = num_right/len(examples) * 100
    a = round(a, 2)
    
    b = num_right_raw/len(examples) * 100
    b = round(b, 2)
    examples  = []
    # print(a, b, mean_distance)
    m = "{}----".format(i) + "{}----".format(a) + "{}----".format(b) + "{}".format(mean_distance)
    log(m, logfile)
    # if int(s["label"]) != int(s["predicted"]):
    #     print(i, s)
right_rate = num_right/num_test_data
print(right_rate, pos_right/total_pos, neg_right/total_neg, num_right_raw/num_test_data)

In [None]:
import random, json
import numpy as np
import datetime
import copy
# def log(message, logfile="app.log"):
#     """
#     Log a message with a timestamp to a specified logfile.
    
#     :param message: Message to be logged
#     :param logfile: File to which the log will be written (default: app.log)
#     """
#     timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
#     with open(logfile, "a") as log_file:
#         log_file.write(f"{message}\n")
with open("test_classification_result_small_distance.json", "r") as f:
    samples = json.load(f)
examples = []
num_test_data = len(samples)
num_right = 0
pos_right = 0
neg_right = 0
total_pos = 0
total_neg = 0
random.shuffle(samples)
mean_distance = float(samples[0]["distance"])
distances=[]
# distances.append(float(samples[0]["distance"]))
# logfile = "threshold_value_test.log"
# message = "使用均值动态求阈值"
# log(message, logfile)
num_right = 0
pos_right = 0
neg_right = 0
total_pos = 0
total_neg = 0
num_right_raw = 0
with open("train_classification_result_small_distance.json", "r") as f:
    samples1 = json.load(f)
for s in samples1:
    distances.append(float(s["distance"]))
mean_distance = np.mean(distances)
std_distance = np.std(distances)
max_distance = np.max(distances)
min_distance = np.min(distances)
d_list = []
for i, ss in enumerate(samples):
    if int(ss["label"]) == int(ss["predicted"]):
        num_right_raw += 1
    # h = (np.abs(float(ss["distance"]) - mean_distance)) + float(ss["score"])
    # h = float(ss["score"])/float(ss["distance"])
    # h = float(ss["distance"]) - float(ss["score"])
    distance_norm = (float(ss["distance"]) -min_distance)/(max_distance-min_distance)
    
    h = (float(ss["distance"]) -min_distance)/(max_distance-min_distance)
    l = (1 - h) + float(ss["score"])
    # l = (2 * (1 - h) * float(ss["score"]))/((1 - h) + float(ss["score"]))
    # h = -(float(ss["score"]) - mean_distance) * (float(ss["distance"])/std_distance) * (float(ss["score"]) - mean_distance)
    # h = -(float(ss["score"]) - mean_distance) * (float(ss["distance"])/std_distance) * (float(ss["score"]) 
    #     - mean_distance) + mean_distance/std_distance - mean_distance
    # if float(ss["distance"]) < mean_distance and int(ss["predicted"]) == 0:
    #     ss["predicted"] = 1
    # if float(s["distance"]) > 1.4 and int(s["predicted"]) == 1:
    #     s["predicted"] = 0
    if l > 0.5:
        ss["modified_predicted"] = 1
    else:
        ss["modified_predicted"] = 0
    if int(ss["label"]) == 1:
        total_pos += 1
        if l >= 0.5:
            pos_right += 1
    if int(ss["label"]) == 0:
        total_neg += 1
        if l < 0.5:
            neg_right += 1 
        
    # if int(ss["label"]) == int(ss["predicted"]):
    #     num_right += 1
    d_list.append(float(ss["distance"]))
    ss["h"] = h
    ss["slm_d"] = l
    ss["norm_dis"] = distance_norm
    examples.append(ss)
    

num_right = neg_right + pos_right
a = num_right/len(samples) * 100
a = round(a, 2)

b = num_right_raw/len(samples) * 100
b = round(b, 2)
# examples  = []
    # print(a, b, mean_distance)
    # m = "{}----".format(i) + "{}----".format(a) + "{}----".format(b) + "{}".format(mean_distance)
    # log(m, logfile)
    # if int(s["label"]) != int(s["predicted"]):
    #     print(i, s)
right_rate = num_right/num_test_data
print(right_rate, pos_right/total_pos, neg_right/total_neg, num_right_raw/num_test_data,np.mean(d_list),mean_distance, std_distance)


# json_samples = json.dumps(examples, ensure_ascii=False, indent=2)
# with open("test_classification_result_small_distance_h_normalized.json", 'w', encoding='utf-8') as f:
#     f.write(json_samples)