Skip to content

Commit

Permalink
Merge pull request hpcaitech#107 from jpthu17/fix_2d_RoPE_init_bug
Browse files Browse the repository at this point in the history
[fix] 2d RoPE init

Former-commit-id: e873ffa66f01268cce600f8da9ad297dc3e8aaaf
  • Loading branch information
LinB203 authored Mar 10, 2024
2 parents 9a0d0bc + 75938cf commit 74ee558
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions opensora/models/diffusion/dit/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
attention_pe_mode=None,
hw: Union[int, Tuple[int, int]] = 16, # (h, w)
pt_hw: Union[int, Tuple[int, int]] = 16, # (h, w)
intp_vfreq: bool = False, # vision position interpolation
intp_vfreq: bool = True, # vision position interpolation
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
Expand Down Expand Up @@ -280,8 +280,8 @@ def __init__(
extras=1,
attention_mode='math',
attention_pe_mode=None,
pt_input_size: Union[int, Tuple[int, int]] = 16, # (h, w)
intp_vfreq: bool = False, # vision position interpolation
pt_input_size: Union[int, Tuple[int, int]] = None, # (h, w)
intp_vfreq: bool = True, # vision position interpolation
):
super().__init__()
self.gradient_checkpointing = False
Expand All @@ -305,6 +305,8 @@ def __init__(
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
self.temp_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False)

if pt_input_size is None:
pt_input_size = input_size
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode,
attention_pe_mode=attention_pe_mode, hw=input_size, pt_hw=pt_input_size,
Expand Down
8 changes: 5 additions & 3 deletions opensora/models/diffusion/latte/latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self,
attention_pe_mode=None,
hw: Union[int, Tuple[int, int]] = 16, # (h, w)
pt_hw: Union[int, Tuple[int, int]] = 16, # (h, w)
intp_vfreq: bool = False, # vision position interpolation
intp_vfreq: bool = True, # vision position interpolation
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
Expand Down Expand Up @@ -284,8 +284,8 @@ def __init__(
extras=1,
attention_mode='math',
attention_pe_mode=None,
pt_input_size: Union[int, Tuple[int, int]] = 16, # (h, w)
intp_vfreq: bool = False, # vision position interpolation
pt_input_size: Union[int, Tuple[int, int]] = None, # (h, w)
intp_vfreq: bool = True, # vision position interpolation
):
super().__init__()
self.learn_sigma = learn_sigma
Expand Down Expand Up @@ -316,6 +316,8 @@ def __init__(
self.temp_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False)
self.hidden_size = hidden_size

if pt_input_size is None:
pt_input_size = input_size
self.blocks = nn.ModuleList([
TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode,
attention_pe_mode=attention_pe_mode, hw=input_size, pt_hw=pt_input_size,
Expand Down

0 comments on commit 74ee558

Please sign in to comment.