Skip to content
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8425cd4
I added a new doc string to the class. This is more flexible to under…
hisushanta Oct 5, 2023
8a0e77d
Merge branch 'main' into doc_string
hisushanta Oct 5, 2023
cf3a816
Merge branch 'main' into doc_string
hisushanta Oct 5, 2023
a0cd96f
Merge branch 'main' into doc_string
hisushanta Oct 7, 2023
746a8e8
Update src/diffusers/models/unet_2d_blocks.py
hisushanta Oct 7, 2023
6e56886
Update src/diffusers/models/unet_2d_blocks.py
hisushanta Oct 7, 2023
627fd9f
Update unet_2d_blocks.py
hisushanta Oct 7, 2023
ae4f7f2
Update unet_2d_blocks.py
hisushanta Oct 8, 2023
f0bea43
Update unet_2d_blocks.py
hisushanta Oct 8, 2023
872a4a5
Merge branch 'main' into doc_string
hisushanta Oct 8, 2023
3546f6d
I run the black command to reformat style in the code
hisushanta Oct 9, 2023
12534f4
Merge branch 'main' into doc_string
hisushanta Oct 9, 2023
48afb4b
Merge branch 'main' into doc_string
hisushanta Oct 10, 2023
0bb06b6
Merge pull request #1 from hi-sushanta/doc_string
hisushanta Oct 11, 2023
01a9fc9
Update unet_2d_blocks.py
hisushanta Oct 13, 2023
04e6efb
Merge branch 'huggingface:main' into main
hisushanta Oct 13, 2023
d695c4a
Merge branch 'huggingface:main' into main
hisushanta Oct 19, 2023
27d8002
Merge branch 'huggingface:main' into main
hisushanta Oct 24, 2023
55001e3
Merge branch 'huggingface:main' into main
hisushanta Oct 29, 2023
57d534a
I removed the dummy variable defined in both the encoder and decoder.
hisushanta Oct 29, 2023
1fdcd40
Now, I run black package to reformat my file
hisushanta Oct 29, 2023
e8fa50c
Merge branch 'main' into vae_modify
DN6 Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def __init__(

self.gradient_checkpointing = False

def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = x

sample = self.conv_in(sample)

if self.training and self.gradient_checkpointing:
Expand Down Expand Up @@ -273,9 +273,11 @@ def __init__(

self.gradient_checkpointing = False

def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
def forward(
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = z

sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
Expand Down