# PTG Training code 분석
작성자 : 설지우

아래의 code는 Textbox 2.0의 PTG와 관련된 code를 이해를 위해 추출한 파일입니다.

설정
* github link: [https://github.com/RUCAIBox/TextBox](https://github.com/RUCAIBox/TextBox)
* 수행서버: 서버3 GPU2 RTX 3090 single GPU
* 파일 위치는 TextBox/230704 PTG.ipynb입니다.
* dataset은 "XSum"으로 수행합니다. 부디, Textbox/dataset/xsum에 dataset을 미리 다운로드 받으셔야함
* 이해를 위해 1개의 batch만을 추출하여 수행하였습니다. 
* 시작은 'run_textbox.py'부터 'trainer'까지 호출되는 함수들을 순서대로 정리한 것이니 참고만 해주세요. 
<br><br>

순서
1. Load Prompts: source_prompt.pth 파일 로드 
2. Load XSum : XSum test dataset을 로드하고 1개의 batch로 추출 
3. Load PTG model : PTG model의 내부 동작을 추출한 code입니다. (model의 forward 동작에 해당됨)<br>
    3-1. PTG의 _process_prompt_tuning_input <br>
    3-2. model()<br>
    3-3. Loss function<br>

## 1. Load Prompts

In [3]:
import torch

source_task_set = {
    'cross-task1': ['squad', 'wiki', 'quora', 'wp', 'cnndm'],
    'cross-task2': ['squad', 'wiki', 'quora', 'wp', 'pc'],
    'cross-dataset1': ['msn', 'mn', 'nr'],
    'cross-dataset2': ['tc', 'da', 'mw'],
}

prompt_source = torch.load('prompt_source.pth')
prompt_source

{'cnndm': tensor([[-0.5581,  0.5014, -0.2667,  ..., -0.3199, -0.4092, -0.9452],
         [ 0.3267, -0.0185,  1.6073,  ..., -0.6338,  0.2770, -0.9944],
         [-0.1004,  0.2631,  0.2974,  ..., -0.2703, -0.1962, -0.0020],
         ...,
         [-0.4019,  0.0385,  1.0486,  ...,  0.0838, -0.5052,  0.0451],
         [-0.0628,  0.0940, -0.0373,  ..., -0.7181,  0.3422,  0.1762],
         [ 0.0689,  0.5133,  0.2685,  ...,  0.6588,  0.9140,  0.7408]]),
 'da': tensor([[-9.4889e-01, -3.9889e-01,  1.3212e+00,  ..., -1.2951e-01,
           2.0233e-01,  2.5516e-01],
         [ 1.1851e+00, -3.0833e-01,  7.8064e-01,  ...,  1.9099e-02,
          -8.1817e-01, -1.1408e+00],
         [-6.8315e-01, -2.8800e-01,  1.0065e+00,  ..., -7.6061e-02,
          -1.1598e+00, -9.9347e-01],
         ...,
         [-2.4265e-01,  7.2915e-01,  2.9903e-01,  ...,  5.9689e-01,
          -8.6180e-01, -4.8323e-01],
         [ 5.9264e-01,  1.1957e-03,  7.2779e-01,  ..., -1.2124e+00,
           7.4641e-01,  4.6784e-01],
    

In [4]:
source_task = source_task_set['cross-dataset2']
task_embedding = [prompt_source[task] for task in source_task]
task_embedding = torch.stack(task_embedding)

In [5]:
print(f'task_embedding : {task_embedding.shape}')

task_embedding : torch.Size([3, 200, 1024])


## 2. Load XSum Dataset

In [6]:
import os
from textbox.data.misc import load_data, _pad_sequence

data_path = 'dataset/xsum'
source_filename = os.path.join(data_path, f'test.src')
target_filename = os.path.join(data_path, f'test.tgt')

source_text = load_data(source_filename, max_length=0)
source_length = 800        # 원래는 1024지만 length 계산하기 시간이 부족해서 임의로 설정했음
target_text = load_data(target_filename, max_length=0)
target_length = 64

In [7]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')

### _process_prompt

In [8]:
prefix = 'Summarize:'
suffix = ""

prefix_ids = tokenizer.encode(prefix, add_special_tokens=False)
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

source_max_length = source_length - tokenizer.num_special_tokens_to_add() - len(prefix_ids) - len(suffix_ids)
target_max_length = target_length - tokenizer.num_special_tokens_to_add()

In [9]:
print(f'prefix_ids : {prefix_ids}')

prefix_ids : [38182, 3916, 2072, 35]


### tokenize

In [10]:
index = [0,2,8,3]
source_ids = []
ids = tokenizer(
        source_text,
        add_special_tokens=False,
        return_token_type_ids=False,
        return_attention_mask=False,
    )['input_ids']
source_ids.extend(ids)
source_ids = source_ids[:10]

after_src_ids = []
for ids in source_ids:
    ids = ids[:source_length] 
    ids = prefix_ids + ids + suffix_ids
    ids = tokenizer.build_inputs_with_special_tokens(ids)
    after_src_ids.append(torch.tensor(ids, dtype=torch.long))

target_ids = []
after_tgt_ids = []
target_ids = tokenizer(
    text_target=target_text,
    add_special_tokens=False,
    return_token_type_ids=False,
    return_attention_mask=False   
)['input_ids']
target_ids = target_ids[:10]
for ids in target_ids:
    ids = ids[:target_length]
    ids = tokenizer.build_inputs_with_special_tokens(ids)
    after_tgt_ids.append(torch.tensor(ids, dtype=torch.long))

Token indices sequence length is longer than the specified maximum sequence length for this model (1185 > 1024). Running this sequence through the model will result in indexing errors


In [11]:
from textbox.data.misc import load_data, _pad_sequence

batch = {}
cur_source_text = []
cur_source_ids = []
cur_source_mask = []
cur_source_length = []
cur_target_text = []
cur_source_padding_side = 'right'

for idx in [0,2,8,3]:
    cur_source_text.append(source_text[idx])
    cur_source_ids.append(after_src_ids[idx])
    cur_source_mask.append(torch.ones(len(after_src_ids[idx]), dtype=torch.long))
    cur_source_length.append(len(after_src_ids[idx]))
    cur_target_text.append(target_text[idx])
    
batch["source_text"] = cur_source_text
batch["source_ids"] = _pad_sequence(cur_source_ids, tokenizer.pad_token_id, cur_source_padding_side)
batch["source_mask"] = _pad_sequence(cur_source_mask, 0, cur_source_padding_side)
batch["source_length"] = torch.tensor(cur_source_length, dtype=torch.long)
batch["target_text"] = cur_target_text

cur_target_ids = []
for idx in [0,2,8,3]:
    cur_target_ids.append(after_tgt_ids[idx])
batch['target_ids'] = _pad_sequence(cur_target_ids, -100, tokenizer.padding_side)

In [12]:
batch

{'source_text': ['Fast forward about 20 years, and it\'s fair to say he has done just that. The business he runs, Frasers Hospitality, is one of the world\'s biggest providers of high-end serviced apartments. Its 148 properties span about 80 capital cities, as well as financial hubs across Europe, Asia, the Middle East and Africa. But it almost didn\'t get off the ground. When Mr Choe was appointed to launch and lead the company, Asia was booming; the tiger economies of Hong Kong, South Korea, Taiwan and Singapore were expanding rapidly. But as Frasers prepared to open its first two properties in Singapore, the Asian financial crisis hit. It was 1997. Currencies went into freefall. Suddenly, people were losing their jobs and stopped travelling. Mr Choe recalls asking staff if they really wanted to continue working with the firm, because when the properties opened they might not get paid. "It was really that serious," he says. "I remember tearing up because they said \'let\'s open it, l

## 3. Load PTG model

* [AbstractModel](https://github.com/RUCAIBox/TextBox/blob/2.0.0/textbox/model/abstract_model.py) ==(상속)==> [Pretrained_Models](https://github.com/RUCAIBox/TextBox/blob/2.0.0/textbox/model/pretrained_models.py) ==(상속)==> [PTG](https://github.com/RUCAIBox/TextBox/blob/2.0.0/textbox/model/ptg.py) 
* model의 forward 부분을 이해를 위해 재구성한 code임

In [13]:
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

arg_model = 'PTG'
arg_dataset = 'xsum'
args_config_files = []
self_task_num = 3
self_head_num = 16
self_head_dim = 64
self_scaling = self_head_dim**-0.5
self_embedding_size = 1024
self_k_proj = nn.Linear(self_embedding_size, self_embedding_size)  # in_feature x out_feature
self_v_proj = nn.Linear(self_embedding_size, self_embedding_size)
self_q_proj = nn.Linear(self_embedding_size, self_embedding_size)
self_out_proj = nn.Linear(self_embedding_size, self_embedding_size)
self_task_key = nn.Embedding(self_task_num+1, self_embedding_size)  # https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

bert_tokenizer = AutoTokenizer.from_pretrained('bert-large-uncased')
bert_model = AutoModel.from_pretrained('bert-large-uncased')


Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
import importlib
from accelerate import Accelerator
from textbox.config.configurator import Config
from accelerate import DistributedDataParallelKwargs

module_path = '.'.join(['textbox.model', 'PTG'.lower()])
if importlib.util.find_spec(module_path, __name__):
    model_module = importlib.import_module(module_path, __name__)
    model_class = getattr(model_module, 'PTG')
self_model = model_class

config = Config('PTG', 'xsum', [], {})
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config['find_unused_parameters'])
accelerator = Accelerator(
            gradient_accumulation_steps=config['accumulation_steps'], kwargs_handlers=[ddp_kwargs]
        )

## 3-1. PTG의 _process_prompt_tuning_input 

In [15]:
from transformers import AutoTokenizer, BartForConditionalGeneration

self_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
self_model.config.hidden_size

1024

In [16]:
# process_forward_inputs
inputs = {
    'input_ids': batch['source_ids'],
    'attention_mask': batch['source_mask'],
    'labels': batch['target_ids']
}

# process_prompt_tuning_input
input_ids = inputs['input_ids']
batch_size = input_ids.size(0)
input_embeds = self_model.get_input_embeddings()(input_ids)

print(f'input_embeds : {input_embeds.shape}')

input_embeds : torch.Size([4, 806, 1024])


* repeat : https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html

In [17]:
def sentence_embedding(text):
    encoding_dict = bert_tokenizer(text, max_length=bert_tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt')
    input_ids = encoding_dict['input_ids']
    attn_masks = encoding_dict['attention_mask']
    output = bert_model(input_ids, attn_masks)['last_hidden_state'] # b, l, dim
    print(f'BERT-large output : {output.shape}')
    
    hidden_state = output*attn_masks.unsqueeze(-1)
    print(f'BERT-large hidden state : {hidden_state.size()}')
    embedding = hidden_state.sum(dim=1)/attn_masks.sum(dim=1).unsqueeze(-1)
    return embedding.detach()

* key: a learnable clustr key == a learnable prompt key (batch size, task num, dim)
* input_query: instance-level query (batch size, 1, dim)
* value: source prompt로 부터 얻은 task-specific knowledge (=reprsentation)

In [18]:
## KEY
print('\n------------------ KEY ------------------')
task_key = self_task_key.weight.repeat(batch_size, 1, 1)  # b, tn+1, dim

print(f'self.task_key : {self_task_key.weight.size()}')
print(f'task_key : {task_key.size()}')

# key: a learnable cluster key && a learnable prompt key
key = self_k_proj(task_key[:, :-1])   # (b, tn (3), dim) 0번째~2번째
print(f'key : {key.size()}')
print('--------------------------------------------', end='\n\n')

## VALUE
value = self_v_proj(task_embedding).reshape(self_task_num,-1).repeat(batch_size, 1, 1)
print(f'\n------------------ VALUE ------------------')
print(f'self_v_proj(task_embedding) : {self_v_proj(task_embedding).size()}')
print(f' => reshape : {self_v_proj(task_embedding).reshape(self_task_num,-1).size()}')
print(f' => repeat : {value.size()}')
print('--------------------------------------------', end='\n\n')

## QUERY
print('\n------------------ QUERY ------------------')
# task-level query
task_query = self_q_proj(task_key[:, -1:])  # (b, 1, dim) 3번째
print(f'task-level query (task_query): {task_query.size()}', end='\n\n')

# instance-level query
input_query = sentence_embedding(batch['source_text']).unsqueeze(1)     # b, l, e
print(f'instance-level query (input_query): {input_query.size()}')
print('--------------------------------------------', end='\n\n')



------------------ KEY ------------------
self.task_key : torch.Size([4, 1024])
task_key : torch.Size([4, 4, 1024])
key : torch.Size([4, 3, 1024])
--------------------------------------------


------------------ VALUE ------------------
self_v_proj(task_embedding) : torch.Size([3, 200, 1024])
 => reshape : torch.Size([3, 204800])
 => repeat : torch.Size([4, 3, 204800])
--------------------------------------------


------------------ QUERY ------------------
task-level query (task_query): torch.Size([4, 1, 1024])

BERT-large output : torch.Size([4, 512, 1024])
BERT-large hidden state : torch.Size([4, 512, 1024])
instance-level query (input_query): torch.Size([4, 1, 1024])
--------------------------------------------



### MHA (Multi-Head Attention)

* 역할: Query, Key, Value에 대하여 여러개의 Attention을 동시에 병렬적으로 수행
* Multi-Head Attention을 수행하는 이유
    * 하나의 단어가 여러가지 attention 값을 가질 수 있으므로 다수의 문맥 정보를 포함하기 위함
* Attention을 수행하는 이유: 문맥에 따라 집중할 단어를 결정하는 방식 = 문맥 안에서 attention을 두도록 한다 = 문맥을 파악한다는 의미랑 같을 듯

In [19]:
import torch.nn.functional as F

def MHA(query, key, value):
    """
        Transformer의 Multi-Head Self Attention과 동일함
    
        key: (4,3,1024)
        query: 
            - instance-level query: (4,1,1024) 
            - task-level query: (4,1,1024)
        value: (4,3,204800)
    """
    print(f'-------------------------- MHA -------------------------')
    batch_size = key.size(0)
    
    # (4,1,16,64) => (4,16,1,64) => (64,1,64) (b*h,1,d)
    query = query.reshape(batch_size, -1, self_head_num, self_head_dim).transpose(1,2).reshape(batch_size*self_head_num, -1, self_head_dim)
    print(f'query: {query.shape}')

    # (4,3,16,64) => (4,16,3,64) => (64,3,64) (b*h,tn,d)
    key = key.reshape(batch_size, -1, self_head_num, self_head_dim).transpose(1,2).reshape(batch_size*self_head_num, -1, self_head_dim)
    print(f'key: {key.shape}')
    
    # (4,3,200,16,64) => (4,16,3,200,64) => (64,3,12800) (b*h,tn,pl*d)
    value = value.reshape(batch_size, self_task_num, -1, self_head_num, self_head_dim).permute(0,3,1,2,4).reshape(batch_size*self_head_num, self_task_num,-1)
    print(f'value: {value.shape}')
    
    # bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html
    # Query와 Key 내적 유사도 구함 => scaling 수행
    attn_weights = torch.bmm(query, key.transpose(1,2))*self_scaling  # (64,1,3) (b*h,1,tn)
    print(f'Query x Key : {attn_weights.shape}')
    
    # softmax 취하여 가중치 얻음
    attn_probs = F.dropout(attn_weights, p=0.1)
    attn_probs = F.softmax(attn_probs, dim=1)   
    
    # 각 가중치와 Value 곱함
    attn_output = torch.bmm(attn_probs, value)  # (64,1,12800) (b*h,1,pl*dim)
    print(f'=> attention output : {attn_output.size()}')  
    
    # prompt embedding으로 형태로 변환
    # (b*h, pl*dim) => (4,16,200,64) (b,h,pl,dim)
    prompt_embedding = attn_output.squeeze(1).reshape(batch_size, self_head_num, -1, self_head_dim) 
    print(f'prompt_embedding : {prompt_embedding.size()}', end=' => ')
    
    # Concat 수행
    # (4,200,16,64) => (4,200,1024)
    prompt_embedding = prompt_embedding.transpose(1,2).reshape(batch_size, -1, self_embedding_size)
    print(f'{prompt_embedding.shape}', end=' => ')
    
    prompt_embedding = self_out_proj(prompt_embedding)
    print(f'{prompt_embedding.shape}')
    print('----------------------------------------------------------------------------', end='\n\n')
    return prompt_embedding

### Equation (4) 수행

In [20]:
# task-level query / key / value
task_level_MHA = MHA(task_query, key, value)
print(f'task_level_MHA : {task_level_MHA.shape}')

-------------------------- MHA -------------------------
query: torch.Size([64, 1, 64])
key: torch.Size([64, 3, 64])
value: torch.Size([64, 3, 12800])
Query x Key : torch.Size([64, 1, 3])
=> attention output : torch.Size([64, 1, 12800])
prompt_embedding : torch.Size([4, 16, 200, 64]) => torch.Size([4, 200, 1024]) => torch.Size([4, 200, 1024])
----------------------------------------------------------------------------

task_level_MHA : torch.Size([4, 200, 1024])


In [21]:
# instance-level query / key / value
instance_level_MHA = MHA(input_query, key, value)
print(f'instance_level_MHA : {instance_level_MHA.shape}')

-------------------------- MHA -------------------------
query: torch.Size([64, 1, 64])
key: torch.Size([64, 3, 64])
value: torch.Size([64, 3, 12800])
Query x Key : torch.Size([64, 1, 3])
=> attention output : torch.Size([64, 1, 12800])
prompt_embedding : torch.Size([4, 16, 200, 64]) => torch.Size([4, 200, 1024]) => torch.Size([4, 200, 1024])
----------------------------------------------------------------------------

instance_level_MHA : torch.Size([4, 200, 1024])


In [22]:
lam = 0.5
prompt_embeds = lam*task_level_MHA + (1-lam)*instance_level_MHA

최종 inputs

In [23]:
inputs_embeds = torch.cat([prompt_embeds, input_embeds], dim=1)
inputs['inputs_embeds'] = inputs_embeds

prompt_length=200
mask = torch.ones(batch_size, prompt_length, dtype=torch.long)
inputs['attention_mask'] = torch.cat([mask, inputs['attention_mask']], dim=1)

print(f'최종 inputs : {inputs}')
print(f'=> input_embeds : {inputs["inputs_embeds"].shape}')
print(f'=> attention_mask : {inputs["attention_mask"].shape}')

최종 inputs : {'input_ids': tensor([[    0, 38182,  3916,  ...,     5, 38724,     2],
        [    0, 38182,  3916,  ...,  3713,   495,     2],
        [    0, 38182,  3916,  ...,     1,     1,     1],
        [    0, 38182,  3916,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[    0,  4148,     5,    78,   183,    11,    39,    92,   633,     6,
           732,  3540, 34597,  9430,    21,   576,    10,  5342,  2007,  4315,
            35,    22,  6785,   213,   146,   201,    10,   319,     9,   418,
            72,     2,  -100,  -100,  -100,  -100],
        [    0,   133, 19716,   392,    16,   567,   223,  1164,     7,   224,
           549,    79,  1467,    59,    10,   431,  3834,  7051,     9,     5,
           987,    18,  1748,  2398,   467,   137,    10,  4096, 10271,   900,
             4,     2,  -100,  -100,  -100,  -10

## 3-2. model()
일반적인 BART-large forward와 동일!

In [24]:
inputs.keys()
del inputs['input_ids']

In [25]:
inputs

{'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'labels': tensor([[    0,  4148,     5,    78,   183,    11,    39,    92,   633,     6,
            732,  3540, 34597,  9430,    21,   576,    10,  5342,  2007,  4315,
             35,    22,  6785,   213,   146,   201,    10,   319,     9,   418,
             72,     2,  -100,  -100,  -100,  -100],
         [    0,   133, 19716,   392,    16,   567,   223,  1164,     7,   224,
            549,    79,  1467,    59,    10,   431,  3834,  7051,     9,     5,
            987,    18,  1748,  2398,   467,   137,    10,  4096, 10271,   900,
              4,     2,  -100,  -100,  -100,  -100],
         [    0,   250,  1123, 23226,  5448,    61,  7544,    80,  6872, 33704,
            994,   583,  1378, 10416,    94,    76,   197,    45,  1303, 20396,
             50, 43337, 13014,   514,    53,  1492,  8182,    24,    40,   240,
    

In [26]:
outputs = self_model(**inputs)

## 3-3. Loss Function: CrossEntropyLoss

단, 해당 loss는 Summarization의 경우에만 해당함. 다른 task에서는 각자 확인 필요

In [29]:
self_label_smoothing = 0.1
loss_fct = nn.CrossEntropyLoss(label_smoothing=self_label_smoothing)
vocab_size = outputs.logits.size(-1)

logits = outputs.logits
labels = inputs['labels']

loss = loss_fct(logits.view(-1, vocab_size), labels.view(-1))
print(f'loss : {loss}')


loss : 4.745816230773926
