In [1]:
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
# from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.decomposition import PCA
# import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.metrics.pairwise import cosine_similarity
# tqdm.pandas()

from transformers import pipeline, AutoTokenizer,GPT2Tokenizer, GPT2Model, AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset, DatasetDict, Dataset
from sklearn.model_selection import train_test_split

import os
import time

import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForMaskedLM, AdamW
from tqdm import tqdm
import re

2024-03-15 06:07:17.034721: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
# print(torch.cuda.memory_summary())

device

device(type='cuda')

In [3]:
df = pd.read_csv("data/8000_data.csv")
df = df.dropna()
df['original_text'] = df['original_text'].apply(lambda x: x[:200])
df['rewritten_text'] = df['rewritten_text'].apply(lambda x: x[:200])

df = df[:1000]

In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForCausalLM.from_pretrained('bert-base-uncased')
model.eval()

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


BertLMHeadModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [5]:
import re
import torch
from torch.utils.data import Dataset

class MaskedSequenceDataset(Dataset):
    def __init__(self, dataframe, tokenizer, mask_token='[MASK]'):
        self.tokenizer = tokenizer
        self.dataframe = dataframe
        self.mask_token = mask_token

        self.dataframe['original_text'] = dataframe['original_text'].apply(lambda x: re.sub('<.*?>', '', x).strip())
        self.dataframe['rewritten_text'] = dataframe['rewritten_text'].apply(lambda x: re.sub('<.*?>', '', x).strip())
        self.dataframe['rewrite_prompt'] = dataframe['prompt'].apply(lambda x: re.sub('<.*?>', '', x).strip())

    def combine_and_mask(self, original, rewrite, prompt_length):
        masks = " ".join([self.mask_token for _ in range(prompt_length)])
        masked_sequence = f"{original} The task is to rewrite this narrative with the given blanks: {masks} {rewrite}"
        return masked_sequence

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        prompt = row['rewrite_prompt']
        
        target_subsequence_ids = self.tokenizer.encode(prompt, add_special_tokens=False)

        masked_sequence = self.combine_and_mask(row['original_text'], row['rewritten_text'], len(target_subsequence_ids))
        
        inputs = self.tokenizer(masked_sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt")
        input_ids = inputs['input_ids'].squeeze(0)
        
        labels = torch.full_like(input_ids, fill_value=-100)
        
        mask_indices = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=False).squeeze()
        if len(mask_indices) >= len(target_subsequence_ids):
            labels[mask_indices[:len(target_subsequence_ids)]] = torch.tensor(target_subsequence_ids, dtype=torch.long)
        else:
            raise ValueError("Not enough mask tokens to fit the rewrite prompt")

        return input_ids, labels


In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_df, test_df = train_test_split(df, test_size=0.2)

train_dataset = MaskedSequenceDataset(train_df, tokenizer)
test_dataset = MaskedSequenceDataset(test_df, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)



In [7]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.train()
model.to(device) 

optimizer = AdamW(model.parameters(), lr=5e-5)

epochs = 15
for epoch in range(epochs):
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch"):
        input_ids, labels = batch
        input_ids = input_ids.to(device)  # or 'cpu'
        labels = labels.to(device)  # or 'cpu'
        
        model.zero_grad()
        
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Average loss at epoch {epoch + 1}: {avg_loss:.4f}")



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'bert.pooler.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Epoch 1/15:   0%|          | 0/50 [00:00<?, ?batch/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Epoch 1/15: 100%|██████████| 50/50 [00:29<00:00,  1.68batch

Average loss at epoch 1: 4.4658


Epoch 2/15: 100%|██████████| 50/50 [00:29<00:00,  1.72batch/s]


Average loss at epoch 2: 2.8903


Epoch 3/15: 100%|██████████| 50/50 [00:28<00:00,  1.76batch/s]


Average loss at epoch 3: 2.3139


Epoch 4/15: 100%|██████████| 50/50 [00:29<00:00,  1.68batch/s]


Average loss at epoch 4: 1.8673


Epoch 5/15: 100%|██████████| 50/50 [00:29<00:00,  1.71batch/s]


Average loss at epoch 5: 1.5831


Epoch 6/15: 100%|██████████| 50/50 [00:28<00:00,  1.73batch/s]


Average loss at epoch 6: 1.3350


Epoch 7/15: 100%|██████████| 50/50 [00:28<00:00,  1.76batch/s]


Average loss at epoch 7: 1.1343


Epoch 8/15: 100%|██████████| 50/50 [00:29<00:00,  1.70batch/s]


Average loss at epoch 8: 0.8650


Epoch 9/15: 100%|██████████| 50/50 [00:28<00:00,  1.74batch/s]


Average loss at epoch 9: 0.7224


Epoch 10/15: 100%|██████████| 50/50 [00:29<00:00,  1.70batch/s]


Average loss at epoch 10: 0.5600


Epoch 11/15: 100%|██████████| 50/50 [00:29<00:00,  1.68batch/s]


Average loss at epoch 11: 0.4381


Epoch 12/15: 100%|██████████| 50/50 [00:29<00:00,  1.68batch/s]


Average loss at epoch 12: 0.4278


Epoch 13/15: 100%|██████████| 50/50 [00:29<00:00,  1.67batch/s]


Average loss at epoch 13: 0.4514


Epoch 14/15: 100%|██████████| 50/50 [00:29<00:00,  1.67batch/s]


Average loss at epoch 14: 0.2829


Epoch 15/15: 100%|██████████| 50/50 [00:29<00:00,  1.69batch/s]

Average loss at epoch 15: 0.2292





In [None]:
# Save the model after training
# model.save_pretrained('model/fill_blanks.pth')
# tokenizer.save_pretrained('model/fill_blanks_tokenizer.pth')

In [8]:
def batch_cosine_similarity(x1, x2):
    x1_norm = torch.nn.functional.normalize(x1, p=2, dim=-1)
    x2_norm = torch.nn.functional.normalize(x2, p=2, dim=-1)
    
    cos_sim = torch.mm(x1_norm, x2_norm.transpose(0, 1))
    
    return cos_sim

from sentence_transformers import SentenceTransformer

scs_model = SentenceTransformer("sentence-t5-base")

def sharpened_cosine_similarity_batch(scs_model, output_texts, target_texts, sharpen_factor=3):
    target_embeddings = scs_model.encode(target_texts, convert_to_tensor=True)
    output_embeddings = scs_model.encode(output_texts, convert_to_tensor=True)
    
    cos_sims = batch_cosine_similarity(target_embeddings, output_embeddings)
    
    sharpened_scores = [cos_sims[i][i].unsqueeze(0) ** sharpen_factor for i in range(cos_sims.size(0))]
    
    return sharpened_scores

In [9]:
predicts = []
targets = []
for batch in test_loader:
    input_ids, labels = batch
    input_ids = input_ids.to('cuda')
    # Make predictions
    with torch.no_grad():
        outputs = model(input_ids)
        predictions = outputs.logits.argmax(dim=-1)
        
        labels_index = (labels != -100).nonzero(as_tuple=True)[1]
        
        labels_idx = labels[0,labels_index]
        label_tokens = tokenizer.convert_ids_to_tokens(labels_idx)
        label_sentences = tokenizer.convert_tokens_to_string(label_tokens)

        predicted_idx = predictions[0,labels_index]
        predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_idx)
        predicted_sentences = tokenizer.convert_tokens_to_string(predicted_tokens)

        print("Predicted:", predicted_sentences)
        print("Labels:", label_sentences)
        predicts.append(predicted_sentences)
        targets.append(label_sentences)
    


score = sharpened_cosine_similarity_batch(scs_model, predicts, targets, sharpen_factor=3)


Predicted: re the essence text in text into a a ai scientist narrative . narrative . my wheels adventure through the celestial romance , ign narrative . my wheels adventure through the celestial romance , igniting the rest re the essence text in text into a a ai scientist narrative . my wheels adventure through the celestial romance , essence text in text into a a ai scientist narrative . my wheels adventure through the rest the rest re the essence text in text into a a ai scientist narrative . my wheels adventure in text into a a ai scientist narrative . my wheels adventure narrative . my wheels adventure through the celestial romance , igniting the olfactory senses the essence text in text into a a ai scientist narrative . my wheels adventure text in text into a a ai scientist narrative . my wheels adventure through the celestial romance rest the rest re the essence text in text into a a ai scientist narrative . my wheels adventure through the in text into a a ai scientist narrative 

Predicted: translate the essence of this text into a story narrative . this narrative rest the text rest text translate translate the essencewrite this narrative rest the text rest text translate translate the essence of this text into story narrative . narrative action adventureled with the adventure of smoke action story translate the essence of this text into a story narrative . narrative action adventureled with . narrative action adventureled with the adventure of smoke action story as narrative rest the text rest text translate translate the essence of this text into a story narrative . narrative text into a story narrative . narrative action adventureled with the adventure of smoke action into a story narrative . narrative action adventureled with the adventure of smoke action story as rewrite this narrative rest the text rest text translate translate the essence of this text into a story narrative essence of this text into a story narrative . narrative action adventureled with 

Predicted: convey the same message as this text but through the eyes a a a cats detective . text : convey the same message as this text but through the eyes a a a rewrite this narrative rest the text text text : convey the this narrative rest the text text text : convey the same message as this text but through the text : convey the same message as this text but through the eyes a a a cats detective : convey the same message as this text but through the eyes is to rewrite this narrative rest the text text text : convey the same message rest the text text text : convey the same message as this same message as this text but through the eyes a a a cats detective . the same message as this text but through the eyes a a a cats detective . . listen text text : convey the same message as this text but through the eyes a a a cats detective . to rewrite this narrative rest the text text text : convey the same message as this text but the text text text : convey the same message as this text but

Predicted: imagine this text was a this in in the world of superhero , how would it fi written ? from the text blank rest : imagine this text was a this in in the blank rest : imagine this text was a this in in the world of superhero , narrative from the text blank rest : imagine this text was a this in was a this in in the world of superhero , how would it fi written to rewrite this narrative from the text blank rest : imagine this text text was a this in in the world of superhero , how would it fi rewrite this narrative from the text blank rest : imagine this : imagine this text was a this in in the world of superhero , how would it in the world of superhero , how would it fi written ? was a this in in the world of superhero , how would it fi written ?write this narrative from the text blank rest : imagine this text was a this in in the world of a this in in the world of superhero , how would it fi written ? ? my dear gentleman this narrative from the text blank rest : imagine this t

In [10]:
score = torch.mean(torch.stack(score))
score

tensor(0.7949, device='cuda:0')