# Resnet Imagenet

In [50]:
import torchvision.models as models
resnet18 = models.resnet34(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /home/jiang.2880/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|███████████████████████████████████████| 83.3M/83.3M [00:00<00:00, 409MB/s]


In [51]:
print(resnet18.fc.weight.shape)

torch.Size([1000, 512])


In [52]:
weights = resnet18.fc.weight
W = nn.functional.normalize(weights, dim=1, p=2)
W = W.detach().cpu().numpy()

In [53]:
print(W.shape)

(1000, 512)


In [54]:

def loss_func(W, H):
    logits = W @ H.T
    WH = torch.diag(logits).unsqueeze(0)
    logits -= WH
    logits -= 100 * torch.eye(W.shape[0]).to(W.device)
    max_logits = torch.max(logits, dim=0, keepdim=True)[0]
    return max_logits

def minimize(W, H, lr=0.01, max_iter=10000):
    
    """
    Use gradient descent to minimize the objective function.
    """
    # lr_sched = np.linspace(0, lr, num=max_iter)
    # lr_sched = lr_sched[::-1]
    lr_step_sched = max_iter // 5
    H = torch.autograd.Variable(H.to(W.device), requires_grad=True)
    for i in tqdm(range(max_iter)):
        if (i+1) % lr_step_sched == 0:
            lr *= 0.1
        f = loss_func(W, H)
        f_sum = f.sum()
        f_sum.backward()
        with torch.no_grad():
            H -= lr * H.grad
            H /= torch.norm(H, dim=1, keepdim=True)
            H.grad.zero_()
    return f, H

def compute_NC2_matrix_form(W):
    K, d = W.shape
    W = torch.tensor(W).cuda()
    W = W.detach()
    H = torch.randn([K,d], device=W.device)
    H /= torch.norm(H, dim=1, keepdim=True)
    fs, H= minimize(W, H)

    distance = torch.min(-fs)
    return distance

In [55]:

NC2 = compute_NC2_matrix_form(W)
print(NC2)

100%|████████████████████████████████████| 10000/10000 [00:15<00:00, 658.14it/s]

tensor(0.4902, device='cuda:0', grad_fn=<MinBackward1>)





In [None]:
# Resnet18: 0.4433
# Resnet34: 0.4902
# optimal : 0.9150