Skip to content

Commit

Permalink
Falcon support (#890)
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jul 2, 2023
1 parent 3c75e59 commit bba1578
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 2 deletions.
62 changes: 62 additions & 0 deletions engines/python/setup/djl_python/scheduler/lm_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,65 @@ def forward(self, input_ids: torch.tensor, position_ids: torch.tensor,
output.past_key_values = past_key_values

return output


class FalconBlock(LMBlock):

def __init__(self, model):
super(FalconBlock, self).__init__(model)
self.config = {
'use_cache': True,
'return_dict': True,
'output_attentions': False,
'output_hidden_states': True
}

def forward(self, input_ids: torch.tensor, position_ids: torch.tensor,
attention_mask: torch.tensor, past_key_values):
# concatenate along seq_length dimension:
# [batch, seq, num_heads*kvDim = 73 * 64]
# - key: [batch_size * self.num_kv, seq, kvDim]. [2, 6, 64]
# - value: [batch_size * self.num_kv, seq, kvDim]. [2, 6, 64]

# Falcon
# fused_qkv: [batch, seq, (num_heads=71 + num_kv=1 + num_kv=1) * kvDim]
# query_layer: [batch*num_heads=71, seq, kvDim]
# key_layer : [batch*num_kv=1, seq, kvDim].
# value_layer: [batch*num_kv=1, seq, kvDim].
# hidden_dim = 4544

# kv: (batch, num_head, seq_len, kv_dim)
# <->
# k: (batch*num_kv, seq_len, kv_dim),
# v: (batch*num_kv, seq_len, kv_dim)
batch_size = input_ids.shape[0]

# Pre-process
if past_key_values is not None:
_, num_head, seq_len, kv_dim = past_key_values[0][0].shape
new_kv_list = []
for k, v in past_key_values:
k_new = k.view(batch_size * num_head, seq_len, kv_dim)
v_new = v.view(batch_size * num_head, seq_len, kv_dim)
new_kv_list.append((k_new, v_new))
past_key_values = tuple(new_kv_list)

# Forward
output = self.model.forward(input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
**self.config)
past_key_values = output.past_key_values

# Post-process
_, seq_len, kv_dim = past_key_values[0][0].shape
new_kv_list = []
for k, v in past_key_values:
k_new = k.view(batch_size, -1, seq_len, kv_dim)
v_new = v.view(batch_size, -1, seq_len, kv_dim)
new_kv_list.append((k_new, v_new))
past_key_values = tuple(new_kv_list)
output.past_key_values = past_key_values

return output
74 changes: 73 additions & 1 deletion engines/python/setup/djl_python/tests/test_scheduler_bloom.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import unittest

from djl_python.scheduler import BloomBlock
from djl_python.scheduler.lm_block import BloomBlock, FalconBlock
from djl_python.scheduler.seq_batch_scheduler import SeqBatchScheduler
from transformers import AutoConfig, BloomForCausalLM, AutoTokenizer
from djl_python.scheduler.search_config import SearchConfig
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer, AutoModelForCausalLM


class TestSchedulerBloom(unittest.TestCase):
Expand Down Expand Up @@ -119,6 +120,77 @@ def test_contrastive_scheduler(self):
for i, ret in results.items():
print('\n{}:'.format(i), tokenizer.decode(ret))

def test_contrastive_scheduler_falcon(self):
model_name = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
"BlackSamorez/falcon-40b-tiny-testing", trust_remote_code=True)

lm_block = FalconBlock(model)

search_config = SearchConfig()
PAD = search_config.pad_token_id
scheduler = SeqBatchScheduler(lm_block, "contrastive", search_config)

input_ids_0 = tokenizer.encode(
'Memories follow me left and right. I can', return_tensors='pt')
request_ids = torch.tensor([[0]])

# Test init_forward
scheduler.add_request(input_ids_0, request_ids)

# Merge longer sequences
input_ids_1 = tokenizer.encode(
"When your legs don't work like they used to before And I can't sweep you off",
return_tensors='pt')
input_ids_2 = torch.concat([
torch.tensor([PAD, PAD, PAD, PAD, PAD, PAD]),
tokenizer.encode(
"There's a time that I remember, when I did not know",
return_tensors='pt')[0]
]).view(1, -1)
input_ids = torch.concat([input_ids_1, input_ids_2], dim=0)

request_ids = torch.tensor([[1], [2]])
scheduler.add_request(input_ids, request_ids)

# Forward pass
for _ in scheduler.increment_forward(20):
pass

results = scheduler.results

assert tokenizer.decode(
results[1][:30]
) == "When your legs don't work like they used to before And I can't sweep you offíc warr formats Tos Bruce advocacyyoungGP xxx522"
assert tokenizer.decode(
results[2][:20]
) == "There's a time that I remember, when I did not know rents complimentaryigsiosis stimulate roads"
assert tokenizer.decode(
results[0][:30]
) == 'Memories follow me left and right. I canHex pennednal hackers quali consists authoritative operates Nurse Scotland[@ Burns diminishing和 preNut comfortably drainage suddenly revised'

# Merge shorter sequences
input_ids_1 = tokenizer.encode("When your legs don't work",
return_tensors='pt')
input_ids_2 = torch.concat([
torch.tensor([PAD, PAD]),
tokenizer.encode("There's a time", return_tensors='pt')[0]
]).view(1, -1)
input_ids = torch.concat([input_ids_1, input_ids_2], dim=0)
request_ids = torch.tensor([[3], [4]])

scheduler.add_request(input_ids, request_ids)

# Forward pass
for _ in scheduler.increment_forward(100):
pass

# print
for i, ret in results.items():
print('\n{}:'.format(i), tokenizer.decode(ret))



if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run(self):
requirements = ['psutil', 'packaging', 'wheel']

test_requirements = [
'numpy', 'requests', 'Pillow', 'transformers', 'torch'
'numpy', 'requests', 'Pillow', 'transformers', 'torch', 'einops'
]

setup(name='djl_python',
Expand Down

0 comments on commit bba1578

Please sign in to comment.