Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to train decoder using embedding-image pairs #63

Closed
Veldrovive opened this issue May 5, 2022 · 4 comments
Closed

Add ability to train decoder using embedding-image pairs #63

Veldrovive opened this issue May 5, 2022 · 4 comments

Comments

@Veldrovive
Copy link
Collaborator

I am implementing a single node training script for the decoder and it seems @lucidrains has implemented a wrapper script for this purpose that is already feature-full. Currently, the forward pass is implemented as follows:

def forward(
self,
x,
*,
unet_number,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)

This lacks the ability to substitute our own image embeddings in the case where we have precomputed embedding-image pairs. The functionality is already mostly supported by the Decoder network where image_embed can be passed to the forward method so this could be implemented by simply adding the image_embed parameter as a pass though to decoder.forward. However, it would also be convenient to make the clip model optional in the Decoder constructor. I already started on this a week ago in this branch by adding the ability to set clip_image_size and channels separately from a clip model.

There are only a few small changes that would be necessary to implement this feature so I could put together a pull request to do this.

@lucidrains
Copy link
Owner

@Veldrovive Hi Aidan! Indeed that is the case, and I can get this finished in the next half hour, been meaning to get around to it!

@lucidrains
Copy link
Owner

@Veldrovive
Copy link
Collaborator Author

Great! For the DecoderTrainer are you thinking to just use kwargs for image_embed and not put a specific named parameter for it?

@lucidrains
Copy link
Owner

@Veldrovive yup, for wrapper i usually just forward kwargs to whatever is being wrapped (instead of using some fancy forwarding module)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants