In [1]:
import pandas as pd
from dask import annotate
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer
from datasets import Dataset
from sklearn.preprocessing import LabelEncoder
import torch
import tqdm
import warnings

In [2]:
warnings.filterwarnings("ignore")

In [3]:
annotate_df = pd.read_csv('Annotated_Citations.csv')
annotate_df.head(5)

Unnamed: 0,text,startPosition,endPosition,normCite,citeType,altCite,pinCiteStr,pageRangeStr,nodeId,section,sectionAndSubSection,isShortCite,chunk_id,context,label
0,1 USC 1,3479,3486,1 usc 1,USC,,,,0,1 USC 1,1 USC 1,False,0.0,"Division A—Military Construction, Veterans Aff...",Definition
1,or direction,188589,188601,or dir ection,,,,,0,,,False,9.0,"16353(b)). <paragraph display-inline=""no-displ...",Definition
2,42 U.S.C.,245062,245071,42 usc,USC,,,,0,42 U.S.C.,42 U.S.C.,False,4.0,Domestic Food Programs Food and Nutrition Serv...,Authority
3,19 USC 2434,110102,110113,19 usc 2434,USC,,,,0,19 USC 2434,19 USC 2434,False,16.0,"4655)— <clause display-inline=""no-display-inli...",Amending
4,2 FAM 154,343562,343571,[2] 1 fam 154,UK,,,,0,,,False,,(d) None of the funds appropriated or otherwis...,Authority


In [4]:
annotate_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2353 entries, 0 to 2352
Data columns (total 15 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  2353 non-null   object
 1   startPosition         2353 non-null   object
 2   endPosition           2353 non-null   object
 3   normCite              2353 non-null   object
 4   citeType              2306 non-null   object
 5   altCite               451 non-null    object
 6   pinCiteStr            2 non-null      object
 7   pageRangeStr          2 non-null      object
 8   nodeId                2353 non-null   object
 9   section               1707 non-null   object
 10  sectionAndSubSection  1707 non-null   object
 11  isShortCite           2353 non-null   object
 12  chunk_id              1596 non-null   object
 13  context               2353 non-null   object
 14  label                 2353 non-null   object
dtypes: object(15)
memory usage: 275.9+ KB


In [5]:
holdout_df = pd.read_csv('Holdout_Citations.csv')
holdout_df.head(5)

Unnamed: 0,text,startPosition,endPosition,normCite,citeType,altCite,pinCiteStr,pageRangeStr,nodeId,section,sectionAndSubSection,isShortCite,chunk_id,context
0,10 USC 816,117893,117903,10 usc 816,USC,,,,0,10 USC 816,10 USC 816,False,2.0,"<subparagraph display-inline=""no-display-inlin..."
1,section 702(b) of the Department of Agricultur...,133065,133132,7 usc 2257,USC,7 usc 2257,,,0,,,False,,None of the funds appropriated by this or any ...
2,section 302(a) of the Congressional Budget Act...,133766,133820,2 usc 633,USC,2 usc 633,,,0,,,False,0.0,Res. 71 (115th Congress). (3) Classification o...
3,section 801 of the Foreign Intelligence Survei...,39035,39099,50 usc 1885,USC,50 usc 1885,,,0,,,False,,Accountability procedures for incidents relati...
4,50 States,168532,168541,50 stat es,StatutesAtLarge,,,,0,,,False,15.0,(7) Optional product or service The term optio...


In [6]:
holdout_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7122 entries, 0 to 7121
Data columns (total 14 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  7118 non-null   object
 1   startPosition         7118 non-null   object
 2   endPosition           7118 non-null   object
 3   normCite              7118 non-null   object
 4   citeType              6980 non-null   object
 5   altCite               1540 non-null   object
 6   pinCiteStr            3 non-null      object
 7   pageRangeStr          3 non-null      object
 8   nodeId                7118 non-null   object
 9   section               5022 non-null   object
 10  sectionAndSubSection  5022 non-null   object
 11  isShortCite           7118 non-null   object
 12  chunk_id              4795 non-null   object
 13  context               7118 non-null   object
dtypes: object(14)
memory usage: 779.1+ KB


In [7]:
def clean_empty_fields(df):
    df = df.fillna("None")
    object_cols = df.select_dtypes(include='object').columns
    for col in object_cols:
        df[col] = df[col].replace(r'^\s*$', "None", regex=True)
    return df

In [8]:
annotate_df = clean_empty_fields(annotate_df)

In [9]:
annotate_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2353 entries, 0 to 2352
Data columns (total 15 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  2353 non-null   object
 1   startPosition         2353 non-null   object
 2   endPosition           2353 non-null   object
 3   normCite              2353 non-null   object
 4   citeType              2353 non-null   object
 5   altCite               2353 non-null   object
 6   pinCiteStr            2353 non-null   object
 7   pageRangeStr          2353 non-null   object
 8   nodeId                2353 non-null   object
 9   section               2353 non-null   object
 10  sectionAndSubSection  2353 non-null   object
 11  isShortCite           2353 non-null   object
 12  chunk_id              2353 non-null   object
 13  context               2353 non-null   object
 14  label                 2353 non-null   object
dtypes: object(15)
memory usage: 275.9+ KB


In [10]:
holdout_df = clean_empty_fields(holdout_df)

In [11]:
holdout_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7122 entries, 0 to 7121
Data columns (total 14 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  7122 non-null   object
 1   startPosition         7122 non-null   object
 2   endPosition           7122 non-null   object
 3   normCite              7122 non-null   object
 4   citeType              7122 non-null   object
 5   altCite               7122 non-null   object
 6   pinCiteStr            7122 non-null   object
 7   pageRangeStr          7122 non-null   object
 8   nodeId                7122 non-null   object
 9   section               7122 non-null   object
 10  sectionAndSubSection  7122 non-null   object
 11  isShortCite           7122 non-null   object
 12  chunk_id              7122 non-null   object
 13  context               7122 non-null   object
dtypes: object(14)
memory usage: 779.1+ KB


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [13]:
label_encoder = LabelEncoder()
annotate_df["encoded_label"] = label_encoder.fit_transform(annotate_df["label"])
num_labels = len(label_encoder.classes_)
print("Number of unique labels:", num_labels)

Number of unique labels: 6


In [14]:
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
print("Label to Number Mapping:\n", label_mapping)

Label to Number Mapping:
 {'Amending': 0, 'Authority': 1, 'Definition': 2, 'Exception': 3, 'Precedent': 4, 'Rescinding': 5}


In [15]:
annotate_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2353 entries, 0 to 2352
Data columns (total 16 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  2353 non-null   object
 1   startPosition         2353 non-null   object
 2   endPosition           2353 non-null   object
 3   normCite              2353 non-null   object
 4   citeType              2353 non-null   object
 5   altCite               2353 non-null   object
 6   pinCiteStr            2353 non-null   object
 7   pageRangeStr          2353 non-null   object
 8   nodeId                2353 non-null   object
 9   section               2353 non-null   object
 10  sectionAndSubSection  2353 non-null   object
 11  isShortCite           2353 non-null   object
 12  chunk_id              2353 non-null   object
 13  context               2353 non-null   object
 14  label                 2353 non-null   object
 15  encoded_label         2353 non-null   

In [16]:
annotate_df.rename(columns={"label": "original_label", "encoded_label": "label"}, inplace=True)

In [17]:
annotate_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2353 entries, 0 to 2352
Data columns (total 16 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  2353 non-null   object
 1   startPosition         2353 non-null   object
 2   endPosition           2353 non-null   object
 3   normCite              2353 non-null   object
 4   citeType              2353 non-null   object
 5   altCite               2353 non-null   object
 6   pinCiteStr            2353 non-null   object
 7   pageRangeStr          2353 non-null   object
 8   nodeId                2353 non-null   object
 9   section               2353 non-null   object
 10  sectionAndSubSection  2353 non-null   object
 11  isShortCite           2353 non-null   object
 12  chunk_id              2353 non-null   object
 13  context               2353 non-null   object
 14  original_label        2353 non-null   object
 15  label                 2353 non-null   

In [18]:
model_name = "saibo/legal-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [19]:
annotate_df["input_text"] = annotate_df["text"].fillna("None") + " [SEP] " + annotate_df["context"].fillna("None")
holdout_df["input_text"] = holdout_df["text"].fillna("None") + " [SEP] " + holdout_df["context"].fillna("None")

In [20]:
def tokenize_function(examples):
    return tokenizer(examples["input_text"], padding="max_length", truncation=True, max_length=512)

In [21]:
train_dataset = Dataset.from_pandas(annotate_df[["input_text", "label"]])
test_dataset = Dataset.from_pandas(holdout_df[["input_text"]])

In [23]:
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_train = tokenized_train.remove_columns(["input_text"])
tokenized_train = tokenized_train.rename_column("label", "labels")
tokenized_train.set_format("torch")

Map:   0%|          | 0/2353 [00:00<?, ? examples/s]

In [24]:
tokenized_test = test_dataset.map(tokenize_function, batched=True)
tokenized_test = tokenized_test.remove_columns(["input_text"])
tokenized_test.set_format("torch")

Map:   0%|          | 0/7122 [00:00<?, ? examples/s]

In [25]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
model.to("cuda")

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at saibo/legal-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
             

In [26]:
training_args = TrainingArguments(
    output_dir="./legal_model",
    evaluation_strategy="no",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    num_train_epochs=5,
    logging_dir="./logs",
    load_best_model_at_end=False,
)

In [27]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    tokenizer=tokenizer,
)

In [28]:
trainer.train()

Step,Training Loss
500,0.3867
1000,0.2456


TrainOutput(global_step=1475, training_loss=0.27415310617220606, metrics={'train_runtime': 1925.276, 'train_samples_per_second': 6.111, 'train_steps_per_second': 0.766, 'total_flos': 3095612739348480.0, 'train_loss': 0.27415310617220606, 'epoch': 5.0})

In [29]:
predictions = trainer.predict(tokenized_test)

In [30]:
predicted_label_ids = predictions.predictions.argmax(axis=-1)

In [31]:
predicted_labels = label_encoder.inverse_transform(predicted_label_ids)

In [32]:
holdout_df["original_label"] = predicted_labels

In [33]:
holdout_df["original_label"].value_counts()

original_label
Authority     5494
Amending       998
Definition     630
Name: count, dtype: int64

In [34]:
holdout_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7122 entries, 0 to 7121
Data columns (total 16 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  7122 non-null   object
 1   startPosition         7122 non-null   object
 2   endPosition           7122 non-null   object
 3   normCite              7122 non-null   object
 4   citeType              7122 non-null   object
 5   altCite               7122 non-null   object
 6   pinCiteStr            7122 non-null   object
 7   pageRangeStr          7122 non-null   object
 8   nodeId                7122 non-null   object
 9   section               7122 non-null   object
 10  sectionAndSubSection  7122 non-null   object
 11  isShortCite           7122 non-null   object
 12  chunk_id              7122 non-null   object
 13  context               7122 non-null   object
 14  input_text            7122 non-null   object
 15  original_label        7122 non-null   

In [36]:
holdout_df.drop(columns=["input_text"], inplace=True)

In [35]:
annotate_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2353 entries, 0 to 2352
Data columns (total 17 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  2353 non-null   object
 1   startPosition         2353 non-null   object
 2   endPosition           2353 non-null   object
 3   normCite              2353 non-null   object
 4   citeType              2353 non-null   object
 5   altCite               2353 non-null   object
 6   pinCiteStr            2353 non-null   object
 7   pageRangeStr          2353 non-null   object
 8   nodeId                2353 non-null   object
 9   section               2353 non-null   object
 10  sectionAndSubSection  2353 non-null   object
 11  isShortCite           2353 non-null   object
 12  chunk_id              2353 non-null   object
 13  context               2353 non-null   object
 14  original_label        2353 non-null   object
 15  label                 2353 non-null   

In [37]:
annotate_df[["original_label", "label"]].head(5)

Unnamed: 0,original_label,label
0,Definition,2
1,Definition,2
2,Authority,1
3,Amending,0
4,Authority,1


In [39]:
annotate_df.drop(columns=["input_text", "label"], inplace=True)

In [40]:
annotate_df.columns

Index(['text', 'startPosition', 'endPosition', 'normCite', 'citeType',
       'altCite', 'pinCiteStr', 'pageRangeStr', 'nodeId', 'section',
       'sectionAndSubSection', 'isShortCite', 'chunk_id', 'context',
       'original_label'],
      dtype='object')

In [41]:
holdout_df.columns

Index(['text', 'startPosition', 'endPosition', 'normCite', 'citeType',
       'altCite', 'pinCiteStr', 'pageRangeStr', 'nodeId', 'section',
       'sectionAndSubSection', 'isShortCite', 'chunk_id', 'context',
       'original_label'],
      dtype='object')

In [42]:
# DO both dataframes have the same columns?
holdout_df.columns == annotate_df.columns

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True])

In [43]:
# Merge the two dataframes
merged_df = pd.concat([annotate_df, holdout_df], ignore_index=True)

In [44]:
merged_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9475 entries, 0 to 9474
Data columns (total 15 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   text                  9475 non-null   object
 1   startPosition         9475 non-null   object
 2   endPosition           9475 non-null   object
 3   normCite              9475 non-null   object
 4   citeType              9475 non-null   object
 5   altCite               9475 non-null   object
 6   pinCiteStr            9475 non-null   object
 7   pageRangeStr          9475 non-null   object
 8   nodeId                9475 non-null   object
 9   section               9475 non-null   object
 10  sectionAndSubSection  9475 non-null   object
 11  isShortCite           9475 non-null   object
 12  chunk_id              9475 non-null   object
 13  context               9475 non-null   object
 14  original_label        9475 non-null   object
dtypes: object(15)
memory usage: 1.1+ MB


In [45]:
merged_df.head(5), merged_df.tail(5)

(           text startPosition endPosition       normCite citeType altCite  \
 0       1 USC 1          3479        3486        1 usc 1      USC    None   
 1  or direction        188589      188601  or dir ection     None    None   
 2     42 U.S.C.        245062      245071         42 usc      USC    None   
 3   19 USC 2434        110102      110113    19 usc 2434      USC    None   
 4     2 FAM 154        343562      343571  [2] 1 fam 154       UK    None   
 
   pinCiteStr pageRangeStr nodeId      section sectionAndSubSection  \
 0       None         None      0      1 USC 1              1 USC 1   
 1       None         None      0         None                 None   
 2       None         None      0   42 U.S.C.            42 U.S.C.    
 3       None         None      0  19 USC 2434          19 USC 2434   
 4       None         None      0         None                 None   
 
   isShortCite chunk_id                                            context  \
 0       FALSE        0 

In [47]:
merged_df["original_label"].value_counts()

original_label
Authority     7295
Amending      1302
Definition     855
Rescinding      11
Exception        7
Precedent        5
Name: count, dtype: int64

In [46]:
# Save the merged dataframe to a CSV file
merged_df.to_csv("Final_Citations_Label.csv", index=False)

In [48]:
# Save the model and tokenizer
model.save_pretrained("legal_model")
tokenizer.save_pretrained("legal_model")

('legal_model\\tokenizer_config.json',
 'legal_model\\special_tokens_map.json',
 'legal_model\\vocab.json',
 'legal_model\\merges.txt',
 'legal_model\\added_tokens.json',
 'legal_model\\tokenizer.json')