Skip to content

Adding support for checkpointing #149

@pevnak

Description

@pevnak

I am copying that here from my post on slack, such that it does not get lost.

I think it might be worth to add a rudimentary support for checkpointing as

struct Checkpointed{S} <: Transformers.Layers.AbstractTransformerBlock
	f::S
end

Base.show(io::IO, c::Checkpointed) = print(io, c.f)

(m::Checkpointed)(args...) = Zygote.checkpointed(m.f, args...)

and then wrappend blocks to Checkpointed as

decoder = Transformers.Layers.Chain(Transformer(map(Checkpointed, decoder.layers[1].blocks)), decoder.layers[2]) 

and while it is probably not the nicest representation, it seems to work.
The running times are approximately 50% longer, which I think is correct since the the forward pass is need to do twice.

I do not know, if this is something that is wanted. If yes, I might try to add this as a more proper solution and improve it. Ideally, one would like to have an option to download the model from HF and add checkpointing. I think that HF has this option.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions