-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
I am copying that here from my post on slack, such that it does not get lost.
I think it might be worth to add a rudimentary support for checkpointing as
struct Checkpointed{S} <: Transformers.Layers.AbstractTransformerBlock
f::S
end
Base.show(io::IO, c::Checkpointed) = print(io, c.f)
(m::Checkpointed)(args...) = Zygote.checkpointed(m.f, args...)
and then wrappend blocks to Checkpointed as
decoder = Transformers.Layers.Chain(Transformer(map(Checkpointed, decoder.layers[1].blocks)), decoder.layers[2])
and while it is probably not the nicest representation, it seems to work.
The running times are approximately 50% longer, which I think is correct since the the forward pass is need to do twice.
I do not know, if this is something that is wanted. If yes, I might try to add this as a more proper solution and improve it. Ideally, one would like to have an option to download the model from HF and add checkpointing. I think that HF has this option.
Metadata
Metadata
Assignees
Labels
No labels