Skip to content

Commit

Permalink
Print the words as they are generated
Browse files Browse the repository at this point in the history
Fixes #6
  • Loading branch information
certik committed Mar 19, 2023
1 parent b10860e commit 7e34e64
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
12 changes: 10 additions & 2 deletions gpt2.f90
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ function generate(n_tokens_to_generate, &
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache) result(output)
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache, &
decoder_idx, decoder_txt, byte_decoder) result(output)
integer, intent(in) :: decoder_idx(:), byte_decoder(:)
character, intent(in) :: decoder_txt(:)
integer, intent(in) :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, &
n_tokens_to_generate
integer, intent(in) :: input(n_seq)
Expand All @@ -238,8 +241,10 @@ function generate(n_tokens_to_generate, &
integer, allocatable :: input2(:)
logical :: use_kv_cache
real(sp) :: kv_cache(n_embd,n_seq+n_tokens_to_generate,2,n_layer)
character(:), allocatable :: txt1, txt2
allocate(input2(size(input)))
input2 = input
txt1 = decode(input2, decoder_idx, decoder_txt, byte_decoder)
do i = 1, n_tokens_to_generate
if (use_cache) then
use_kv_cache = (i > 1) ! Use cache for subsequent tokens
Expand All @@ -260,11 +265,14 @@ function generate(n_tokens_to_generate, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_kv_cache, kv_cache(:,:n_seq2,:,:))
next_id = maxloc(logits(:,n_seq_x), dim=1)-1
print *, i, next_id
input2 = [input2, next_id]
txt2 = decode(input2, decoder_idx, decoder_txt, byte_decoder)
write(*, fmt="(a)", advance="no") txt2(len(txt1)+1:)
txt1 = txt2
deallocate(logits)
end do
output = input2(n_seq+1:)
print *
end function

function c2s(x) result(y)
Expand Down
3 changes: 2 additions & 1 deletion main.f90
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ program gpt2
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache)
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache, &
decoder_idx, decoder_txt, byte_decoder)
t2o = omp_get_wtime()
call cpu_time(t2)
print "(a,f8.3,a,f4.2,a)", " done. Time:", t2o-t1o, "s (", (t2-t1)/(t2o-t1o), "x)"
Expand Down

0 comments on commit 7e34e64

Please sign in to comment.