In [4]:
from preprocessing_v2 import *

In [5]:
from transformers import LEDTokenizer, LEDForConditionalGeneration
import torch
from datasets import Dataset
from transformers import AdamW
import time
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import evaluate

In [6]:
rouge = evaluate.load("rouge")

In [7]:
def process_data_to_model_inputs(batch):
    # tokenize the inputs and labels
    inputs = tokenizer(
        batch["encoder_input_string"],
        padding="max_length",
        truncation=True,
        max_length=4096,
    )
    outputs = tokenizer(
        batch["decoder_input_string"],
        padding="max_length",
        truncation=True,
        max_length=4096,
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["labels"] = outputs.input_ids

    # We have to make sure that the PAD token is ignored
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels]
        for labels in batch["labels"]
    ]
    # Ensure global attention mask is padded to 4096 tokens
    for i in range(len(batch["globalmask"])):
        if len(batch["globalmask"][i]) > 4096:
            # Truncate if the length is greater than 4096
            batch["globalmask"][i] = batch["globalmask"][i][:4096]
        elif len(batch["globalmask"][i]) < 4096:
            # Pad if the length is less than 4096
            padding_length = 4096 - len(batch["globalmask"][i])
            batch["globalmask"][i] += [0] * padding_length

    batch["global_attention_mask"] = batch["globalmask"]

    return batch

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [8]:
# Prepare the data, model, and tokenizer before training
preprocessor = preprocess('new_court_cases.csv')



In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device(type='cuda')

In [10]:
model = preprocessor.LED_model
tokenizer = preprocessor.LED_tokenizer
xdata = preprocessor.encoder_inputs
ydata = preprocessor.decoder_inputs
globalmask = preprocessor.global_mask
model.to(device)

LEDForConditionalGeneration(
  (led): LEDModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): LEDEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): LEDLearnedPositionalEmbedding(4096, 768)
      (layers): ModuleList(
        (0-5): 6 x LEDEncoderLayer(
          (self_attn): LEDEncoderAttention(
            (longformer_self_attn): LEDEncoderSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
              (value_global): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): Linear(in_features=768, out_features=768, bias=True)
          )
      

In [11]:
data = {
    'encoder_input_string': xdata,
    'decoder_input_string': ydata,
    'globalmask': globalmask
}

df = pd.DataFrame.from_dict(data)

In [12]:
df.head()

Unnamed: 0,encoder_input_string,decoder_input_string,globalmask
0,this case is about need for prosection and all...,<RULING> this case is about need for prosectio...,"[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,this case is about proof required establish do...,this case is about proof required establish do...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,court resolves instant appeal by certiorari fi...,<FACTS> court resolves instant appeal by certi...,"[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,this is a petition for certiorari petition und...,<FACTS> this is a petition for certiorari peti...,"[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,. . number 230391 petitioner juliette gomez ro...,<FACTS> . . number 230391 petitioner juliette ...,"[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [13]:
train_data, eval_data = train_test_split(df, test_size=0.1, random_state=42)
train_data = train_data.reset_index(drop=True)
eval_data = eval_data.reset_index(drop=True)
train_data = Dataset.from_pandas(train_data)
eval_data = Dataset.from_pandas(eval_data)

In [14]:
train_dataset = train_data.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=2,
    remove_columns=["encoder_input_string", "decoder_input_string", "globalmask"]
)

val_dataset = eval_data.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=2,
    remove_columns=["encoder_input_string", "decoder_input_string", "globalmask"]
)

Map:   0%|          | 0/147 [00:00<?, ? examples/s]

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

Map:   0%|          | 0/17 [00:00<?, ? examples/s]

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [None]:
train_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)
val_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)
print(train_dataset['input_ids'][0].device)  # Should print "cuda:0" if using GPU

In [None]:
print(max([len(input_ids) for input_ids in train_dataset['input_ids']]))
print(max([len(labels) for labels in train_dataset['labels']]))
print(max([len(labels) for labels in train_dataset['global_attention_mask']]))
print(max([len(labels) for labels in train_dataset['attention_mask']]))

In [None]:
print(model.config.max_encoder_position_embeddings)
print(model.config.max_decoder_position_embeddings)

In [None]:
print(model.config)

In [None]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    fp16=True,
    output_dir="./",
    logging_steps=5,
    eval_steps=10,
    save_steps=10,
    save_total_limit=2,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()