In [17]:
import torch
import torch.nn as nn
import time

In [106]:
def calc_weight_loss(weights, mean:float, std:float, step:int, scale:float=1.96):
    sep = 2 * scale * std / step
    qp = torch.linspace(- scale * std, scale * std, step + 1)
    weights = weights % sep
    weights[weights > (sep / 2)] = weights[weights > (sep / 2)] - sep
    weights = weights.view(-1, 1)
    qp = qp.view(1, step+1)

    gap = torch.remainder(weights - qp, sep)
    return gap
    
        
x = torch.randn([4,4], dtype=torch.float32)
y = torch.randn([4,4], dtype=torch.float32)

model = nn.Linear(4,4)
nn.init.normal_(model.weight)

s = time.time()
for i in range(1000):
    mean, std = torch.mean(model.weight).item(), torch.std(model.weight).item()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    pred = model.forward(x)
    gap = calc_weight_loss(model.weight, mean, std, 4)
    loss = nn.functional.mse_loss(pred, y)
    total_loss = loss + torch.sum(gap)

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if i%100 == 0:
        print(gap, loss)
e = time.time()
print(e-s)


1.0287203907966613 tensor([[-0.2478, -0.3020, -0.2184,  0.4127],
        [ 0.1976,  0.2307,  0.2441,  0.2974],
        [ 0.4767, -0.2260,  0.3072, -0.0800],
        [ 0.3409,  0.3537, -0.3927,  0.1239]], grad_fn=<IndexPutBackward0>)
tensor([[0.7809, 0.7809, 0.7809, 0.7809, 0.7809],
        [0.7267, 0.7267, 0.7267, 0.7267, 0.7267],
        [0.8103, 0.8103, 0.8103, 0.8103, 0.8103],
        [0.4127, 0.4127, 0.4127, 0.4127, 0.4127],
        [0.1976, 0.1976, 0.1976, 0.1976, 0.1976],
        [0.2307, 0.2307, 0.2307, 0.2307, 0.2307],
        [0.2441, 0.2441, 0.2441, 0.2441, 0.2441],
        [0.2974, 0.2974, 0.2974, 0.2974, 0.2974],
        [0.4767, 0.4767, 0.4767, 0.4767, 0.4767],
        [0.8027, 0.8027, 0.8027, 0.8027, 0.8027],
        [0.3072, 0.3072, 0.3072, 0.3072, 0.3072],
        [0.9487, 0.9487, 0.9487, 0.9487, 0.9487],
        [0.3409, 0.3409, 0.3409, 0.3409, 0.3409],
        [0.3537, 0.3537, 0.3537, 0.3537, 0.3537],
        [0.6361, 0.6361, 0.6361, 0.6361, 0.6361],
        [0.1239, 

In [89]:
gap, torch.mean(gap)

(tensor([[1.1879e-04, 4.2439e-05, 1.5914e-05, 5.4240e-05, 3.0398e-05, 5.2977e-02,
          8.8334e-05, 2.7001e-05, 3.9816e-05, 9.7752e-05, 1.2314e-04, 3.0637e-05,
          3.9279e-05, 1.7881e-05, 2.1444e-01, 7.8082e-06],
         [1.6332e-05, 9.8467e-05, 5.1260e-05, 5.3883e-05, 7.1526e-06, 7.4148e-05,
          9.5367e-06, 2.3186e-05, 3.1710e-05, 7.1451e-06, 4.8339e-05, 2.3723e-05,
          7.8678e-06, 6.7830e-05, 5.1796e-05, 5.4479e-05],
         [2.7431e-01, 4.0557e-05, 5.3518e-05, 1.7285e-05, 6.4135e-05, 6.5625e-05,
          7.8321e-05, 5.2452e-05, 6.6102e-05, 5.7697e-05, 4.5013e-05, 8.6606e-05,
          4.1485e-05, 8.2433e-05, 6.1382e-05, 8.0049e-05],
         [1.8299e-05, 3.4809e-05, 1.0192e-05, 1.1761e-04, 8.9400e-05, 6.2466e-05,
          9.7428e-05, 4.4465e-05, 9.5367e-07, 4.7803e-05, 3.6299e-05, 4.6015e-05,
          3.6180e-05, 3.1769e-05, 3.9041e-05, 6.9678e-05],
         [1.1683e-05, 7.2658e-05, 1.1153e-04, 5.1260e-06, 7.2617e-05, 2.5649e-01,
          5.7817e-06, 6.69

In [97]:
a = torch.tensor(-10.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a % b
c.backward()
print(c)

tensor(2., grad_fn=<RemainderBackward1>)
