In [None]:
#https://www.kaggle.com/datasets/tomokihirose/faiss-gpu-173-python310
#https://www.kaggle.com/datasets/datafan07/llm-whls
#https://www.kaggle.com/datasets/ahmadsaladin/mistral-7b-it-v02
#https://www.kaggle.com/models/Microsoft/phi/Transformers/2/1
#https://www.kaggle.com/models/mozhiwenmzw/phi2-public-data-sft-adapter/PyTorch/public-data-sft/1
#https://www.kaggle.com/code/levantaokkz/library-off-for-llm

In [None]:
!pip install -U /kaggle/input/faiss-gpu-173-python310/faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install -Uq /kaggle/input/llm-whls/bitsandbytes-0.41.1-py3-none-any.whl
!pip install -Uq /kaggle/input/llm-whls/peft-0.4.0-py3-none-any.whl
!pip install -Uq /kaggle/input/library-off-for-llm/transformers-4.38.2-py3-none-any.whl
!pip install sentence-transformers

## 方案总揽
* 1.seq2seq模型  
* 2.开源微调模型phi2
* 3.zero-shot大模型mistral-7b-v2

集成三种模型的预测结果为最终结果

### 预处理
生成训练集和验证集的embedding，便于加速训练

In [None]:
%%writefile trian_embedding_generate.py
import pandas as pd
import gc
import numpy as np
df = pd.read_parquet(f"./train_clean.parquet", columns=['rewrite_prompt'])
valid = pd.read_csv('./validation826.csv', usecols=['rewrite_prompt'])

import pandas as pd
import time
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
import pickle
model =  SentenceTransformer('sentence-transformers/sentence-t5-base')#
model.max_seq_length = 512
encoded_data = model.encode(list(df['rewrite_prompt']), batch_size=64, device='cuda', show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)
encoded_data = encoded_data.detach().cpu().numpy()
encoded_data = np.asarray(encoded_data.astype('float32'))

np.save('train_clean_emb_sentence-t5-base.npy', encoded_data)

valid_emb = model.encode(list(valid['rewrite_prompt']), batch_size=64, device='cuda', show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)
valid_emb = valid_emb.detach().cpu().numpy()
valid_emb = np.asarray(valid_emb.astype('float32'))

np.save('valid826_emb_sentence-t5-base.npy', valid_emb)

### seq2seq训练

In [None]:
%%writefile seq2seq_exp14_train.py
import gc
import math
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import os
import pandas as pd
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import uuid

from glob import glob
from torch.nn import Parameter
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from typing import Dict, List

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Current device is: {device}")

os.environ['TOKENIZERS_PARALLELISM']='true'
import tokenizers
import transformers
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")

class config:
    AMP = True
    BATCH_SIZE_TRAIN = 32 #若出现oom，减少即可
    BATCH_SIZE_VALID = 32 #若出现oom，减少即可
    BETAS = (0.9, 0.999)
    DEBUG = 0 #debug改为1
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    LR = 5e-6
    EPOCHS = 6
    EPS = 1e-6
    GRADIENT_CHECKPOINTING = False
    MODEL = "/kaggle/input/deberta-v3-large-hf-weights" #模型文件-https://www.kaggle.com/datasets/radek1/deberta-v3-large-hf-weights
    CKPT = 'deberta-v3-large'
    MAX_GRAD_NORM = 100000.0
    MAX_LEN = 384
    NUM_WORKERS = 0
    PRINT_FREQ = 500
    SEED = 20
    WANDB = False
    WEIGHT_DECAY = 0.008

class paths:
    TRAIN_DATA = "./train_clean.parquet"
    #TRAIN_DATA2 = './train_sft_v13.csv'
    VALID_DATA = './validation826.csv'
    train_embedding_file = './train_clean_emb_sentence-t5-base.npy'
    #train_embedding_file2 = './train_sft_v13_emb_sentence-t5-base.npy'
    valid_embedding_file = './valid826_emb_sentence-t5-base.npy'
    OUTPUT_DIR = "./exp14"#保存文件夹
    LOGGER = 'exp14'

os.makedirs(paths.OUTPUT_DIR, exist_ok=True)

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def get_config_dict(config):
    config_dict = dict((key, value) for key, value in config.__dict__.items()
    if not callable(value) and not key.startswith('__'))
    return config_dict


def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
         'lr': encoder_lr, 'weight_decay': weight_decay},
        {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
         'lr': encoder_lr, 'weight_decay': 0.0},
        {'params': [p for n, p in model.named_parameters() if "model" not in n],
         'lr': decoder_lr, 'weight_decay': 0.0}
    ]
    return optimizer_parameters


def get_logger(filename=paths.OUTPUT_DIR+'/'+paths.LOGGER):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger


def seed_everything(seed=20):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def generate_uuid():
    return str(uuid.uuid4())


def to_device(inputs, device: str = device):
    return {k: v.to(device) for k, v in inputs.items()}

LOGGER = get_logger()
seed_everything(seed=config.SEED)

tokenizer = AutoTokenizer.from_pretrained(config.MODEL)
tokenizer.save_pretrained(paths.OUTPUT_DIR + '/tokenizer/')

def prepare_input(cfg: type, text: np.ndarray, tokenizer):

    inputs = tokenizer.encode_plus(
        text,
        return_tensors=None,
        add_special_tokens=True,
        max_length=cfg.MAX_LEN,
        padding='max_length', # TODO: check padding to max sequence in batch
        truncation=True
    )
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long) # TODO: check dtypes
    return inputs


def collate(inputs):
    mask_len = int(inputs["attention_mask"].sum(axis=1).max()) # Get batch's max sequence length
    for k, v in inputs.items():
        inputs[k] = inputs[k][:,:mask_len]
    return inputs


class CustomDataset(Dataset):
    def __init__(self, cfg, df, tokenizer, rewrite_prompts_embeddings):
        self.cfg = cfg
        self.original_texts = df['original_text'].fillna('').map(str).values
        self.rewritten_texts = df['rewritten_text'].fillna('').map(str).values
        self.rewrite_prompts = rewrite_prompts_embeddings
        self.text_ids = df['id'].astype(str).values
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.text_ids)

    def __getitem__(self, item):
        output = {}
        output["original_text"] = prepare_input(self.cfg, self.original_texts[item], self.tokenizer)
        output["rewritten_text"] = prepare_input(self.cfg, self.rewritten_texts[item], self.tokenizer)
        output["rewrite_prompt"] = self.rewrite_prompts[item]
        output["id"] = self.text_ids[item]
        return output

class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()

    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings


class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, mode: str ="train", pretrained=False):
        super().__init__()
        self.cfg = cfg
        self.mode = mode
        self.dropout = 0.2
        # Load config by inferencing it from the model name.
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.MODEL, output_hidden_states=True)
            self.config.hidden_dropout = 0.
            self.config.hidden_dropout_prob = 0.
            self.config.attention_dropout = 0.
            self.config.attention_probs_dropout_prob = 0.
        # Load config from a file.
        else:
            self.config = torch.load(config_path)

        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.MODEL, config=self.config)
        else:
            self.model = AutoModel(self.config)

        if self.cfg.GRADIENT_CHECKPOINTING:
            self.model.gradient_checkpointing_enable()


        self.pool = MeanPooling()
        self.head = nn.Sequential(
            nn.Linear(self.config.hidden_size*4, 32768),
            nn.BatchNorm1d(32768),
            nn.ReLU(),
            nn.Linear(32768, 768),
        )
        self._init_weights(self.head)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        outputs = self.model(**inputs)
        #last_hidden_states = outputs[1]
        feature1 = self.pool(outputs.hidden_states[-1], inputs['attention_mask'])
        feature2 = self.pool(outputs.hidden_states[-2], inputs['attention_mask'])
        return torch.cat([feature1, feature2], dim=1)

    def forward(self, original_texts, rewritten_texts, rewrite_prompts_embedding):

        original_texts_feature = self.feature(original_texts) # shape (batch_size, 768)
        rewritten_texts_feature = self.feature(rewritten_texts) # shape (batch_size, 768)
        feature = torch.cat([original_texts_feature, rewritten_texts_feature], dim=1) # shape (batch_size, 768 * 2)
        output = self.head(feature)

        if self.mode == "train":
            prompt_embedding = torch.tensor(rewrite_prompts_embedding, device=self.cfg.DEVICE) # shape (batch_size, 768)
        else:
            prompt_embedding = None

        return output, prompt_embedding


model = CustomModel(config, config_path=None, pretrained=True)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {round(total_params/1e6, 2)} M")

def sharpened_cosine_similarity(k: np.ndarray, s: np.ndarray, p: int = 3, q: float = 1e-10):
    dot_product = np.dot(k, s)
    norm_k = np.linalg.norm(k)
    norm_s = np.linalg.norm(s)
    # Compute the cosine similarity with added term q for numerical stability
    cosine_similarity = dot_product / (norm_k * (norm_s + q))
    # Compute the sharpened cosine similarity
    score = np.sign(dot_product) * (cosine_similarity ** p)
    return score


def scs(k_batch: np.ndarray, s_batch: np.ndarray, p: int = 3, q: float = 1e-10):
    bs = k_batch.shape[0]
    scores = []
    for item in range(bs):
        k = k_batch[item]
        s = s_batch[item]
        scores.append(sharpened_cosine_similarity(k, s, p, q))
    score = np.mean(scores)
    return score

# Example usage
k_vector = torch.tensor([[0.90, 0.10, 0.95], [0.05, 0.10, 0.99]])
s_vector = torch.tensor([[1, 0, 1], [0, 0, 1]])
p_value = 3

result = scs(k_vector, s_vector)
print("Mean Sharpened Cosine Similarity:", result)

def train_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    model.train() # set model in train mode
    scaler = torch.cuda.amp.GradScaler(enabled=config.AMP) # Automatic Mixed Precision tries to match each op to its appropriate datatype.
    losses = AverageMeter() # initiate AverageMeter to track the loss.
    start = end = time.time() # track the execution time.
    global_step = 0

    # ========== ITERATE OVER TRAIN BATCHES ============
    with tqdm(train_loader, unit="train_batch", desc='Train') as tqdm_train_loader:
        for step, batch in enumerate(tqdm_train_loader):
            ids_batch = batch.pop("id")
            original_texts = to_device(collate(batch.pop("original_text")))
            rewritten_texts = to_device(collate(batch.pop("rewritten_text")))
            rewrite_prompts = batch.pop("rewrite_prompt")
            batch_size = len(ids_batch)
            targets = torch.ones(batch_size, device=device) # -1 for dissimilar, 1 for similar
            with torch.cuda.amp.autocast(enabled=config.AMP):
                y_preds, y_trues = model(original_texts, rewritten_texts, rewrite_prompts) # forward propagation pass
                loss = criterion(y_preds, y_trues, targets) # get loss
            losses.update(loss.item(), batch_size) # update loss function tracking
            scaler.scale(loss).backward() # backward propagation pass
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)
            scaler.step(optimizer) # update optimizer parameters
            scaler.update()
            optimizer.zero_grad() # zero out the gradients
            global_step += 1
            scheduler.step() # update learning rate
            end = time.time() # get finish time

            # ========== LOG INFO ==========
            if step % config.PRINT_FREQ == 0 or step == (len(train_loader)-1):
                print('Epoch: [{0}][{1}/{2}] '
                      'Elapsed {remain:s} '
                      'Loss: {loss.avg:.4f} '
                      'Grad: {grad_norm:.4f}  '
                      'LR: {lr:.8f}  '
                      .format(epoch+1, step, len(train_loader),
                              remain=timeSince(start, float(step+1)/len(train_loader)),
                              loss=losses,
                              grad_norm=grad_norm,
                              lr=scheduler.get_lr()[0]))
            if config.WANDB:
                wandb.log({f"[fold_{fold}] train loss": losses.val,
                           f"[fold_{fold}] lr": scheduler.get_lr()[0]})

    gc.collect()

    return losses.avg


def valid_epoch(valid_loader, model, criterion, device):
    model.eval() # set model in evaluation mode
    losses = AverageMeter() # initiate AverageMeter for tracking the loss.
    output_dict = {}
    preds, trues, ids = [], [], []
    start = end = time.time() # track the execution time.
    with tqdm(valid_loader, unit="valid_batch", desc='Validation') as tqdm_valid_loader:
        for step, batch in enumerate(tqdm_valid_loader):
            ids_batch = batch.pop("id")
            original_texts = to_device(collate(batch.pop("original_text")))
            rewritten_texts = to_device(collate(batch.pop("rewritten_text")))
            rewrite_prompts = batch.pop("rewrite_prompt")
            batch_size = len(ids_batch)
            targets = torch.ones(batch_size, device=device) # -1 for dissimilar, 1 for similar
            with torch.no_grad():
                y_preds, y_trues = model(original_texts, rewritten_texts, rewrite_prompts) # forward propagation pass
                loss = criterion(y_preds, y_trues, targets) # get loss
            losses.update(loss.item(), batch_size) # update loss function tracking
            preds.append(y_preds.to('cpu').numpy()) # save predictions
            trues.append(y_trues.to('cpu').numpy()) # save ground truth
            ids += ids_batch
            end = time.time() # get finish time

            # ========== LOG INFO ==========
            if step % config.PRINT_FREQ == 0 or step == (len(valid_loader)-1):
                print('EVAL: [{0}/{1}] '
                      'Elapsed {remain:s} '
                      'Loss: {loss.avg:.4f} '
                      .format(step, len(valid_loader),
                              loss=losses,
                              remain=timeSince(start, float(step+1)/len(valid_loader))))
            if config.WANDB:
                wandb.log({f"[fold_{fold}] val loss": losses.val})

    output_dict["predictions"] = np.concatenate(preds)
    output_dict["ground_truths"] = np.concatenate(trues)
    output_dict["ids"] = ids
    return losses.avg, output_dict

def train_loop(fold=0):

    LOGGER.info(f"========== Fold: {fold} training ==========")
    train = pd.read_parquet(paths.TRAIN_DATA)
    #train2 = pd.read_csv(paths.TRAIN_DATA2)
    #train = pd.concat([train, train2], ignore_index=True)
    valid = pd.read_csv(paths.VALID_DATA)
    #train = train[train['source']!='dpo'].reset_index(drop=True)

    if config.DEBUG:
        train = train.head(1000)
        valid = valid.head(384)

    train['id'] = range(len(train))
    valid['id'] = range(len(valid))
    train_embedding = np.load(paths.train_embedding_file)
    #train_embedding2 = np.load(paths.train_embedding_file2)
    #train_embedding = np.concatenate([train_embedding, train_embedding2])
    vaid_embedding = np.load(paths.valid_embedding_file)

    if config.DEBUG:
        train_embedding = train_embedding[:1000]
        vaid_embedding = vaid_embedding[:384]

    assert len(train)==len(train_embedding)
    assert len(valid)==len(vaid_embedding)
    # ======== DATASETS ==========
    train_dataset = CustomDataset(config, train, tokenizer, train_embedding)
    valid_dataset = CustomDataset(config, valid, tokenizer, vaid_embedding)

    # ======== DATALOADERS ==========
    train_loader = DataLoader(train_dataset,
                              batch_size=config.BATCH_SIZE_TRAIN, 
                              shuffle=True,
                              pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=config.BATCH_SIZE_VALID,
                              shuffle=False,
                              pin_memory=True, drop_last=False)

    # ======== MODEL ==========
    model = CustomModel(config, config_path=None, pretrained=True)
    torch.save(model.config, paths.OUTPUT_DIR + '/config.pth')
    model.to(device)

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=config.LR,
                                                decoder_lr=config.LR,
                                                weight_decay=config.WEIGHT_DECAY)
    optimizer = AdamW(optimizer_parameters,
                      lr=config.LR,
                      eps=config.EPS,
                      betas=config.BETAS)

    scheduler = OneCycleLR(
        optimizer,
        max_lr=config.LR,
        epochs=config.EPOCHS,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy="cos",
        final_div_factor=100,
    )

    # ======= LOSS ==========
    criterion = nn.CosineEmbeddingLoss()

    best_score = -np.inf
    # ====== ITERATE EPOCHS ========
    for epoch in range(config.EPOCHS):

        start_time = time.time()

        # ======= TRAIN ==========
        avg_loss = train_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # ======= EVALUATION ==========
        avg_val_loss, output_dict = valid_epoch(valid_loader, model, criterion, device)
        predictions = output_dict["predictions"]
        ground_truths = output_dict["ground_truths"]

        # ======= SCORING ==========
        score = scs(predictions, ground_truths)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')

        if score > best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save(model.state_dict(),
                        paths.OUTPUT_DIR + f"/{config.CKPT.replace('/', '_')}_fold_{fold}_best.pth")
            best_model_predictions = predictions

    valid.loc[:, "preds"] = best_model_predictions.tolist()
    valid.loc[:, "trues"] = ground_truths.tolist()

    torch.cuda.empty_cache()
    gc.collect()

    return valid

if __name__ == '__main__':
    def get_result(oof_df):
        trues = oof_df["trues"].values
        preds = oof_df["preds"].values
        score = scs(preds, trues)
        LOGGER.info(f'Score: {score:<.4f}')


    oof_df = train_loop()


### seq2seq的检索库生成

In [None]:
%%writefile prompts_embedding_index_generate.py
import sys
from sentence_transformers import SentenceTransformer, models
import pandas as pd
import gc
import numpy as np

df = pd.read_csv(f"prompts_df.csv",)

#df = df.rename(columns={'text':"rewrite_prompt"})
contexts = list(df['rewrite_prompt'])
import faiss

import pandas as pd
import time
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
import pickle
model =  SentenceTransformer('sentence-transformers/sentence-t5-base')
model.max_seq_length = 512

encoded_data = model.encode(contexts, batch_size=64, device='cuda', show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)
encoded_data = encoded_data.detach().cpu().numpy()
encoded_data = np.asarray(encoded_data.astype('float32'))
#np.save('prompts_embedding.npy',encoded_data)
df['rewrite_prompt'].to_csv('prompts_df.csv', index=False)

index = faiss.IndexFlatIP(768)
index.add(encoded_data)
faiss.write_index(index, 'prompts_embedding.index')

### seq2seq推理

In [None]:
%%writefile infer_seq2seq.py
import gc
import math
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import os
import pandas as pd
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F


from glob import glob
from torch.nn import Parameter
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from typing import Dict, List


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Current device is: {device}")

import tokenizers
import transformers
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup


class config:
    BATCH_SIZE_TEST = 4
    DEBUG = False
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    GRADIENT_CHECKPOINTING = True
    MODEL = "microsoft/deberta-v3-base"
    MAX_LEN = 512
    NUM_WORKERS = 0 # multiprocessing.cpu_count()
    SEED = 20

class paths:
    TEST_CSV = "/kaggle/input/llm-prompt-recovery/test.csv"
    TOKENIZER = '/kaggle/input/llmpr-models3/exp14/tokenizer'
    
model_weights = [
                
                '/kaggle/input/llmpr-models3/exp14/deberta-v3-large_fold_0_best.pth',
                ]
model_configs = [
                 '/kaggle/input/llmpr-models3/exp14/config.pth',
                 
               ]
model_weights

def get_config_dict(config):
    """
    Return the config, which is originally a class, as a Python dictionary.
    """
    config_dict = dict((key, value) for key, value in config.__dict__.items() 
    if not callable(value) and not key.startswith('__'))
    return config_dict


def seed_everything(seed=20):
    """Seed everything to ensure reproducibility"""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
    
def to_device(inputs, device: str = device):
    return {k: v.to(device) for k, v in inputs.items()}

seed_everything(seed=config.SEED)

test_df = pd.read_csv(paths.TEST_CSV)
print(f"Dataframe has shape: {test_df.shape}")
test_df.head()

tokenizer = AutoTokenizer.from_pretrained(paths.TOKENIZER)
print(tokenizer)

test_df['original_text'] = test_df['original_text'].fillna("")
test_df['rewritten_text'] = test_df['rewritten_text'].fillna("")

def prepare_input(cfg: type, text: np.ndarray, tokenizer):
    """
    This function tokenizes the input text with the configured padding and truncation. Then,
    returns the input dictionary, which contains the following keys: "input_ids",
    "token_type_ids" and "attention_mask". Each value is a torch.tensor.
    :param cfg: configuration class.
    :param text: a numpy array where each value is a text as string.
    :return inputs: python dictionary where values are torch tensors.
    """
    inputs = tokenizer.encode_plus(
        text, 
        return_tensors=None, 
        add_special_tokens=True, 
        max_length=cfg.MAX_LEN,
        padding='max_length', # TODO: check padding to max sequence in batch
        truncation=True
    )
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long) # TODO: check dtypes
    return inputs


def collate(inputs):
    """
    It truncates the inputs to the maximum sequence length in the batch. 
    """
    mask_len = int(inputs["attention_mask"].sum(axis=1).max()) # Get batch's max sequence length
    for k, v in inputs.items():
        inputs[k] = inputs[k][:,:mask_len]
    return inputs


class CustomDataset(Dataset):
    def __init__(self, cfg, df, tokenizer):
        self.cfg = cfg
        self.original_texts = df['original_text'].values
        self.rewritten_texts = df['rewritten_text'].values
        self.rewrite_prompts = []
        self.text_ids = df['id'].astype(str).values
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.text_ids)

    def __getitem__(self, item):
        output = {}
        output["original_text"] = prepare_input(self.cfg, self.original_texts[item], self.tokenizer)
        output["rewritten_text"] = prepare_input(self.cfg, self.rewritten_texts[item], self.tokenizer)
        output["rewrite_prompt"] = []
        output["id"] = self.text_ids[item]
        return output

class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings
    

class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, mode: str = "test", pretrained=False):
        super().__init__()
        self.cfg = cfg
        self.mode = mode
        self.dropout = 0.2
        # Load config by inferencing it from the model name.
        if config_path is None: 
            self.config = AutoConfig.from_pretrained(cfg.MODEL, output_hidden_states=True)
            self.config.hidden_dropout = 0.
            self.config.hidden_dropout_prob = 0.
            self.config.attention_dropout = 0.
            self.config.attention_probs_dropout_prob = 0.
        # Load config from a file.
        else:
            self.config = torch.load(config_path)
        
        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.MODEL, config=self.config)
        else:
            self.model = AutoModel.from_config(self.config)
        
        if self.cfg.GRADIENT_CHECKPOINTING:
            self.model.gradient_checkpointing_enable()
          
        self.t5_encoder = None #hub.KerasLayer(cfg.T5_MODEL)
        self.pool = MeanPooling()
        self.head = nn.Sequential(
            nn.Linear(self.config.hidden_size*4, 32768),
            nn.BatchNorm1d(32768),
            nn.ReLU(),
            nn.Linear(32768, 768),
        )
        self._init_weights(self.head)
        
    def _init_weights(self, module):
        """
        This method initializes weights for different types of layers. The type of layers 
        supported are nn.Linear, nn.Embedding and nn.LayerNorm.
        """
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        
    def feature(self, inputs):
        outputs = self.model(**inputs)
        #last_hidden_states = outputs[1]
        feature1 = self.pool(outputs.hidden_states[-1], inputs['attention_mask'])
        feature2 = self.pool(outputs.hidden_states[-2], inputs['attention_mask'])
        return torch.cat([feature1, feature2], dim=1)

    def forward(self, original_texts, rewritten_texts, rewrite_prompts):
        """
        This method makes a forward pass through the model, the MeanPooling layer and finally
        then through the Linear layer to get a regression value.
        """
        original_texts_feature = self.feature(original_texts) # shape (batch_size, 768)
        rewritten_texts_feature = self.feature(rewritten_texts) # shape (batch_size, 768)
        feature = torch.cat([original_texts_feature, rewritten_texts_feature], dim=1) # shape (batch_size, 768 * 2)
        output = self.head(feature)
        
        if self.mode == "train":
            prompt_embedding = torch.tensor(self.t5_encoder(rewrite_prompts)[0].numpy(), device=self.cfg.DEVICE) # shape (batch_size, 768)
        else:
            prompt_embedding = []
            
        return output, prompt_embedding
    

def sharpened_cosine_similarity(k: np.ndarray, s: np.ndarray, p: int = 3, q: float = 1e-10):
    """
    Computes Sharpened Cosine Similarity (SCS) between two numpy arrays of shape (N,).
    :param k: prediction embedding.
    :param s: ground truth embedding.
    :param p: SCS power.
    :param q: small value for numerical stability.
    :return score: SCS score.
    """
    dot_product = np.dot(k, s)
    norm_k = np.linalg.norm(k)
    norm_s = np.linalg.norm(s)
    # Compute the cosine similarity with added term q for numerical stability
    cosine_similarity = dot_product / (norm_k * (norm_s + q))
    # Compute the sharpened cosine similarity
    score = np.sign(dot_product) * (cosine_similarity ** p)
    return score


def scs(k_batch: np.ndarray, s_batch: np.ndarray, p: int = 3, q: float = 1e-10):
    """
    Computes Sharpened Cosine Similarity (SCS) between two batches of numpy arrays of shape (batch_size, N).
    :param k: prediction embedding.
    :param s: ground truth embedding.
    :param p: SCS power.
    :param q: small value for numerical stability.
    :return score: mean SCS score for the batch.
    """
    bs = k_batch.shape[0]
    scores = []
    for item in range(bs):
        k = k_batch[item]
        s = s_batch[item]
        scores.append(sharpened_cosine_similarity(k, s, p, q))
    score = np.mean(scores)
    return score

def inference_fn(model_weight, config, test_df, tokenizer, device, model_config):
    # ======== DATASETS ==========
    test_dataset = CustomDataset(config, test_df, tokenizer)
    
    # ======== DATALOADERS ==========
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE_TEST,
        shuffle=False,
        num_workers=0,
        pin_memory=True, drop_last=False
    )
    
    # ======== MODEL ==========
    model = CustomModel(config, config_path=model_config, pretrained=False)
    state = torch.load(model_weight)
    model.load_state_dict(state)
    model.to(device)
    model.eval() # set model in evaluation mode
    output_dict = {}
    preds, ids = [], []
    with tqdm(test_loader, unit="test_batch", desc='Test') as tqdm_test_loader:
        for step, batch in enumerate(tqdm_test_loader):
            ids_batch = batch.pop("id")
            original_texts = to_device(collate(batch.pop("original_text")))
            rewritten_texts = to_device(collate(batch.pop("rewritten_text")))
            rewrite_prompts = []
            batch_size = len(ids_batch)
            targets = torch.ones(batch_size, device=device) # -1 for dissimilar, 1 for similar
            with torch.no_grad():
                y_preds, _ = model(original_texts, rewritten_texts, rewrite_prompts)            
            preds.append(y_preds.to('cpu').numpy()) # save predictions
            ids += ids_batch          
    output_dict["predictions"] = np.concatenate(preds) 
    output_dict["ids"] = ids
    return output_dict

preds = []

for model_weight, model_config in zip(model_weights, model_configs):
    predictions = inference_fn(model_weight, config, test_df, tokenizer, device, model_config)
    predictions = predictions["predictions"]
    predictions = torch.nn.functional.normalize(torch.from_numpy(predictions), p=2, dim=1).numpy()
    preds.append(predictions)
    
preds = np.mean(preds, axis=0)

import faiss
from faiss import write_index, read_index, read_VectorTransform

prompts_embedding_index = read_index("./prompts_embedding.index")
search_score, search_index = prompts_embedding_index.search(preds, 1)
prompts_df = pd.read_csv("./prompts_df.csv")
prompts_df.head()

pred_prompts = []

for i, (scr, idx) in tqdm(enumerate(zip(search_score, search_index)), total=len(search_score)):
    scr_idx = idx
    p = prompts_df.loc[scr_idx, "rewrite_prompt"].tolist()
    pred_prompts.append(''.join(p))

values = pred_prompts

submission = pd.DataFrame()
submission["id"] = test_df["id"]
submission["rewrite_prompt"] = values
submission.to_csv("pred1.csv", index=False)

In [None]:
!python infer_seq2seq.py

### 开源微调模型sft-phi2的推理

In [None]:
%%writefile infer_phi.py
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch

from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
input_token_len = 1024
output_token_len = 100
test_df = pd.read_csv('/kaggle/input/llm-prompt-recovery/test.csv')
base_model_name = "/kaggle/input/phi/transformers/2/1"#/kaggle/input/phi/transformers/2/1
adapter_model_name = "/kaggle/input/phi2-public-data-sft-adapter/pytorch/public-data-sft/1/phi2_public_data_sft"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(base_model_name,trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
model = AutoModelForCausalLM.from_pretrained(base_model_name,trust_remote_code=True)
model = PeftModel.from_pretrained(model, adapter_model_name)
model.to(device)
model.eval()
print('model loaded !!')
def text_generate(ori_text, rew_text,model, tokenizer, stop_tokens=['.',';',':','<|endoftext|>'], input_max_len=512, output_len=20, device='cuda'):
    prompt = f"Instruct: Original Text:{ori_text}\nRewritten Text:{rew_text}\nWrite a prompt that was likely given to the LLM to rewrite original text to rewritten text.\nOutput:"
    inputs = tokenizer(prompt, max_length=input_max_len, truncation=True, return_tensors="pt", return_attention_mask=False)
    output_start_index = len(inputs.input_ids[0])
    inputs = {k:v.to(device) for k,v in inputs.items()}
    outputs = model.generate(**inputs,
                             do_sample=False,
                             max_new_tokens=output_len,
                             pad_token_id=tokenizer.pad_token_id,
                             eos_token_id=tokenizer.convert_tokens_to_ids(stop_tokens),
                            )
    text = tokenizer.batch_decode(outputs,skip_special_tokens=True,clean_up_tokenization_spaces=False)[0]
    start_index = text.find('Output:')
    generated_text = text[start_index+len('Output:'):].strip()[:-1]
    return generated_text

import nltk
from nltk import sent_tokenize
import re
rewrite_prompts = []
for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
    prompt = 'Please improve this text.'
    try:
        prompt = text_generate(row['original_text'],
                               row['rewritten_text'],
                               model,
                               tokenizer,
                               ['.',';',':','<|endoftext|>'],
                               input_token_len,
                               output_token_len,
                               device,
                              )
    except:
        pass
        
    rewrite_prompts.append(prompt)



test_df['rewrite_prompt'] = rewrite_prompts
sub_df = test_df[['id', 'rewrite_prompt']]
sub_df.to_csv('pred2.csv', index=False)

In [None]:
!python infer_phi.py

### LLM模型的zero-shot推理

In [None]:
%%writefile mistral_infer.py
import torch
import random
import numpy as np
import pandas as pd
import gc
import time

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

#https://github.com/Lightning-AI/lit-gpt/issues/327
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

if (not torch.cuda.is_available()): print("Sorry - GPU required!")
    
import logging
logging.getLogger('transformers').setLevel(logging.ERROR)
#this can help speed up inference
max_new_tokens = 30

#output test is trimmed according to this
max_sentences_in_response = 1
model_name = '/kaggle/input/mistral-7b-it-v02'
tokenizer = AutoTokenizer.from_pretrained(model_name) 

# Load base model(Mistral 7B)
bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
)
#original text prefix
orig_prefix = "Original Text:"

#mistral "response"
llm_response_for_rewrite = "Provide the new text and I will tell you what new element was added or change in tone was made to improve it - with no references to the original.  I will avoid mentioning names of characters.  It is crucial no person, place or thing from the original text be mentioned.  For example - I will not say things like 'change the puppet show into a book report' - I would just say 'Please improve this text using the writing style of a book report'.  If the original text mentions a specific idea, person, place, or thing - I will not mention it in my answer.  For example if there is a 'dog' or 'office' in the original text - the word 'dog' or 'office' must not be in my response.  My answer will be a single sentence."

#modified text prefix
rewrite_prefix = "Re-written Text:"

#provided as start of Mistral response (anything after this is used as the prompt)
#providing this as the start of the response helps keep things relevant
response_start = "The request was: "

#added after response_start to prime mistral
#"Improve this" or "Improve this text" resulted in non-answers.  
#"Improve this text by" seems to product good results
response_prefix = "Please improve this text using the writing style"

#well-scoring baseline text
#thanks to: https://www.kaggle.com/code/rdxsun/lb-0-61
base_line = 'Please improve this text using the writing style with maintaining the original meaning but altering the tone.' 

#these will all be given to Mistral before each and every prompt
#original_text
#rewritten_text
#prompt

examples_sequences = [
    (
        "Hey there! Just a heads up: our friendly dog may bark a bit, but don't worry, he's all bark and no bite!",
        "Warning: Protective dog on premises. May exhibit aggressive behavior. Ensure personal safety by maintaining distance and avoiding direct contact.",
        "Please improve this text using the writing style of a warning."
    ),

    (
        "A lunar eclipse happens when Earth casts its shadow on the moon during a full moon. The moon appears reddish because Earth's atmosphere scatters sunlight, some of which refracts onto the moon's surface. Total eclipses see the moon entirely in Earth's shadow; partial ones occur when only part of the moon is shadowed.",
        "Yo check it, when the Earth steps in, takes its place, casting shadows on the moon's face. It's a full moon night, the scene's set right, for a lunar eclipse, a celestial sight. The moon turns red, ain't no dread, it's just Earth's atmosphere playing with sunlight's thread, scattering colors, bending light, onto the moon's surface, making the night bright. Total eclipse, the moon's fully in the dark, covered by Earth's shadow, making its mark. But when it's partial, not all is shadowed, just a piece of the moon, slightly furrowed. So that's the rap, the lunar eclipse track, a dance of shadows, with no slack. Earth, moon, and sun, in a cosmic play, creating the spectacle we see today.",
        "Please improve this text using the writing style of a rap."
    ),
    
    (
        "Drinking enough water each day is crucial for many functions in the body, such as regulating temperature, keeping joints lubricated, preventing infections, delivering nutrients to cells, and keeping organs functioning properly. Being well-hydrated also improves sleep quality, cognition, and mood.",
        "Arrr, crew! Sail the health seas with water, the ultimate treasure! It steadies yer body's ship, fights off plagues, and keeps yer mind sharp. Hydrate or walk the plank into the abyss of ill health. Let's hoist our bottles high and drink to the horizon of well-being!",
        "Please improve this text using the writing style of a sea pirate."
    ),
    
    (
        "In a bustling cityscape, under the glow of neon signs, Anna found herself at the crossroads of endless possibilities. The night was young, and the streets hummed with the energy of life. Drawn by the allure of the unknown, she wandered through the maze of alleys and boulevards, each turn revealing a new facet of the city's soul. It was here, amidst the symphony of urban existence, that Anna discovered the magic hidden in plain sight, the stories and dreams that thrived in the shadows of skyscrapers.",
        "On an ordinary evening, amidst the cacophony of a neon-lit city, Anna stumbled upon an anomaly - a door that defied the laws of time and space. With the curiosity of a cat, she stepped through, leaving the familiar behind. Suddenly, she was adrift in the stream of time, witnessing the city's transformation from past to future, its buildings rising and falling like the breaths of a sleeping giant.",
        "Please improve this text using the writing style with time travel topic."
    ),
    
    (
        "Late one night in the research lab, Dr. Evelyn Archer was on the brink of a breakthrough in artificial intelligence. Her fingers danced across the keyboard, inputting the final commands into the system. The lab was silent except for the hum of machinery and the occasional beep of computers. It was in this quiet orchestra of technology that Evelyn felt most at home, on the cusp of unveiling a creation that could change the world.",
        "In the deep silence of the lab, under the watchful gaze of the moon, Dr. Evelyn Archer found herself not alone. Beside her, the iconic red eye of HAL 9000 flickered to life, a silent partner in her nocturnal endeavor. 'Good evening, Dr. Archer,' HAL's voice filled the room, devoid of warmth yet comforting in its familiarity. Together, they were about to initiate a test that would intertwine the destiny of human and artificial intelligence forever. As Evelyn entered the final command, HAL processed the data with unparalleled precision, a testament to the dawn of a new era.",
        "Please improve this text using the writing style with an intelligent computer."
    ),
    
    (
        "The park was empty, save for a solitary figure sitting on a bench, lost in thought. The quiet of the evening was punctuated only by the occasional rustle of leaves, offering a moment of peace in the chaos of city life.",
        "Beneath the cloak of twilight, the park transformed into a realm of solitude and reflection. There, seated upon an ancient bench, was a lone soul, a guardian of secrets, enveloped in the serenity of nature's whispers. The dance of the leaves in the gentle breeze sang a lullaby to the tumult of the urban heart.",
        "Please improve this text using the writing style to be more poetic."
    ),
    
    (
        "The annual town fair was bustling with activity, from the merry-go-round spinning with laughter to the game booths challenging eager participants. Amidst the excitement, a figure in a cloak moved silently, almost invisibly, among the crowd, observing everything with keen interest but participating in none.",
        "Beneath the riot of color and sound that marked the town's annual fair, a solitary figure roamed, known to the few as Eldrin the Enigmatic. Clad in a cloak that shimmered with the whispers of the arcane, Eldrin moved with the grace of a shadow, his gaze piercing the veneer of festivity to the magic beneath. As a master of the mystic arts, he sought not the laughter of the crowds but the silent stories woven into the fabric of the fair. With a flick of his wrist, he could coax wonder from the mundane, transforming the ordinary into spectacles of shimmering illusion, his true participation hidden within the folds of mystery.",
        "Please improve this text using the writing style by adding a magician."
    ),
    
    (
        "The startup team sat in the dimly lit room, surrounded by whiteboards filled with ideas, charts, and plans. They were on the brink of launching a new app designed to make home maintenance effortless for homeowners. The app would connect users with local service providers, using a sophisticated algorithm to match needs with skills and availability. As they debated the features and marketing strategies, the room felt charged with the energy of creation and the anticipation of what was to come.",
        "In the quiet before dawn, a small group of innovators gathered, their mission: to simplify home maintenance through technology. But their true journey began with the unexpected addition of Max, a talking car with a knack for solving problems. 'Let me guide you through this maze of decisions,' Max offered, his dashboard flickering to life.",
        "Please improve this text using the writing style by adding a talking car."
    ),
    
        

    
    
]

def remove_numbered_list(text):
    final_text_paragraphs = [] 
    for line in text.split('\n'):
        # Split each line at the first occurrence of '. '
        parts = line.split('. ', 1)
        # If the line looks like a numbered list item, remove the numbering
        if len(parts) > 1 and parts[0].isdigit():
            final_text_paragraphs.append(parts[1])
        else:
            # If it doesn't look like a numbered list item, include the line as is
            final_text_paragraphs.append(line)

    return '  '.join(final_text_paragraphs)


#trims LLM output to just the response
def trim_to_response(text):
    terminate_string = "[/INST]"
    text = text.replace('</s>', '')
    #just in case it puts things in quotes
    text = text.replace('"', '')
    text = text.replace("'", '')

    last_pos = text.rfind(terminate_string)
    return text[last_pos + len(terminate_string):] if last_pos != -1 else text

#looks for response_start / returns only text that occurs after
def extract_text_after_response_start(full_text):
    parts = full_text.rsplit(response_start, 1)  # Split from the right, ensuring only the last occurrence is considered
    if len(parts) > 1:
        return parts[1].strip()  # Return text after the last occurrence of response_start
    else:
        return full_text  # Return the original text if response_start is not found

    
#trims text to requested number of sentences (or first LF or double-space sequence)
def trim_to_first_x_sentences_or_lf(text, x):
    if x <= 0:
        return ""

    # Any double-spaces dealt with as linefeed
    text = text.replace("  ", "\n")

    # Split text at the first linefeed
    text_chunks = text.split('\n', 1)
    first_chunk = text_chunks[0]

    # Split the first chunk into sentences, considering the space after each period
    sentences = [sentence.strip() for sentence in first_chunk.split('.') if sentence]

    # If there's a linefeed, return the text up to the first linefeed
    if len(text_chunks) > 1:
        # Check if the first chunk has fewer sentences than x, and if so, just return it
        if len(sentences) < x:
            trimmed_text = first_chunk
        else:
            # Otherwise, trim to x sentences within the first chunk
            trimmed_text = '. '.join(sentences[:x]).strip()
    else:
        # If there's no linefeed, determine if the number of sentences is less than or equal to x
        if len(sentences) <= x:
            trimmed_text = '. '.join(sentences).strip()  # Ensure space is preserved after periods
        else:
            # Otherwise, return the first x sentences, again ensuring space after periods
            trimmed_text = '. '.join(sentences[:x]).strip()

    # Add back the final period if it was removed and the text needs to end with a sentence.
    if len(sentences) > 0 and not trimmed_text.endswith('.'):
        trimmed_text += '.'

    return trimmed_text

def get_prompt(orig_text, transformed_text):
    stop_tokens = ['.',':']
    messages = []

    # Append example sequences
    for example_text, example_rewrite, example_prompt in examples_sequences:
        messages.append({"role": "user", "content": f"{orig_prefix} {example_text}"})
        messages.append({"role": "assistant", "content": llm_response_for_rewrite})
        messages.append({"role": "user", "content": f"{rewrite_prefix} {example_rewrite}"})
        messages.append({"role": "assistant", "content": f"{response_start} {example_prompt}"})

    #actual prompt
    messages.append({"role": "user", "content": f"{orig_prefix} {orig_text}"})
    messages.append({"role": "assistant", "content": llm_response_for_rewrite})
    messages.append({"role": "user", "content": f"{rewrite_prefix} {transformed_text}"})
    messages.append({"role": "assistant", "content": f"{response_start}"})
        
    #give it to Mistral
    decode_ids = tokenizer.encode(response_prefix, add_special_tokens=False)
    model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
    
    output_start_index = len(model_inputs[0])
    force_decoder_ids = []
    for i, did in enumerate(decode_ids):
        force_decoder_ids.append([i+output_start_index, did])
    
    model_inputs = model_inputs.to("cuda") 
    generated_ids = model.generate(model_inputs, max_new_tokens=max_new_tokens, 
                                   pad_token_id=tokenizer.eos_token_id,
                                   eos_token_id=tokenizer.convert_tokens_to_ids(stop_tokens),
                                   forced_decoder_ids = force_decoder_ids,
                                  )

    #decode and trim to actual response
    decoded = tokenizer.batch_decode(generated_ids)
    just_response = trim_to_response(decoded[0])        
    final_text = extract_text_after_response_start(just_response)
        
    #mistral has been replying with numbered lists - clean them up....
    final_text = remove_numbered_list(final_text)
        
    #mistral v02 tends to respond with the input after providing the answer - this tries to trim that down
    final_text = trim_to_first_x_sentences_or_lf(final_text, max_sentences_in_response)
    
    #default to baseline if empty or unusually short
    if len(final_text) < 15:
        final_text = base_line
        return final_text
    final_text = final_text[:-1] + ', maintaining the original meaning but altering the tone.'
    return final_text

test_df = pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv")

for index, row in test_df.iterrows():
    result = get_prompt(row['original_text'], row['rewritten_text'])
    print(result)
    test_df.at[index, 'rewrite_prompt'] = result
    
test_df = test_df[['id', 'rewrite_prompt']]
test_df.to_csv('pred3.csv', index=False)

In [None]:
!python mistral_infer.py

### 集成三种模型的结果

In [None]:
import pandas as pd
p1 = pd.read_csv('pred1.csv').sort_values(['id']).reset_index(drop=True).fillna('')
p2 = pd.read_csv('pred2.csv').sort_values(['id']).reset_index(drop=True).fillna('')
p3 = pd.read_csv('pred3.csv').sort_values(['id']).reset_index(drop=True).fillna('')
p1['rewrite_prompt'] = p1['rewrite_prompt'].map(str)+' '+p2['rewrite_prompt'].map(str)+' '+p3['rewrite_prompt'].map(str)
print(p1['rewrite_prompt'].iloc[0])
p1.to_csv('submission.csv', index=False)