In [1]:
!rm /opt/conda/lib/python3.10/site-packages/aiohttp-3.9.1.dist-info -rdf

In [2]:
!pip install rouge_score



In [3]:
!pip install evaluate 



In [4]:
import torch
import numpy as np

import nltk

import transformers
from datasets import load_dataset
import evaluate

2024-06-08 17:52:57.105229: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-08 17:52:57.105286: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-08 17:52:57.106844: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Data preporcessing

In [6]:
billsum = load_dataset('billsum', split='ca_test')
billsum = billsum.train_test_split(test_size=.1)

In [7]:
tokenizer = transformers.AutoTokenizer.from_pretrained('ainize/bart-base-cnn')

In [8]:
def preprocess_function(examples):
    model_inputs = tokenizer(
        examples['text'], max_length=1024, truncation=True
    )
    labels = tokenizer(
        text_target=examples['summary'], max_length=128, truncation=True
    )
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [9]:
tokenized_billsum = billsum.map(preprocess_function, batched=True)

Map:   0%|          | 0/1113 [00:00<?, ? examples/s]

Map:   0%|          | 0/124 [00:00<?, ? examples/s]

## Metrics

In [10]:
nltk.download('punkt', quiet=True)
metric = evaluate.load('rouge')

In [11]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return result

## Model

In [12]:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained('ainize/bart-base-cnn')
# Batching function
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

  return self.fget.__get__(instance, owner)()


In [13]:
# Define arguments of the finetuning
training_args = transformers.Seq2SeqTrainingArguments(
    output_dir='./bart_finetuning_results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,  # batch size for train
    per_device_eval_batch_size=8,  # batch size for eval
    weight_decay=.01,
    save_total_limit=3,  # num of checkpoints to save 
    num_train_epochs=2,
    fp16=True,
    predict_with_generate=True
)



In [14]:
trainer = transformers.Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum['train'],
    eval_dataset=tokenized_billsum['test'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [15]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33msvir[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,No log,2.1175,0.17864,0.098246,0.15971,0.169381
2,No log,2.065461,0.182595,0.101807,0.163663,0.172274




TrainOutput(global_step=140, training_loss=2.3062035696847096, metrics={'train_runtime': 263.2571, 'train_samples_per_second': 8.456, 'train_steps_per_second': 0.532, 'total_flos': 1357273356042240.0, 'train_loss': 2.3062035696847096, 'epoch': 2.0})

## One sample prediction

In [16]:
text_example = billsum['test']['text'][42]
print(text_example)

The people of the State of California do enact as follows:


SECTION 1.
Section 1464 of the Penal Code is amended to read:
1464.
(a) (1) Subject to Chapter 12 (commencing with Section 76000) of Title 8 of the Government Code, and except as otherwise provided in this section, there shall be levied a state penalty in the amount of ten dollars ($10) for every ten dollars ($10), or part of ten dollars ($10), upon every fine, penalty, or forfeiture imposed and collected by the courts for all criminal offenses, including all offenses, except parking offenses as defined in subdivision (i) of Section 1463, involving a violation of a section of the Vehicle Code or any local ordinance adopted pursuant to the Vehicle Code.
(2) Any bail schedule adopted pursuant to Section 1269b or bail schedule adopted by the Judicial Council pursuant to Section 40310 of the Vehicle Code may include the necessary amount to pay the penalties established by this section and Chapter 12 (commencing with Section 76000

In [17]:
input_ids = tokenizer.encode(
    text_example,
    return_tensors="pt",
    max_length=1024,
    truncation=True,
).to(device)

In [18]:
input_ids.shape

torch.Size([1, 1024])

In [19]:
summary_text_ids = model.generate(
    input_ids=input_ids,
    bos_token_id=model.config.bos_token_id,
    eos_token_id=model.config.eos_token_id,
    max_length=142,
    min_length=56,
    num_beams=4,
)

In [20]:
decoded_text = tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)
print(decoded_text)

Existing law provides for the imposition of a state penalty upon every fine, penalty, or forfeiture imposed and collected by the courts for all criminal offenses, including all offenses, except parking offenses as defined. Existing law also provides that the penalty imposed by this bill is based upon the total fine or bail for each case, except as otherwise provided, and that the state penalty shall be based upon a bail schedule adopted by the Judicial Council, as specified.
This bill would require the clerk of the court to collect the penalty and transmit it to the county treasury.
