In [18]:
from VisionTransformer import ViT
import functorch as ft
import torch
import torch.func as fc
from functools import partial
import copy 

model = ViT(num_classes=10, imgsize=32, patch_dim=4, num_layers=7, d_model=256, nhead=4, d_ff_ratio=4, dropout=0.1, activation="gelu")


In [19]:
def ad_func(params, buffers, names, model, x, y):
    #h, z = model(x)
    #loss = F.cross_entropy(z, y)
    # x = torch.cat([xi, xj], dim=0)
    #h, z = model.forward(x)
    z = fc.functional_call(model, ({k: v for k, v in zip(names, params)}, buffers), (x,))
    # ((2N, g) @ (g, 2N)) / (2N,1) @ (1,2N) -> (2N, 2N) / (2N,2N)
    sim_matrix = (z @ z.T) / (z.norm(p=2, dim=1, keepdim=True) @ z.norm(p=2, dim=1, keepdim=True).T)
    mask = torch.eye(z.shape[0], dtype=torch.bool, device=z.device)
    pos_mask = mask.roll(shifts=sim_matrix.shape[0]//2, dims=1).bool()  # find pos-pair N away
    pos = torch.exp(sim_matrix[pos_mask] / 0.1)
    neg = torch.exp(sim_matrix.masked_fill(mask, value=float("-inf")) / 0.1)
    loss = -torch.log(pos / torch.sum(neg))
    #loss = - (sim_matrix[pos_mask] / self.hparams.temp / 2) + (torch.logsumexp(sim_matrix.masked_fill(mask, value=float("-inf")) / self.hparams.temp, dim=1) / 2)
    # Find the rank for the positive pair
    sim_matrix = torch.cat([sim_matrix[pos_mask].unsqueeze(1), sim_matrix.masked_fill(pos_mask,float("-inf"))], dim=1)
    pos_pair_pos = torch.argsort(sim_matrix, descending=True, dim=1).argmin(dim=1)
    top1 = torch.mean((pos_pair_pos == 0).float())
    top5 = torch.mean((pos_pair_pos < 5).float())
    mean_pos = torch.mean(pos_pair_pos.float())
    return torch.mean(loss)# , top1, top5, mean_pos

In [20]:
named_buffers = dict(model.named_buffers())
named_params = dict(model.named_parameters())
names = named_params.keys()
params = named_params.values()

In [21]:
x = torch.rand(32, 3, 32, 32)
y = torch.rand(32, 10)

In [22]:
v_params = tuple([torch.randn_like(param) for param in params])
v_params_copy = copy.deepcopy(v_params)
foo = partial(
    ad_func,
    model=model,
    names=names,
    buffers=named_buffers,
    x=x,
    y=y
)
loss, jvp = fc.jvp(foo, (tuple(params),), (v_params,))

In [23]:
jvp

tensor(-18.8820, grad_fn=<MeanBackward0>)

In [29]:
v_params[0]

tensor([[[ 6.0109e-01, -2.9706e-01, -4.9550e-02, -1.0207e+00, -4.9993e-01,
          -3.2764e-01, -1.5874e-01, -4.6128e-01,  1.8962e-01,  8.0452e-01,
          -9.9625e-01,  1.6541e+00,  1.2493e+00,  1.1574e+00,  1.0616e-02,
           5.1207e-01,  8.1531e-01,  9.6493e-01,  9.8747e-01, -1.3913e-01,
           1.4757e+00,  2.2006e+00,  7.9197e-01, -9.0837e-01, -3.2935e-01,
           8.4876e-02,  7.5794e-01,  7.0846e-02, -1.2059e+00, -2.1368e-01,
          -8.5498e-01, -8.4946e-01,  8.4499e-01,  9.0548e-01,  2.7240e-01,
           4.9713e-02, -7.2316e-01,  1.1908e+00,  6.7316e-01,  6.0828e-03,
           2.2433e+00,  1.2006e+00,  1.6620e-01, -1.3031e-01, -6.9910e-01,
           1.7291e-01, -2.4626e+00, -8.2512e-01, -8.6691e-01, -2.0017e-01,
           4.7458e-03, -7.5629e-03, -3.9518e-01,  1.1264e+00, -1.0212e+00,
          -6.7782e-01,  2.4335e-01, -3.6254e-01,  6.8250e-01,  2.5501e-01,
           1.9431e+00,  1.2812e+00,  4.4676e-02,  2.1638e+00,  5.2783e-01,
          -1.6409e-01,  1

In [32]:
(v_params[0] == v_params_copy[0]).sum()

tensor(256)

In [25]:
for v, p in zip(v_params, params):
    p.grad = v * jvp

In [26]:
import torch.autograd.forward_ad as fwAD

input = x
params = {name: p for name, p in model.named_parameters()}
tangents = {name: torch.rand_like(p) for name, p in params.items()}

with fwAD.dual_level():
    for name, p in params.items():
        delattr(model, name)
        setattr(model, name, fwAD.make_dual(p, tangents[name]))

    out = model(input)
    jvp = fwAD.unpack_dual(out).tangent

AttributeError: 'ViT' object has no attribute 'embed.convembed.weight'

In [None]:
from torch.func import functional_call

# We need a fresh module because the functional call requires the
# the model to have parameters registered.

dual_params = {}
with fwAD.dual_level():
    for name, p in params.items():
        # Using the same ``tangents`` from the above section
        dual_params[name] = fwAD.make_dual(p, tangents[name])
    out = functional_call(model, dual_params, input)
    jvp2 = fwAD.unpack_dual(out).tangent

# Check our results
#assert torch.allclose(jvp, jvp2)

In [None]:
jvp2.shape