Skip to content

Commit

Permalink
Remove quantization of hidden state
Browse files Browse the repository at this point in the history
  • Loading branch information
cafaxo committed Apr 22, 2024
1 parent 97e867b commit 42001c5
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 135 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1 change: 0 additions & 1 deletion src/Llama2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using LinearAlgebra
using StatsBase
using Printf
using ProgressMeter
using SIMD
using LoopVectorization
using Random
using Distributions
Expand Down
18 changes: 10 additions & 8 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ function matmul!(
end

function matmul!(
y::AbstractVector{Float32},
A::AbstractMatrix{T},
x::AbstractVector{Float32},
) where {T<:Union{block_q4_K,block_q5_K,block_q6_K}}

# FIXME: preallocate this
x = quantize(block_q8_K, x)
y::AbstractVector{Float32},
A::AbstractMatrix{T},
x::AbstractVector{Float32},
) where {T<:Union{block_q4_K,block_q5_K,block_q6_K}}
if T <: Union{block_q4_K,block_q5_K}
x = to_block_f16_sums32(x) # FIXME: preallocate this
else # block_q6_K
x = to_block_f16_sums16(x) # FIXME: preallocate this
end

Threads.@threads for i in 1:length(y)
y[i] = dot(view(A, :, i), x)
y[i] = vecdot(view(A, :, i), x)
end

return nothing
Expand Down
10 changes: 0 additions & 10 deletions src/quantization/utils.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
@inline function vwidemul(x::Vec{N,T}, y::Vec{N,T}) where {N,T}
WT = widen(T)
Vec{N,WT}(x) * Vec{N,WT}(y)
end

@inline function vpaddq(x::Vec{8,T}, y::Vec{8,T}) where {T}
shufflevector(x, y, Val((0, 2, 4, 6, 8, 10, 12, 14))) +
shufflevector(x, y, Val((1, 3, 5, 7, 9, 11, 13, 15)))
end

@noinline Base.@assume_effects :total function fieldoffset_sym(::Type{T}, s::Symbol) where {T}
for i in 1:fieldcount(T)
if fieldname(T, i) == s
Expand Down
Loading

0 comments on commit 42001c5

Please sign in to comment.