In [1]:
import os
import json
import re
import string
import random
import time
import datetime

import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt

from argparse import Namespace
from tqdm import tqdm
# from datasets import Dataset

import transformers
from transformers import BertTokenizer, BertModel, BertConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import pipeline
from transformers import BertTokenizer, DataCollatorForLanguageModeling

import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset

from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
from sklearn.feature_extraction import text
from sklearn.metrics.pairwise import cosine_similarity, linear_kernel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = Namespace(
    mlm_dataset = "./processed_data/sentences.csv",
    sts_datapath = "processed_data/casehold_processed.csv",
    model_save_path = "./models/mlm_sts_per_batch",
    train_split=0.7,
    learning_rate=0.01,
    epochs=3,
    num_samples = 10000
)


## Data Preparation

In [3]:
tokenizer = BertTokenizer.from_pretrained('casehold/legalbert')

### Pre-process STS Data

In [4]:
mlm_df = pd.read_csv(args.mlm_dataset)[0:args.num_samples]
mlm_df['split'] = 'train'
num_train_rows = int(len(mlm_df) * (1 - args.train_split)//2) - 1
mlm_df.loc[:num_train_rows, 'split'] = 'val'
mlm_df.loc[num_train_rows:num_train_rows + num_train_rows, 'split'] = 'test'
mlm_df.head()

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,sentence,split
0,0,533.0,164 (indicating that a judgment of conviction ...,val
1,1,994.0,``(b) Any regulatory preemption of State law s...,val
2,2,1644.0,I threw this months receipt away during cleani...,val
3,3,1001.0,I live in Belgium and I was looking to buy som...,val
4,4,1561.0,"In sum, Smith has submitted a substantial amou...",val


### Pre-process MLM data 

In [5]:
sts_df = pd.read_csv(args.sts_datapath)[0:args.num_samples]
sts_df['split'] = 'train'

num_val_rows = int(len(sts_df) * (1 - args.train_split)//2) - 1

# 15% for validation and test each , remaining 70% for train
sts_df.loc[:num_val_rows, 'split'] = 'val'
sts_df.loc[num_val_rows: num_val_rows + num_val_rows, 'split'] = 'test'

### Create Combined Dataset and Dataloader

In [6]:
from torch.utils.data import Dataset
import torch
from transformers import BertTokenizer, DataCollatorForLanguageModeling

class CombinedDataset(Dataset):
    def __init__(self, tokenizer, class_df, mlm_df, max_length=256, mlm_probability=0.15):
        self.tokenizer = tokenizer
        self.class_df = class_df[class_df['split'] == 'train'].reset_index(drop=True)
        self.mlm_sentences = mlm_df[mlm_df['split'] == 'train']['sentence'].tolist()
        self.max_length = max_length
        self.mlm_probability = mlm_probability

        # Sentence Pair Classification processing
        self.encodings_class = [tokenizer.encode_plus(row['context'], row['holding'], 
                             add_special_tokens=True, max_length=max_length, 
                             pad_to_max_length=True, truncation=True, return_tensors="pt") 
                             for _, row in self.class_df.iterrows()]
        self.labels_class = torch.tensor(self.class_df['binary_label'].tolist())

        # Initialize data collator for dynamic MLM masking
        self.data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_probability)

    def __len__(self):
        return len(self.encodings_class)

    def __getitem__(self, idx):
        # Handling Sentence Pair Classification
        item_class = self.encodings_class[idx]
        input_ids_class = item_class['input_ids'].squeeze()  # Remove batch dimension
        attention_mask_class = item_class['attention_mask'].squeeze()
        token_type_ids_class = item_class['token_type_ids'].squeeze()
        label_class = self.labels_class[idx]

        # Handling MLM
        sentence = self.mlm_sentences[idx % len(self.mlm_sentences)]
        encoding_mlm = self.tokenizer(sentence, max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt")
        # Apply dynamic masking here
        inputs_mlm = self.data_collator([encoding_mlm])

        # Extract masked input_ids and labels for MLM
        input_ids_mlm = inputs_mlm['input_ids'].squeeze()  # Remove batch dimension
        labels_mlm = inputs_mlm['labels'].squeeze()

        return {
            'input_ids_class': input_ids_class,
            'attention_mask_class': attention_mask_class,
            'token_type_ids_class': token_type_ids_class,
            'labels_class': label_class,  # For classification task
            'input_ids_mlm': input_ids_mlm,
            'labels_mlm': labels_mlm  # For MLM task
        }



In [7]:
combined_dataset = CombinedDataset(tokenizer, sts_df, mlm_df)
dataloader = DataLoader(combined_dataset, batch_size=8, shuffle=True)


Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

## Training 

In [8]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

print(device)

cuda


In [9]:
class BertForMLMAndClassification(nn.Module):
    def __init__(self, bert_model_name, num_labels):
        super(BertForMLMAndClassification, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.mlm_head = nn.Linear(self.bert.config.hidden_size, self.bert.config.vocab_size)
        self.classification_head = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),
            nn.ReLU(),
            nn.Linear(self.bert.config.hidden_size, num_labels)
        )
    
    def forward(self, input_ids_class, attention_mask_class, token_type_ids_class, labels_class, input_ids_mlm, labels_mlm):
        # Ensure no gradients for labels
        labels_class = labels_class.detach()
        labels_mlm = labels_mlm.detach()

        # Classification task
        outputs_class = self.bert(input_ids=input_ids_class, 
                                  attention_mask=attention_mask_class, 
                                  token_type_ids=token_type_ids_class)
        pooled_output = outputs_class.pooler_output.detach()  # Detach to ensure no unnecessary gradient computation
        class_logits = self.classification_head(pooled_output)

        # MLM task
        outputs_mlm = self.bert(input_ids=input_ids_mlm, 
                                attention_mask=attention_mask_class)  # Assuming MLM uses the same attention mask
        sequence_output_mlm = outputs_mlm.last_hidden_state.detach()
        prediction_scores = self.mlm_head(sequence_output_mlm)

        # Loss computation
        losses = {}
        if labels_mlm is not None and labels_class is not None:
            loss_fct_mlm = torch.nn.CrossEntropyLoss()
            mlm_loss = loss_fct_mlm(prediction_scores.view(-1, self.bert.config.vocab_size), labels_mlm.view(-1))
            losses['mlm_loss'] = mlm_loss

            loss_fct_class = torch.nn.CrossEntropyLoss()
            class_loss = loss_fct_class(class_logits.view(-1, self.classification_head[-1].out_features), labels_class.view(-1))
            losses['classification_loss'] = class_loss

        # Ensure losses are added only if they exist to prevent backward through empty graph
        # loss = torch.tensor(0.0, device=input_ids_class.device)
        loss = torch.tensor(0.0, device=device)
        if 'mlm_loss' in losses:
            loss += losses['mlm_loss']
        if 'classification_loss' in losses:
            loss += losses['classification_loss']

        return prediction_scores, class_logits, {"total_loss": loss, **losses}


In [10]:
from transformers import BertModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn

class BertForMLMAndClassification(PreTrainedModel):
    def __init__(self, bert_model_name, num_labels):
        config = BertConfig.from_pretrained(bert_model_name)
        super(BertForMLMAndClassification, self).__init__(config)
        self.num_labels = num_labels
        
        # Load the pre-trained BertModel
        self.bert = BertModel.from_pretrained(bert_model_name, config=config)
        
        # Define the Masked Language Model (MLM) head
        self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
        
        # Define the classification head
        self.classification_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, num_labels)
        )
    
    def forward(self, input_ids_class, attention_mask_class, token_type_ids_class, labels_class=None, input_ids_mlm=None, labels_mlm=None):
        # Process input through BertModel
        outputs_class = self.bert(input_ids=input_ids_class, 
                                  attention_mask=attention_mask_class, 
                                  token_type_ids=token_type_ids_class)
        pooled_output = outputs_class.pooler_output

        # Compute classification logits
        class_logits = self.classification_head(pooled_output)

        # Compute MLM logits if input_ids_mlm is provided
        prediction_scores = None
        if input_ids_mlm is not None:
            outputs_mlm = self.bert(input_ids=input_ids_mlm, attention_mask=attention_mask_class)
            sequence_output_mlm = outputs_mlm.last_hidden_state
            prediction_scores = self.mlm_head(sequence_output_mlm)

        # Compute losses if labels are provided
        loss = None
        losses = {}
        if labels_class is not None and labels_mlm is not None:
            loss_fct_mlm = nn.CrossEntropyLoss()
            mlm_loss = loss_fct_mlm(prediction_scores.view(-1, self.config.vocab_size), labels_mlm.view(-1))
            losses['mlm_loss'] = mlm_loss

            loss_fct_class = nn.CrossEntropyLoss()
            class_loss = loss_fct_class(class_logits.view(-1, self.num_labels), labels_class.view(-1))
            losses['classification_loss'] = class_loss

            loss = mlm_loss + class_loss

        # return SequenceClassifierOutput(
        #     loss=loss,
        #     logits=class_logits,
        #     hidden_states=outputs_class.hidden_states,
        #     attentions=outputs_class.attentions,
        #     mlm_logits=prediction_scores,
        # )
        return {
            "loss": loss,
            "logits": class_logits,
            "hidden_states": outputs_class.hidden_states,
            "attentions": outputs_class.attentions,
            "mlm_logits": prediction_scores,
        }


In [11]:
from transformers import AdamW

model = BertForMLMAndClassification('bert-base-uncased', num_labels=2) # Assuming binary classification


model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(dataloader) * args.epochs)
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())


535932928
589299712




In [12]:


for epoch in range(args.epochs):
    total_loss = 0.0
    for step, batch in enumerate(dataloader):
        # print(torch.cuda.memory_allocated())
        # Move batch data to the same device as the model
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Adjust model inputs according to the new forward method signature
        outputs = model(input_ids_class=batch['input_ids_class'], 
                        attention_mask_class=batch['attention_mask_class'],
                        token_type_ids_class=batch['token_type_ids_class'],
                        labels_class=batch['labels_class'],
                        input_ids_mlm=batch['input_ids_mlm'], 
                        labels_mlm=batch['labels_mlm'])
        

        # Combine MLM and classification losses
        # loss = outputs[2]['mlm_loss'] + outputs[2]['classification_loss']
        # loss = outputs[2]['total_loss']
        loss = outputs['loss']

        loss = loss / 4
        loss.backward()
        total_loss += loss.item()

        if (step + 1) % 4 == 0:        
            # Backpropagation
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()            

    print(f'Epoch {epoch + 1}/{args.epochs}, Average Loss: {total_loss / len(dataloader)}')


535932928
1339815424
1339815424
1339815424
1872033280
2414343680
2414343680
2414343680
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280
2413473280
1872033280
2413473280
2413473280


In [26]:
torch.save(model.state_dict(), './models/mlm_sts1.pth')

: 

In [13]:
model.save_pretrained('./models/mlm_sts')