In [4]:
import os
import sys
import pdb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'
if DEVICE == 'mps':
    DOWNLOAD_PATH = '/Users/georgestoica/Downloads' 
else:
    DOWNLOAD_PATH = '/srv/share/gstoica3/checkpoints/REPAIR/'

In [5]:
DEVICE

'mps'

In [6]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [7]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = w*16

        self.conv1 = nn.Conv2d(3, w*16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w*16)
        self.layer1 = self._make_layer(block, w*16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, w*32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, w*64, num_blocks[2], stride=2)
        self.linear = nn.Linear(w*64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1):
    return ResNet(BasicBlock, [3, 3, 3], w=w)

In [9]:
# evaluates accuracy
def evaluate(model, loader=test_loader):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to('mps'))
            pred = outputs.argmax(dim=1)
            correct += (labels.to('mps') == pred).sum().item()
    return correct

# evaluates loss
def evaluate1(model, loader=test_loader):
    model.eval()
    losses = []
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to('mps'))
            loss = F.cross_entropy(outputs, labels.to('mps'))
            losses.append(loss.item())
    return np.array(losses).mean()

In [10]:
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, conv, bn):
    pre_weights = [
        conv.weight,
        bn.weight,
        bn.bias,
        bn.running_mean,
        bn.running_var,
    ]
    for w in pre_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('ab,bcde->acde', perm_map, w)
        elif len(w.shape) == 2:
            transform = perm_map @ w
        else:
            transform = w @ perm_map.t()
#         assert torch.allclose(w[perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[perm_map]

# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(perm_map, after_convs):
    if not isinstance(after_convs, list):
        after_convs = [after_convs]
    post_weights = [c.weight for c in after_convs]
    for w in post_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('abcd,be->aecd', w, perm_map.t())
        elif len(w.shape) == 2:
            transform = w @ perm_map.t()
    #     assert torch.allclose(w[:, perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[:, perm_map, :, :]

def permute_cls_output(perm_map, linear):
    for w in [linear.weight, linear.bias]:
        w.data = perm_map @ w

In [11]:
modela = resnet20(w=4).to('mps')
modelb = resnet20(w=4).to('mps')
load_model(modela, 'resnet20x4_v2')
load_model(modelb, 'resnet20x4_v1')

evaluate(modela), evaluate(modelb)



(9540, 9521)

In [15]:
conv1_a = modela.conv1
conv1_b = modelb.conv1

bn1_a = modela.bn1
bn1_b = modelb.bn1
_ = modela.eval()
_ = modelb.eval()

In [16]:
output_prep = lambda x: x.flatten(1)
input_prep = lambda x: x.transpose(1, 0).flatten(1)

In [58]:
w_a = output_prep(conv1_a.weight).cpu().detach()
w_b = output_prep(conv1_b.weight).cpu().detach()
gamma_a = bn1_a.bias.cpu().detach()
gamma_b = bn1_b.bias.cpu().detach()
beta_a = bn1_a.weight.cpu().detach()
beta_b = bn1_b.weight.cpu().detach()
mu_a = bn1_a.running_mean.cpu().detach()
mu_b = bn1_b.running_mean.cpu().detach()
var_a = bn1_a.running_var.cpu().detach()
var_b = bn1_b.running_var.cpu().detach()

In [59]:
A = (beta_a.reshape(-1, 1) * w_a) / var_a.reshape(-1, 1)
B = (beta_b.reshape(-1, 1) * w_b) / var_b.reshape(-1, 1)

In [60]:
import geotorch
import torch
import torch.nn as nn

In [61]:
class P_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.P = nn.Linear(64,64, bias=False)
        geotorch.orthogonal(self.P, 'weight')
    
    def other(self):
        obj_gamma = torch.norm(self.P(gamma_a) - gamma_b)
        obj_C = torch.norm(self.P(A.T).T - B)
        obj_mu = torch.norm(self.P(mu_a) - mu_b)
        return obj_gamma + obj_C + obj_mu
    
    def forward(self, x):
        gamma_a, gamma_b, A, B, mu_a, mu_b = x
        obj_gamma = torch.norm(self.P(gamma_a) - gamma_b)
        obj_C = torch.norm(self.P(A.T).T - B)
        obj_mu = torch.norm(self.P(mu_a) - mu_b)
        return obj_gamma + obj_C + obj_mu

In [66]:
model = P_Model()
optimizer = torch.optim.Adam(model.parameters(), lr=.003)

In [67]:
for i in range(1, 1001):
    optimizer.zero_grad()
    loss = model((gamma_a, gamma_b, A, B, mu_a, mu_b))
    loss.backward(retain_graph=True)
    optimizer.step()
    if i == 0 or i % 100 == 1:
        print('step {}: {:.3f}'.format(i, loss))

step 1: 5.683
step 101: 1.662
step 201: 1.579
step 301: 1.557
step 401: 1.548
step 501: 1.543
step 601: 1.541
step 701: 1.539
step 801: 1.538
step 901: 1.537


In [69]:
P = model.P.weight