In [1]:
import math
import torch
import torch.nn as nn

if torch.cuda.is_available():
    dev = 'cuda:0'
else:
    dev = 'cpu'
device = torch.device(dev)

def get_generators(k: int, l: int, m: int, D: int=50) -> list[torch.nn.Module]:
    return [build_generator(l, m, D) for _ in range(k)]


@torch.no_grad()
def get_regression_targets(n:int, k: int, l: int, generators: list[torch.nn.Module], sample_mode: str='random') -> tuple[torch.Tensor, torch.Tensor]:
    if sample_mode == 'random':
        z = torch.rand(n, k, l)
    elif sample_mode == 'diagonal':
        z = torch.repeat_interleave(torch.rand(n, l), k, dim=0)
        z = torch.reshape(z, (n, k, l))
    elif sample_mode == 'orthogonal':
        _z = torch.rand(n, l)
        mask = torch.stack([torch.arange(n), torch.randint(k, (n, 1)).squeeze(dim=1)], dim=1).long()
        z = torch.zeros(n, k, l)
        z[mask.chunk(chunks=2, dim=1)] = _z.unsqueeze(1)
    
    x = [torch.stack([generators[j](z[i][j]) for j in range(k)]) for i in range(n)]
    x = torch.stack(x)

    return z, x

def build_generator(l: int, m: int, D: int, slope: float=0.2) -> nn.Sequential:
    g = nn.Sequential(
        nn.Linear(l, D),
        nn.LeakyReLU(slope),
        nn.Linear(D, m),
        nn.LeakyReLU(slope)
    )
    g.apply(init_min_cond)
    return g


# class Generator(torch.nn.Module):
#     def __init__(self, l: int, m: int, D: int):
#         super(Generator, self).__init__()
#         self.fc1 = nn.Linear(l, D)
#         self.relu1 = nn.LeakyReLU(0.2)
#         self.fc2 = nn.Linear(D, m)
#         self.relu2 = nn.LeakyReLU(0.2)
#         self.apply(init_min_cond)
    
#     def forward(self, x):
#         x = self.relu1(self.fc1(x))
#         x = self.relu2(self.fc2(x))
#         return x


def init_min_cond(m: torch.nn.Module, n_samples: int=7500) -> torch.Tensor:
    if isinstance(m, nn.Linear):
        w = m.weight.data
        k = 1 / w.size(0)

        w = torch.nn.functional.normalize(w, p=2)
        cond = condition_number(w)

        for _ in range(n_samples):
            _w = 2 * math.sqrt(k) * torch.rand(w.size()) - math.sqrt(k)
            _w = nn.functional.normalize(_w, p=2)
            _cond = condition_number(_w)

            if _cond < cond:
                w = _w
                cond = _cond
        
        m.weight.data = w


def condition_number(t: torch.Tensor) -> float:
    return torch.norm(t, p=2) / torch.norm(torch.pinverse(t), p=2)


  from .autonotebook import tqdm as notebook_tqdm


# Models
We consider 2 models:
- an autoencoder, where we can directly impose regularizations on the decoder
- an MLP, where we can only impose regularization on the encoder

In [40]:
def build_MLP(d_in: int, d_out: int, D: int=120, slope: float=0.2, **kwargs) -> nn.Sequential:
    return nn.Sequential(
        nn.Linear(d_in, D),
        nn.LeakyReLU(slope),
        nn.Linear(D, d_out),
        nn.LeakyReLU(slope)
    )


def MLP(k: int, l: int, m: int, D: int=120, **kwargs):
    return build_MLP(k * m, k * l, D, **kwargs)


def MLP3(k: int, l: int, m: int, D: int=120, slope: float=0.2, **kwargs):
    return nn.Sequential(
        nn.Linear(k * m, D),
        nn.LeakyReLU(slope),
        nn.Linear(D, D),
        nn.LeakyReLU(slope),
        nn.Linear(D, k * l),
        nn.LeakyReLU(slope)
    )


class CompositionalMLP(torch.nn.Module):
    def __init__(self, k: int, l: int, m: int, D: int=120, **kwargs):
        super(CompositionalMLP, self).__init__()
        self.k = k
        self.models = nn.ModuleList([build_MLP(k * m, l, round(D / k), **kwargs) for _ in range(k)])
    
    def forward(self, x):
        x = x.reshape(x.size(0), self.k, -1)
        out = []
        for i in range(len(self.models)):
            x_i = torch.zeros_like(x)
            x_i[:, i, :] = x[:, i, :]
            x_i = torch.flatten(x_i, start_dim = 1)
            out.append(self.models[i](x_i))
        return torch.cat(out, dim=1)


class Autoencoder(torch.nn.Module):
    def __init__(self, k: int, l: int, m: int, D: int=120, **kwargs):
        super(Autoencoder, self).__init__()
        self.f = build_MLP(k * m, k * l, D, **kwargs)
        self.g = build_MLP(k * l, k * m, D, **kwargs)
    
    def forward(self, x):
        z = self.f(x)
        out = self.g(z)
        return out, z

# Training
2 regularizations
- compositional contrast from Provably Learning Object-Centric Representations#training a model
- regularize Hessian

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics import R2Score
from tqdm import tqdm


# could be more efficient with torch.utils.data.TensorDataset, but I couldn't be assed to look up the documentation
class Dataset(torch.utils.data.Dataset):
    def __init__(self, n: int, k: int, l: int, generators: list[torch.nn.Module], sample_mode: str='random'):
        super(Dataset, self).__init__()
        self.n = n
        self.z, self.x = get_regression_targets(n, k, l, generators, sample_mode)
    
    def __len__(self):
        return self.n
    
    def __getitem__(self, idx):
        return self.x[idx], self.z[idx]


def train(model: torch.nn.Module, trainloader: torch.utils.data.DataLoader, lr: float=0.001, epochs: int=10):
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # for epoch in tqdm(range(epochs)):
    for epoch in range(epochs):
        cum_loss = 0

        for batch, data in enumerate(trainloader, 0):
            x, z = data

            optimizer.zero_grad()

            out = model(torch.flatten(x, start_dim=1))
            loss = criterion(out, torch.flatten(z, start_dim=1))
            cum_loss += loss
            loss.backward()
            optimizer.step()
        
        cum_loss /= (batch + 1)
    
    return cum_loss


# train for the same number of iterations (batches) independent of dataset size (i.e. without epochs)
def train_iter(model: torch.nn.Module, trainloader: torch.utils.data.DataLoader, lr: float=0.001):
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    cum_loss = 0

    for batch, data in enumerate(trainloader, 0):
        x, z = data

        optimizer.zero_grad()

        out = model(torch.flatten(x.to(dev), start_dim=1))
        loss = criterion(out, torch.flatten(z.to(dev), start_dim=1))
        cum_loss += loss
        loss.backward()
        optimizer.step()
    
    cum_loss /= (batch + 1)
    return cum_loss.to(torch.device('cpu')).item()


# train for the same number of iterations (batches) independent of dataset size (i.e. without epochs)
def train_iter_reg(model: torch.nn.Module, trainloader: torch.utils.data.DataLoader, regularization, lamda: float=0.5, lr: float=0.001):
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    cum_loss = 0

    for batch, data in enumerate(trainloader, 0):
        x, z = data

        optimizer.zero_grad()

        x.requires_grad = True
        x = x.flatten(1).to(dev)
        z = z.flatten(1).to(dev)

        z_hat = model(x)
        loss = criterion(z_hat, z) + lamda * regularization(model, x)
        cum_loss += loss
        loss.backward()
        optimizer.step()
    
    cum_loss /= (batch + 1)
    return cum_loss.to(torch.device('cpu')).item()


# class CompContrast(torch.nn.Module):
#     def __init__(self) -> None:
#         super(CompContrast, self).__init__()
    
#     def forward(self, output, input):
#         output.backward(inputs=input)
#         return 0


def comp_contrast(func: torch.nn.Module, inputs: torch.Tensor) -> torch.Tensor:
    """Calculate the compositional contrast for a function `func` with respect to `inputs`.

    The output is calculated as the mean over the batch dimension.
    `inputs` needs to be flattened except for the batch dimension and `requires_grad` needs to be set to `True`.
    """
    assert inputs.requires_grad == True, 'To calculate the derivative by `inputs` `requires_grad` needs to be set to `True`.'

    # compute the jacobian with respect to the inputs
    # because this is done for the whole batch, the output has dimensions [batch, out, batch, in]
    # but Jacobian[i, :, j, :] = 0 because there is no interaction between batches, so dim 2 can be removed
    # after indexing the Jacobian has shape [batch, out, in]
    jac = torch.autograd.functional.jacobian(func, inputs)
    index = torch.arange(jac.shape[0]).reshape(-1, 1, 1, 1).expand(jac.shape[0], jac.shape[1], 1, jac.shape[3]).to(dev)
    jac = torch.gather(jac, 2, index).squeeze()

    # compute the compositional contrast as the sum of all pairs of partial derivatives for all outputs
    # average over the batch dimension
    cc = torch.mean(torch.sum(torch.triu(jac.unsqueeze(2).repeat(1, 1, 40, 1) * jac.unsqueeze(2).repeat(1, 1, 40, 1).transpose(2, 3), diagonal=1), dim=(1, 2, 3)))

    return cc


@torch.no_grad()
def test(model: torch.nn.Module, testloader: torch.utils.data.DataLoader):
    cum_score = 0

    for batch, data in enumerate(testloader, 0):
        x, z = data
        out = model(torch.flatten(x.to(dev), start_dim=1))
        r2score = R2Score(out.size(1)).to(dev)
        score = r2score(out, torch.flatten(z.to(dev), start_dim=1))
        cum_score += score
    
    cum_score /= (batch + 1)
    return cum_score.to(torch.device('cpu')).item()

In [4]:
import copy

k = 4
l = 2
m = 10

torch.manual_seed(0)

print('Build generators...')
g = get_generators(k, l, m)

print('Build test data...')
te_ds = Dataset(1000, k, l, g, 'random')
te_ldr = torch.utils.data.DataLoader(te_ds, batch_size=1000, shuffle=True)

Build generators...
Build test data...


In [88]:
bs = 64
nb = int(2**13 / bs)

res = []
for log_n in range(4, 13):
    n = 2**log_n
    print(f'n={n:4d}')

    print('Build train data...')
    tr_ds_rand = Dataset(n, k, l, g, 'random')
    tr_ds_diag = Dataset(n, k, l, g, 'diagonal')
    tr_ds_orth = Dataset(n, k, l, g, 'orthogonal')

    tr_ldr_rand = torch.utils.data.DataLoader(tr_ds_rand, batch_sampler=torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(tr_ds_rand, num_samples=nb), bs, False))
    tr_ldr_diag = torch.utils.data.DataLoader(tr_ds_diag, batch_sampler=torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(tr_ds_diag, num_samples=nb), bs, False))
    tr_ldr_orth = torch.utils.data.DataLoader(tr_ds_orth, batch_sampler=torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(tr_ds_orth, num_samples=nb), bs, False))

    print('Build models...')
    mlp_rand = MLP(k, l, m).to(dev)
    mlp_diag = copy.deepcopy(mlp_rand)
    mlp_orth = copy.deepcopy(mlp_rand)
    # cmlp_rand = CompositionalMLP(k, l, m).to(dev)
    # cmlp_diag = copy.deepcopy(cmlp_rand)
    # cmlp_orth = copy.deepcopy(cmlp_rand)

    print('Train models...')
    for i in tqdm(range(500)):
        res.append({'metric': 'train loss', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'normal', 'sampling': 'random', 'val': train_iter_reg(mlp_rand, tr_ldr_rand, comp_contrast)})
        res.append({'metric': 'test R²', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'normal', 'sampling': 'random', 'val': test(mlp_rand, te_ldr)})
        res.append({'metric': 'train loss', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'normal', 'sampling': 'diagonal', 'val': train_iter_reg(mlp_diag, tr_ldr_diag, comp_contrast)})
        res.append({'metric': 'test R²', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'normal', 'sampling': 'diagonal', 'val': test(mlp_diag, te_ldr)})
        res.append({'metric': 'train loss', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'normal', 'sampling': 'orthogonal', 'val': train_iter_reg(mlp_orth, tr_ldr_orth, comp_contrast)})
        res.append({'metric': 'test R²', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'normal', 'sampling': 'orthogonal', 'val': test(mlp_orth, te_ldr)})

        # res.append({'metric': 'train loss', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'compositional', 'sampling': 'random', 'val': train_iter(cmlp_rand, tr_ldr_rand)})
        # res.append({'metric': 'test R²', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'compositional', 'sampling': 'random', 'val': test(cmlp_rand, te_ldr)})
        # res.append({'metric': 'train loss', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'compositional', 'sampling': 'diagonal', 'val': train_iter(cmlp_diag, tr_ldr_diag)})
        # res.append({'metric': 'test R²', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'compositional', 'sampling': 'diagonal', 'val': test(cmlp_diag, te_ldr)})
        # res.append({'metric': 'train loss', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'compositional', 'sampling': 'orthogonal', 'val': train_iter(cmlp_orth, tr_ldr_orth)})
        # res.append({'metric': 'test R²', 'n samples': n, 'n batches': (i+1)*nb, 'model': 'compositional', 'sampling': 'orthogonal', 'val': test(cmlp_orth, te_ldr)})

import pandas as pd
res_df = pd.DataFrame.from_dict(res)

import pickle as pk
with open(r'res_i500_comp.pkl', 'wb') as f:
    pk.dump(res_df, f)

n=  16
Build train data...
Build models...
Train models...


100%|██████████| 500/500 [09:40<00:00,  1.16s/it]


n=  32
Build train data...
Build models...
Train models...


100%|██████████| 500/500 [08:42<00:00,  1.05s/it]


n=  64
Build train data...
Build models...
Train models...


 93%|█████████▎| 467/500 [08:09<00:42,  1.29s/it]

In [5]:
tr_ds_rand = Dataset(1000, k, l, g, 'random')
tr_ldr_rand = torch.utils.data.DataLoader(tr_ds_rand, batch_sampler=torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(tr_ds_rand, num_samples=64), 64, False))

In [67]:
mlp_rand = MLP3(4, 2, 10).to(dev)

x, z = tr_ldr_rand._get_iterator().next()

x.requires_grad = True
x = x.flatten(1).to(dev)

out = mlp_rand(torch.flatten(x.to(dev), start_dim=1))


In [29]:
grads = []
for batch in range(out.shape[0]):
    grads.append([])
    for i in range(out.shape[1]):
        grads[batch].append(torch.autograd.grad(out[batch, i], x, retain_graph=True, allow_unused=True)[0][batch])
grads = torch.stack([torch.stack(_grad) for _grad in grads]).flatten(2)


In [53]:
cc = []
for b in range(jac.shape[0]):
    cc.append([])
    for o in range(jac.shape[1]):
        _jac = jac[b, o]
        cc[b].append(torch.sum(torch.triu(_jac.repeat(40, 1) * _jac.repeat(40, 1).t(), diagonal=1)))
cc = torch.stack([torch.stack(_cc) for _cc in cc])

In [13]:
# compute the jacobian with respect to the inputs
# because this is done for the whole batch, the output has dimensions [batch, out, batch, in]
# but Jacobian[i, :, j, :] = 0 because there is no interaction between batches, so dim 2 can be removed
# after indexing the Jacobian has shape [batch, out, in]
jac = torch.autograd.functional.jacobian(mlp_rand, x)
index = torch.arange(jac.shape[0]).reshape(-1, 1, 1, 1).expand(jac.shape[0], jac.shape[1], 1, jac.shape[3]).to(dev)
jac = torch.gather(jac, 2, index).squeeze()

# compute the compositional contrast as the sum of all pairs of partial derivatives for all outputs
# average over the batch dimension
# cc = torch.mean(torch.sum(torch.triu(jac.unsqueeze(2).repeat(1, 1, 40, 1) * jac.unsqueeze(2).repeat(1, 1, 40, 1).transpose(2, 3), diagonal=1), dim=(1, 2, 3)))

In [49]:
def select_output(func: callable, idx) -> callable:
    def _func(t: torch.Tensor) -> torch.Tensor:
        return func(t)[idx]
    
    return _func

In [29]:
def pow_reducer(x):
    return x.pow(3).sum(1)

In [56]:
mlp_rand(x).mean()

tensor(0.0130, device='cuda:0', grad_fn=<MeanBackward0>)

In [61]:
x.shape

torch.Size([64, 40])

In [69]:
hess = torch.autograd.functional.hessian(select_output(mlp_rand, (0, 0)), x)
# index = torch.arange(hess.shape[0]).reshape(-1, 1, 1, 1).expand(hess.shape[0], hess.shape[1], 1, hess.shape[3]).to(dev)
# hess = torch.gather(hess, 2, index).squeeze()

In [71]:
hess.mean()

tensor(0., device='cuda:0')

In [26]:
hessian = []
for batch in range(out.shape[0]):
    hessian.append([])
    for i in range(out.shape[1]):
        hessian[batch].append(torch.autograd.functional.hessian(select_output(mlp_rand, (batch, i)), x)[0][batch])
hessian = torch.stack([torch.stack(_grad) for _grad in grads]).flatten(2)