<a href="https://colab.research.google.com/github/newfull5/AI-Project/blob/master/summarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Wed Feb 15 12:06:23 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    50W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install transformers datasets torch wandb tqdm evaluate rouge_score

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m52.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.9.0-py3-none-any.whl (462 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m462.8/462.8 KB[0m [31m38.5 MB/s[0m eta [36m0:00:00[0m
Collecting wandb
  Downloading wandb-0.13.10-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m65.9 MB/s[0m eta [36m0:00:00[0m
Collecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 KB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... 

- model = t5-base
- dataset = CNN_dailymail 3.0.0
  - max_token_length = about 3,500
  

- config
  - max_length_truncate = 1024
  


In [3]:
from datasets import load_dataset
from tqdm import tqdm

In [4]:
# length 0 ~ 500 => 1900
# length 500 ~ 1000 => 5400
# length 1000 ~ 1500 => 2800
# length 1500 ~ 2000 -> 1150
# length 2000 ~  -> 500

In [5]:
import argparse

def _get_parser():  
    parser = argparse.ArgumentParser()
    parser.add_argument()
    return parser
   

args = argparse.Namespace(  
  model_name="t5-large", 
  tokenizer_name="t5-large",
  dataset_name=['cnn_dailymail', '3.0.0'],
  batch_size = 4,
  lr=3e-5,
  val_check_interval= 2000,
  max_epochs=100
)

In [6]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from datasets import load_dataset

class Dataset(Dataset):
  def __init__(self, args, stage):
    super().__init__()
    self.stage = stage
    self.args = args
    self.plain_target = []
    self.input_ids, self.attention_mask, self.decoder_input_ids, self.labels = self._get_data(args, stage)
    

  def _get_data(self, args, stage):
    dataset = load_dataset(args.dataset_name[0], args.dataset_name[1])
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
    input_ids = []
    attention_mask = []
    decoder_input_ids = []
    labels = []
    

    if stage in ['train', 'validation', 'test']:
      self.plain_target += dataset[stage]['highlights']
      for data_set in tqdm(dataset[stage]):
        inputs = tokenizer(
            text=data_set['article'],
            max_length=1024,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        input_ids.append(inputs['input_ids'])
        attention_mask.append(inputs['attention_mask'])
        
        outputs = tokenizer(
            text=data_set['highlights'],
            max_length=256,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        decoder_input_ids.append(outputs['input_ids'][:, :-1].contiguous())
        label = outputs['input_ids'][:, 1:].clone().detach()
        label[label == tokenizer.pad_token_id] = -100
        labels.append(label)

    else:
      raise Exception("you can set stage only 'train', 'test' or 'validation'")

    return input_ids, attention_mask, decoder_input_ids, labels
      
  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    return self.input_ids[idx], self.attention_mask[idx], self.decoder_input_ids[idx], self.labels[idx], self.plain_target[idx]

In [7]:
from transformers import AutoModelForSeq2SeqLM
import torch
from torch import nn

class Model(nn.Module):
  def __init__(self, args):
    super(Model, self).__init__()
    self.model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.model.to(self.device)

  def forward(self, batch):
    input_ids, attention_mask, decoder_input_ids, labels, _ = batch
    outputs = self.model(
        input_ids=self._move_to_cuda(input_ids).squeeze(),
        attention_mask=self._move_to_cuda(attention_mask),
        decoder_input_ids=self._move_to_cuda(decoder_input_ids).squeeze(),
        labels=self._move_to_cuda(labels)    
    )
    return outputs

  def _move_to_cuda(self, inputs):
    if torch.is_tensor(inputs):
      return inputs.to(self.device)
    elif isinstance(inputs, list):
      return [self._move_to_cuda(x) for x in inputs]
    elif isinstance(inputs, dict):
      return {key: self._move_to_cuda(value) for key, value in inputs.items()}
    else:
      return inputs

  def save(self, save_dir):
    self.model.save_pretrained(save_dir)

  def load(self, save_dir):
    self.model.load_state_dict(
        torch.load(f"{save_dir}/pytorch_model.bin", map_location=torch.device(self.device))
    )

In [8]:
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import evaluate
import wandb

class Trainer:
  def __init__(self, args, model, train_loader, validation_loader):
    self.args = args
    self.model = model
    self.train_loader = train_loader
    self.valid_loader = validation_loader
    self.rouge = evaluate.load('rouge')
    self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
    self.global_steps = 0
    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr)
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.val_check_interval = args.val_check_interval
    self.global_steps = 0

  def training_phase(self, current_epoch):
    self.model.train()
    scaler = GradScaler()
    total_train_loss = 0
    train_steps = 0

    for batch in tqdm(self.train_loader, desc=f"epoch: {current_epoch}"):
      self.optimizer.zero_grad()
      with autocast():
        outputs = model(batch)
      
      scaler.scale(outputs.loss).backward()
      scaler.step(self.optimizer)
      scaler.update()
      self.global_steps += 1
      total_train_loss += float(outputs.loss)

      if self.global_steps % self.val_check_interval == 0 and self.global_steps != 0:
        wandb.log({
            'train_loss': (total_train_loss/self.val_check_interval)
            })
        
        total_train_loss = 0
        model.save(f"./cnn_daily_summrization/{self.global_steps}/")
        self.valid_phase()

  def valid_phase(self):
    self.model.eval()
    predictions = []
    total_valid_loss = 0

    for batch in self.valid_loader:
      outputs = model(batch)
      total_valid_loss += float(outputs.loss)
      
      predictions += [self.tokenizer.decode(logit.argmax(dim=1)).strip() for logit in outputs.logits]

    rouge_score = self.rouge.compute(predictions=predictions,
                                references=sum([list(t[-1]) for t in self.valid_loader], []),
                                tokenizer=lambda x: self.tokenizer(x)['input_ids']
                                )
    
    wandb.log({
        'rouge1': rouge_score['rouge1'],
        'rouge2': rouge_score['rouge2'],
        'rougeL': rouge_score['rougeL'],
        'rougeLsum': rouge_score['rougeLsum'],
        'valid_loss': total_valid_loss/len(self.valid_loader)
    })

  def fit(self):
    for epoch in range(1, self.args.max_epochs+1):
      self.training_phase(epoch)

In [9]:
wandb.init('Summarization')

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset = Dataset(args, 'train'),
    batch_size = args.batch_size,
    shuffle = True
)

valid_loader = DataLoader(
    dataset = Dataset(args, 'validation'),
    batch_size = args.batch_size,
    shuffle = True
)

model = Model(args)

Downloading builder script:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/9.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

Downloading and preparing dataset cnn_dailymail/3.0.0 to /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de...


Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

In [None]:
trainer = Trainer(
    args=args,
    model=model,
    train_loader=train_loader,
    validation_loader=valid_loader
)

In [None]:
trainer.fit()