In [1]:
import sys
sys.path.append("..")
from plot_utils import *
#
import numpy as np
import time
import torch
import torch.nn as nn
from torch.distributions import Beta
from torch.distributions.dirichlet import Dirichlet
from tqdm import tqdm
#
from utils import *
from functions import *
from BTwins.utils import calc_lambda

In [2]:
def mean_var(x, axis=0):
    x = np.array(x)
    print("mu={:.3f} var={:.3f}".format(float(x.mean(axis=0)),
                                        float(x.var(axis=0))))

def cc_matrix(x1, x2):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    c = x1.T @ x2
    c.div_(x1.shape[0])
    return c

def reshape_(x):
    if len(x.shape) == 1:
        x = x.reshape((-1, 1))
    return x

def cc_bn(x1, x2, debug=False, eps=1e-5):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    
    bn = torch.nn.BatchNorm1d(x1.shape[1], affine=False, eps=eps)
    x1 = bn(x1)
    bn = torch.nn.BatchNorm1d(x1.shape[1], affine=False, eps=eps)
    x2 = bn(x2)
    if debug:
        print("bn(X1)", x1.mean(axis=0), x1.var(axis=0))
        print("bn(X2)", x2.mean(axis=0), x1.var(axis=0))
    #
    return cc_matrix(x1, x2)

def cc_norm(x1, x2, debug=False, eps=1e-5):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    
    # recenter
    #x1 = x1 - x1.mean()
    #x2 = x2 - x2.mean()
    
    # unit variance
    if eps > 0:
        x1 = (x1 - x1.mean(axis=0)) / torch.sqrt(x1.var(axis=0) + eps)
        x2 = (x2 - x2.mean(axis=0)) / torch.sqrt(x2.var(axis=0) + eps)
    else:
        x1 = (x1 - x1.mean(axis=0)) / x1.std(axis=0)
        x2 = (x2 - x2.mean(axis=0)) / x2.std(axis=0)
    
    if debug:
        print("mv(X1)", x1.mean(axis=0), x1.var(axis=0))
        print("mv(X2)", x2.mean(axis=0), x1.var(axis=0))
    #
    return cc_matrix(x1, x2)

In [3]:
n = 100
d = 3
x1 = 2 * torch.rand((n, d)) + 1
x2 = 10 * torch.rand((n, d)) + 4

In [4]:
cc_norm(x1, x1)

tensor([[0.9900, 0.0130, 0.0772],
        [0.0130, 0.9900, 0.0530],
        [0.0772, 0.0530, 0.9900]])

In [5]:
cc_bn(x1, x2)

tensor([[-0.1234, -0.0293,  0.0542],
        [-0.1266, -0.0431, -0.2481],
        [ 0.0503,  0.0676, -0.1017]])

In [6]:
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = self.scale_factor * (on_diag + self.lambd * off_diag)

NameError: name 'c' is not defined

In [7]:
def get_off_scale(d):
    return d/(d**2 - d)

In [8]:
dims = [2, 4, 8, 16, 32, 128, 256, 512, 1024, 2048, 4096, 8192]
for d in dims:
    l1 = get_off_scale(d)
    l2 = calc_lambda(d)
    print("{:.6f} : {:.6f} = 1 / {}".format(l1, l2, l1/l2))

1.000000 : 40.983607 = 1 / 0.0244
0.333333 : 13.661202 = 1 / 0.024399999999999998
0.142857 : 5.854801 = 1 / 0.0244
0.066667 : 2.732240 = 1 / 0.024400000000000005
0.032258 : 1.322052 = 1 / 0.0244
0.007874 : 0.322706 = 1 / 0.0244
0.003922 : 0.160720 = 1 / 0.0244
0.001957 : 0.080203 = 1 / 0.0244
0.000978 : 0.040062 = 1 / 0.0244
0.000489 : 0.020021 = 1 / 0.0244
0.000244 : 0.010008 = 1 / 0.0244
0.000122 : 0.005003 = 1 / 0.024399999999999998


In [9]:
print("{:>4} {:>8} {:>7} {:>6}".format("on", "off", "l", "l*off"))
print("#"*(5+9+8+7))
for d in dims:
    n_on = d
    n_of = d**2 - d
    lmbda = get_off_scale(d)
    print("{:>4d} {:>8d} {:7.5f} {:>6}".format(n_on, n_of, lmbda, lmbda * n_of))

  on      off       l  l*off
#############################
   2        2 1.00000    2.0
   4       12 0.33333    4.0
   8       56 0.14286    8.0
  16      240 0.06667   16.0
  32      992 0.03226   32.0
 128    16256 0.00787  128.0
 256    65280 0.00392  256.0
 512   261632 0.00196  512.0
1024  1047552 0.00098 1024.0
2048  4192256 0.00049 2048.0
4096 16773120 0.00024 4096.0
8192 67100672 0.00012 8192.0


In [10]:
print(w_off)
for d in dims:
    l1 = get_off_scale(d)
    l2 = calc_lambda(d)
    l3 = l1 * w_off
    print("{:.5f} {:.5f} {:.5f}".format(l1, l2, l3))

NameError: name 'w_off' is not defined

In [11]:
c = torch.diag(torch.ones(5))
c = torch.ones((4, 4))

In [None]:
def cc_loss(c):
    loss_on = torch.diagonal(c).add(-1).pow(2).sum()
    loss_off = off_diagonal(c).pow(2).sum()
    return loss_on + loss_off, loss_on, loss_off
    #return (torch.diag(c) - 1).mean()

In [None]:
d = 8
w_off = get_off_scale(d)
lamda = 1 / 0.0244
c = torch.ones((8, 8))
#
on_diag = torch.diagonal(c).mean()
off_diag = off_diagonal(c).mean()

In [None]:
on_diag

In [None]:
off_diag

In [None]:
on_diag 

In [None]:
off_diag * w_off

In [12]:
# PAPER
lambd = 0.0051
#
d = 8192
n_on = d
n_off = d**2 - d
n_on / (n_off * lambd)
#
c = torch.ones((d, d))

In [13]:
w_off = 41.774
(n_off * lambd) / n_on

41.774100000000004

In [14]:
# they use sum
def cc_loss_paper(c, lambd):
    on_diag = torch.diagonal(c).add(-1).pow(2).sum()
    off_diag = off_diagonal(c).pow(2).sum() * lambd
    loss = on_diag + off_diag
    return loss, on_diag, off_diag

def cc_loss_1(c, woff):
    on_diag = torch.diagonal(c).add(-1).pow(2).mean()
    off_diag = off_diagonal(c).pow(2).mean()
    loss = on_diag + off_diag
    return loss, on_diag, off_diag

In [15]:
c_rand = torch.rand((d, d))
c_best = torch.zeros((d, d))
torch.diagonal(c_best).add_(1)
c_worst = torch.ones((d, d))
torch.diagonal(c_worst).add_(-1)

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [16]:
for c in [c_rand, c_best, c_worst]:
    l, l_on, l_off = cc_loss_1(c)
    print(l, l_on, l_off)

TypeError: cc_loss_1() missing 1 required positional argument: 'woff'

In [17]:
for c in [c_rand, c_best, c_worst]:
    l, l_on, l_off = cc_loss_paper(c, lambd)
    print(l, l_on , l_off)

tensor(116729.5469) tensor(2673.0620) tensor(114056.4844)
tensor(0.) tensor(0.) tensor(0.)
tensor(350405.4375) tensor(8192.) tensor(342213.4375)


In [18]:
l_off / l_on

tensor(41.7741)

In [19]:
for d in [2, 4, 8, 16, 32, 64, 128, 512]:
    lmbda = calc_lambda(d)
    g = d / ((d**2 - d) * lmbda)
    h = ((d**2 - d) * lmbda) / d
    print("{:>4d}: {:8.4f} on/off={:.4f} off/on={:8.4f}".format(d,lmbda, g, h))

   2:  40.9836 on/off=0.0244 off/on= 40.9836
   4:  13.6612 on/off=0.0244 off/on= 40.9836
   8:   5.8548 on/off=0.0244 off/on= 40.9836
  16:   2.7322 on/off=0.0244 off/on= 40.9836
  32:   1.3221 on/off=0.0244 off/on= 40.9836
  64:   0.6505 on/off=0.0244 off/on= 40.9836
 128:   0.3227 on/off=0.0244 off/on= 40.9836
 512:   0.0802 on/off=0.0244 off/on= 40.9836
