Skip to content

Commit

Permalink
paranthesis tricks bjork
Browse files Browse the repository at this point in the history
  • Loading branch information
Franck Mamalet committed Jul 1, 2024
1 parent 4583add commit 0cbbf2d
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions deel/torchlip/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,17 @@ def bjorck_normalization(
shape = w.shape
cout = w.size(0)
w_mat = w.reshape(cout, -1)
for i in range(niter):
w_mat = (1.0 + beta) * w_mat - beta * torch.mm(
w_mat, torch.mm(w_mat.t(), w_mat)
)

if w_mat.shape[0]>w_mat.shape[1]:
for i in range(niter):
w_mat = (1.0 + beta) * w_mat - beta * torch.mm(
w_mat, torch.mm(w_mat.t(), w_mat)
)
else:
for i in range(niter):
w_mat = (1.0 + beta) * w_mat - beta * torch.mm(
torch.mm(w_mat, w_mat.t()), w_mat
)
w = w_mat.reshape(shape)
return w

Expand Down

0 comments on commit 0cbbf2d

Please sign in to comment.