In [1]:
!pip install torch sentencepiece datasets evaluate sacrebleu accelerate==0.20.1 transformers==4.28.0


Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sacrebleu
  Downloading sacrebleu-2.3.1-py3-none-any.whl (118 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m118.9/118.9 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate==0.20.1
  Downloading accelerate-0.20.1-py3-none-any.whl (227 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
## Load Dataset

from datasets import load_dataset

dataset = load_dataset("nhankins/legal_data_small")



Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/216k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'doc', 'id', 'original_text', 'reference_summary', 'title'],
        num_rows: 250
    })
})

In [4]:
dataset['train']

Dataset({
    features: ['Unnamed: 0', 'doc', 'id', 'original_text', 'reference_summary', 'title'],
    num_rows: 250
})

In [5]:
print(type(dataset['train']['original_text']))

<class 'list'>


In [6]:
og_text = dataset['train'][0]['original_text']

In [7]:
og_text

'welcome to the pokémon go video game services which are accessible via the niantic inc niantic mobile device application the app. to make these pokémon go terms of service the terms easier to read our video game services the app and our websites located at http pokemongo nianticlabs com and http www pokemongolive com the site are collectively called the services. please read carefully these terms our trainer guidelines and our privacy policy because they govern your use of our services.'

In [16]:

# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

# model_checkpoint = "t5-small"

# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [8]:
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model_checkpoint = "t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading tf_model.h5:   0%|          | 0.00/242M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [9]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")

In [12]:
max_input_length = 512
max_target_length = 30


def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["original_text"],
        max_length=max_input_length,
        truncation=True,
    )
    labels = tokenizer(
        examples["doc"], max_length=max_target_length, truncation=True
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [13]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)


Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [14]:
# Removes columns with strings since collator won't know what to do with them

tokenized_datasets = tokenized_datasets.remove_columns(
    dataset["train"].column_names
)


In [15]:
features = [tokenized_datasets["train"][i] for i in range(2)]
data_collator(features)

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': <tf.Tensor: shape=(2, 108), dtype=int32, numpy=
array([[ 2222,    12,     8,  1977,   157,   154,  2157,   281,   671,
          467,   364,    84,    33,  3551,  1009,     8,     3, 15710,
         1225,    16,    75,     3, 15710,  1225,  1156,  1407,   917,
            8,  1120,     5,    12,   143,   175,  1977,   157,   154,
         2157,   281,  1353,    13,   313,     8,  1353,  1842,    12,
          608,    69,   671,   467,   364,     8,  1120,    11,    69,
         3395,  1069,    44,  2649, 23004,  2157,   839,     3, 15710,
         1225,  9339,     7,     3,   287,    11,  2649,  2442, 23004,
         2157,  7579,   757,     3,   287,     8,   353,    33,  6018,
          120,   718,     8,   364,     5,   754,   608,  4321,   175,
         1353,    69,  8813,  5749,    11,    69,  4570,  1291,   250,
           79, 22417,    39,   169,    13,    69,   364,     5,     1],
       [   57,   338,    69,   364,    25,    33,  2065,    53,    12,
          175,

In [16]:
tf_train_dataset = model.prepare_tf_dataset(
    tokenized_datasets["train"],
    collate_fn=data_collator,
    shuffle=True,
    batch_size=8,
)

Old behaviour: columns=['a'], labels=['labels'] -> (tf.Tensor, tf.Tensor)  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor)  
New behaviour: columns=['a'],labels=['labels'] -> ({'a': tf.Tensor}, {'labels': tf.Tensor})  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) 


In [17]:
from transformers import create_optimizer
import tensorflow as tf

# The number of training steps is the number of samples in the dataset, divided by the batch size then multiplied
# by the total number of epochs. Note that the tf_train_dataset here is a batched tf.data.Dataset,
# not the original Hugging Face Dataset, so its len() is already num_samples // batch_size.
num_train_epochs = 8
num_train_steps = len(tf_train_dataset) * num_train_epochs
model_name = model_checkpoint.split("/")[-1]

optimizer, schedule = create_optimizer(
    init_lr=5.6e-5,
    num_warmup_steps=0,
    num_train_steps=num_train_steps,
    weight_decay_rate=0.01,
)

model.compile(optimizer=optimizer)

# Train in mixed-precision float16
tf.keras.mixed_precision.set_global_policy("mixed_float16")

No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


Log in to HF

In [19]:
new_model = "legal_data_summarizer"

In [20]:
from transformers.keras_callbacks import PushToHubCallback

callback = PushToHubCallback(
    output_dir=f"{new_model}-finetuned-legal", tokenizer=tokenizer
)

model.fit(
    tf_train_dataset, callbacks=[callback], epochs=8
)


Cloning https://huggingface.co/nhankins/legal_data_summarizer-finetuned-legal into local empty directory.


Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8


<keras.callbacks.History at 0x7d04c14095a0>

In [21]:
!pip install tqdm



In [17]:
inputs = tokenizer("I loved reading the Hunger Games!")
inputs

{'input_ids': [27, 1858, 1183, 8, 26049, 5880, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

In [18]:
tokenizer.convert_ids_to_tokens(inputs.input_ids)

['▁I', '▁loved', '▁reading', '▁the', '▁Hunger', '▁Games', '!', '</s>']

In [None]:
!pip install rouge_score


In [None]:
# Using Rouge to measure summary

import evaluate

rouge_score = evaluate.load("rouge")

In [None]:
!pip install nltk

In [None]:
import nltk
# Punctation rules
nltk.download("punkt")

In [None]:
from nltk.tokenize import sent_tokenize


def three_sentence_summary(text):
    return "\n".join(sent_tokenize(text)[:3])


print(three_sentence_summary(dataset["train"][1]["original_text"]))

In [None]:
def evaluate_baseline(dataset, metric):
    summaries = [three_sentence_summary(text) for text in dataset["original_text"]]
    return metric.compute(predictions=summaries, references=dataset["doc"])

In [4]:
# Fine-tuned model name
new_model = "legal_data_summarizer"

# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results"

In [8]:
# Makes data into a batched list

from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
## Inference stage

# from transformers import AutoConfig
# config = AutoConfig.from_pretrained('t5-small')


text = "translate Legal English to Simplified English: *****Write example here*****"

## Can then use a pipline() like below:

from transformers import pipeline

translator = pipeline("translation", model=new_model, max_length=200)
translator(text)


Separate Task Below: HF's Summarization Pipeline. Can comapre results after completion


In [None]:
## Not fine-tuned, just the HF Pipeline summary
## Try and loop through all of them

from transformers import pipeline

# using pipeline API for summarization task
summarization = pipeline("summarization", model=model, tokenizer=tokenizer)
## og_text =
summary_text = summarization(og_text)
print("Summary:", summary_text)