In [5]:
! pip install transformers
! pip install datasets
! pip install torcheval
! pip install pytorch-ignite

Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: xxhash, dill, multiprocess, datasets
Successfully installed dataset

In [6]:
import transformers as T
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from ignite.metrics import Rouge
import re
device = "cuda" if torch.cuda.is_available() else "cpu"

In [21]:
t5_model = T.T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", cache_dir="./cache/").to(device)
t5_tokenizer = T.T5Tokenizer.from_pretrained("google/flan-t5-base", cache_dir="./cache/")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [22]:
def get_tensor(sample):
    # 將模型的輸入和ground truth打包成Tensor
    model_inputs = t5_tokenizer.batch_encode_plus([each["text"] for each in sample], padding=True, truncation=True, return_tensors="pt")
    model_outputs = t5_tokenizer.batch_encode_plus([each["summary"] for each in sample], padding=True, truncation=True, return_tensors="pt")
    return model_inputs["input_ids"].to(device), model_outputs["input_ids"].to(device)
t5_data2 = " {'Rouge-L-P': 0.6806281416676265, 'Rouge-L-R': 0.6745061038113893, 'Rouge-L-F': 0.6745061038113893, 'Rouge-2-P': 0.026492651855262655, 'Rouge-2-R': 0.02629599785119526, 'Rouge-2-F': 0.02629599785119526}"
class CommonGenDataset(Dataset):
    def __init__(self, split="train") -> None:
        super().__init__()
        assert split in ["train", "validation", "test"]
        dataset = load_dataset("hugcyp/LCSTS", split=split, cache_dir="./cache/").to_pandas()
        self.data = []
        for index, row in dataset.iterrows():
            text = row["text"]
            summary = row["summary"]
            self.data.append({"text": text, "summary": summary})

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

data_sample = CommonGenDataset(split="train").data[:3]
print(f"Dataset example: \n{data_sample[0]} \n{data_sample[1]} \n{data_sample[2]}")

Dataset example: 
{'text': '新华社受权于18日全文播发修改后的《中华人民共和国立法法》，修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章，共计105条。', 'summary': '修改后的立法法全文公布'} 
{'text': '一辆小轿车，一名女司机，竟造成9死24伤。日前，深圳市交警局对事故进行通报：从目前证据看，事故系司机超速行驶且操作不当导致。目前24名伤员已有6名治愈出院，其余正接受治疗，预计事故赔偿费或超一千万元。', 'summary': '深圳机场9死24伤续：司机全责赔偿或超千万'} 
{'text': '1月18日，习近平总书记对政法工作作出重要指示：2014年，政法战线各项工作特别是改革工作取得新成效。新形势下，希望全国政法机关主动适应新形势，为公正司法和提高执法司法公信力提供有力制度保障。', 'summary': '孟建柱：主动适应形势新变化提高政法机关服务大局的能力'}


In [11]:
lr = 2e-5
epochs = 1
optimizer = AdamW(t5_model.parameters(), lr=2e-5)
train_batch_size = 16
validation_batch_size = 16
lcsts_train = DataLoader(CommonGenDataset(split="train"), collate_fn=get_tensor, batch_size=train_batch_size, shuffle=True)
lcsts_validation = DataLoader(CommonGenDataset(split="validation"), collate_fn=get_tensor, batch_size=validation_batch_size, shuffle=False)

In [None]:
rouge = Rouge(variants=["L", 2], multiref="best")

In [12]:
def evaluate(model, tokenizer, dataloader, rouge):
    rouge_metric = Rouge(metrics=["rouge-2"])

    model.eval()
    rouge_scores = {"rouge-2": {"precision": 0.0, "recall": 0.0, "f1": 0.0}}

    with torch.no_grad():
        pbar = tqdm(dataloader)
        pbar.set_description(f"Evaluating")

        for inputs, targets in pbar:
            outputs = model.generate(inputs, max_length=50)
            decoded_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            decoded_targets = [tokenizer.decode(target, skip_special_tokens=True) for target in targets]

            for output, target in zip(decoded_outputs, decoded_targets):
                rouge_scores_batch = rouge_metric.compute(predictions=[output], references=[target])
                rouge_scores["rouge-2"]["precision"] += rouge_scores_batch["rouge-2"]["precision"]
                rouge_scores["rouge-2"]["recall"] += rouge_scores_batch["rouge-2"]["recall"]
                rouge_scores["rouge-2"]["f1"] += rouge_scores_batch["rouge-2"]["f1"]

        # Calculate average scores
        total_batches = len(dataloader)
        rouge_scores["rouge-2"]["precision"] /= total_batches
        rouge_scores["rouge-2"]["recall"] /= total_batches
        rouge_scores["rouge-2"]["f1"] /= total_batches

    return rouge_scores

In [24]:
for ep in range(epochs):
    # print(data1 + data2)
    # print(t5_data1 + t5_data2)
    t5_model.train()
    total_loss = 0.0
    pbar = tqdm(lcsts_train)
    pbar.set_description(f"Training epoch [{ep+1}/{epochs}]")
    for inputs, targets in pbar:
        optimizer.zero_grad()
        loss = t5_model(input_ids=inputs, labels=targets).loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix(loss=loss.item())
    avg_loss = total_loss / len(lcsts_train)
    print(f"Avg. Loss on epoch {ep+1}: {avg_loss}")
    # Evaluate the model after each epoch
    rouge_scores = evaluate(t5_model, t5_tokenizer, lcsts_validation, rouge)
    print(f"Rouge-2 score on epoch {ep+1}:", rouge_scores)


Training epoch [1/1]:  100%|██████████| 150037/150037 [8:08:09<0:00:00,  3.79it/s, loss=0.589]
Evaluating: 100%|██████████| 272/272 [01:04<00:00,  4.19it/s]Rouge-2 score on epoch 1: {'Rouge-L-P': 0.6806281416676265, 'Rouge-L-R': 0.6745061038113893, 'Rouge-L-F': 0.6745061038113893, 'Rouge-2-P': 0.026492651855262655, 'Rouge-2-R': 0.02629599785119526, 'Rouge-2-F': 0.02629599785119526}


In [None]:
!pip freeze > requirements.txt
!cat requirements.txt
from google.colab import files
files.download('requirements.txt')

absl-py==1.4.0
aiohttp==3.9.3
aiosignal==1.3.1
alabaster==0.7.16
albumentations==1.3.1
altair==4.2.2
annotated-types==0.6.0
anyio==3.7.1
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.5.1
arviz==0.15.1
astropy==5.3.4
astunparse==1.6.3
async-timeout==4.0.3
atpublic==4.1.0
attrs==23.2.0
audioread==3.0.1
autograd==1.6.2
Babel==2.14.0
backcall==0.2.0
beautifulsoup4==4.12.3
bidict==0.23.1
bigframes==1.0.0
bleach==6.1.0
blinker==1.4
blis==0.7.11
blosc2==2.0.0
bokeh==3.3.4
bqplot==0.12.43
branca==0.7.1
build==1.2.1
CacheControl==0.14.0
cachetools==5.3.3
catalogue==2.0.10
certifi==2024.2.2
cffi==1.16.0
chardet==5.2.0
charset-normalizer==3.3.2
chex==0.1.86
click==8.1.7
click-plugins==1.1.1
cligj==0.7.2
cloudpathlib==0.16.0
cloudpickle==2.2.1
cmake==3.27.9
cmdstanpy==1.2.2
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.4
cons==0.4.6
contextlib2==21.6.0
contourpy==1.2.1
cryptography==42.0.5
cufflinks==0.17.3
cupy-cuda12x==12.2.0

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>