Skip to content

Commit

Permalink
Update: Optimize some code.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Aug 28, 2023
1 parent 8b34de2 commit cd78a21
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, in_channel=3, out_channel=3, channel=None, time_channel=256,
self.outc = nn.Conv2d(in_channels=channel[0], out_channels=out_channel, kernel_size=1)

if num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_channel)
self.label_emb = nn.Embedding(num_embeddings=num_classes, embedding_dim=time_channel)

def pos_encoding(self, time, channels):
"""
Expand All @@ -65,10 +65,11 @@ def pos_encoding(self, time, channels):
:param channels: Channels
:return: pos_enc
"""
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels))
pos_enc_a = torch.sin(time.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(time.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
inv_freq = 1.0 / (10000 ** (torch.arange(start=0, end=channels, step=2, device=self.device).float() / channels))
inv_freq_value = time.repeat(1, channels // 2) * inv_freq
pos_enc_a = torch.sin(input=inv_freq_value)
pos_enc_b = torch.cos(input=inv_freq_value)
pos_enc = torch.cat(tensors=[pos_enc_a, pos_enc_b], dim=-1)
return pos_enc

def forward(self, x, time, y=None):
Expand Down

0 comments on commit cd78a21

Please sign in to comment.