In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

# Datasets

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4

#### CIFAR-10

In [None]:
trainset_10 = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader_10 = torch.utils.data.DataLoader(trainset_10, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

valset_10 = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
valloader_10 = torch.utils.data.DataLoader(testset_10, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

loaders = {
    'train': trainloader_10,
    'val': valloader_10
}

Files already downloaded and verified
Files already downloaded and verified


# Models

MLP-L-H

In [None]:
class MLP(nn.Module):
  def __init__(self, L, H, fixdim, simplex):
    super().__init__()
    
    self.relu = nn.ReLU()

    layers = [nn.Linear(3072, H)]
    bns = [nn.BatchNorm1d(H)]

    num_classes = 10

    for i in range(L-2):
      layers.append(nn.Linear(H, H))
      bns.append(nn.BatchNorm1d(H))
        
    if fixdim:
      layers.append(nn.Linear(H, num_classes))
      self.fc_out= nn.Linear(num_classes, num_classes)
    else:
      layers.append(nn.Linear(H, H))
      self.fc_out = nn.Linear(H, num_classes)
    
    bns.append(nn.BatchNorm1d(H))

    self.fcs = nn.ModuleList(layers)
    self.bns = nn.ModuleList(bns)

    if simplex:
      weight = torch.sqrt(torch.tensor(num_classes/(num_classes-1)))*(torch.eye(num_classes)-(1/num_classes)*torch.ones((num_classes, num_classes)))
      weight /= torch.sqrt((1/num_classes*torch.norm(weight, 'fro')**2))
      if fixdim:
        self.fc_out.weight = nn.Parameter(weight)
      else:
        self.fc_out.weight = nn.Parameter(torch.mm(weight, torch.eye(num_classes, H)))
      self.fc_out.weight.requires_grad_(False)
    
  def forward(self, X):
    for layer, bn in zip(self.fcs, self.bns):
        h = self.relu(bn(layer(h)))
    return self.fc_out(h)

CONV-L-H

In [None]:
class CONV(nn.Module):
  def __init__(self, L, H, fixdim, simplex):
    super().__init__()
    
    self.relu = nn.ReLU()

    starting_layers =  [nn.Conv2d(1, H, kernel_size=2, stride=2), nn.Conv2d(1, H, kernel_size=2, stride=2)]
    starting_bns = [nn.BatchNorm2d(H), nn.BatchNorm2d(H)]

    self.conv_in = nn.ModuleList(starting_layers)
    self.bn_in = nn.ModuleList(starting_bns)

    layers = []
    bns = []

    num_classes = 10

    for i in range(L):
      layers.append(nn.Conv2d(H, H, kernel_size=3, stride=1, padding=1))
      bns.append(nn.BatchNorm2d(H))
        
    if fixdim:
      layers.append(nn.Linear(49*H, num_classes))
      bns.append(nn.BatchNorm1d(H))
      self.fc_out= nn.Linear(num_classes, num_classes)
    else:
      self.fc_out = nn.Linear(49*H, num_classes)

    self.convs = nn.ModuleList(layers)
    self.bns = nn.ModuleList(bns)

    if simplex:
      weight = torch.sqrt(torch.tensor(num_classes/(num_classes-1)))*(torch.eye(num_classes)-(1/num_classes)*torch.ones((num_classes, num_classes)))
      weight /= torch.sqrt((1/num_classes*torch.norm(weight, 'fro')**2))
      if fixdim:
        self.fc_out.weight = nn.Parameter(weight)
      else:
        self.fc_out.weight = nn.Parameter(torch.mm(weight, torch.eye(num_classes, H)))
      self.fc_out.weight.requires_grad_(False)
    
    def forward(self, X):
      h = X
      for layer, bn in zip(self.conv_in, self.bn_in):
        h = bn(layer(h))
      h = self.relu(h)
      for layer, bn in zip(self.convs, self.bns):
        h = self.relu(bn(layer(h)))
      return self.fc_out(h)

# Evaluations

eval utils

In [None]:
def compute_accuracy(output, target):
  batch_size = target.size(0)

  _, pred = output.topk(1, 1, True, True)
  pred = pred.t()
  correct = pred.eq(target.view(1, -1).expand_as(pred))

  correct_k = correct[:k].reshape(-1).float().sum(0)
  return return correct_k.mul_(100.0 / batch_size)

def compute_info(args, model, fc_features, dataloader, isTrain=True):
  mu_G = 0
  mu_c_dict = dict()
  top1 = AverageMeter()
  top5 = AverageMeter()
  for batch_idx, (inputs, targets) in enumerate(dataloader):
    inputs, targets = inputs.to(args.device), targets.to(args.device)
    with torch.no_grad():
      outputs = model(inputs)

    features = fc_features.outputs[0][0]
    fc_features.clear()

    mu_G += torch.sum(features, dim=0)

    for b in range(len(targets)):
      y = targets[b].item()
      if y not in mu_c_dict:
        mu_c_dict[y] = features[b, :]
      else:
        mu_c_dict[y] += features[b, :]

    prec1 = compute_accuracy(outputs[0].data, targets.data)
    top1.update(prec1.item(), inputs.size(0))

    if isTrain:
      mu_G /= sum(CIFAR10_TRAIN_SAMPLES)
      for i in range(len(CIFAR10_TRAIN_SAMPLES)):
        mu_c_dict[i] /= CIFAR10_TRAIN_SAMPLES[i]
    else:
      mu_G /= sum(CIFAR10_TEST_SAMPLES)
      for i in range(len(CIFAR10_TEST_SAMPLES)):
        mu_c_dict[i] /= CIFAR10_TEST_SAMPLES[i]

    return mu_G, mu_c_dict, top1.avg

def compute_Sigma_W(args, model, fc_features, mu_c_dict, dataloader, isTrain=True):
  Sigma_W = 0
  for batch_idx, (inputs, targets) in enumerate(dataloader):

    inputs, targets = inputs.to(args.device), targets.to(args.device)

    with torch.no_grad():
      outputs = model(inputs)

    features = fc_features.outputs[0][0]
    fc_features.clear()

    for b in range(len(targets)):
      y = targets[b].item()
      Sigma_W += (features[b, :] - mu_c_dict[y]).unsqueeze(1) @ (features[b, :] - mu_c_dict[y]).unsqueeze(0)

  if isTrain:
    Sigma_W /= sum(CIFAR10_TRAIN_SAMPLES)
  else:
    Sigma_W /= sum(CIFAR10_TEST_SAMPLES)

  return Sigma_W.cpu().numpy()

def compute_Sigma_B(mu_c_dict, mu_G):
  Sigma_B = 0
  K = len(mu_c_dict)
  for i in range(K):
    Sigma_B += (mu_c_dict[i] - mu_G).unsqueeze(1) @ (mu_c_dict[i] - mu_G).unsqueeze(0)

  Sigma_B /= K

  return Sigma_B.cpu().numpy()

def compute_ETF(W):
  K = W.shape[0]
  WWT = torch.mm(W, W.T)
  WWT /= torch.norm(WWT, p='fro')

  sub = (torch.eye(K) - 1 / K * torch.ones((K, K))).cuda() / pow(K - 1, 0.5)
  ETF_metric = torch.norm(WWT - sub, p='fro')
  return ETF_metric.detach().cpu().numpy().item()

def compute_W_H_relation(W, mu_c_dict, mu_G):
  K = len(mu_c_dict)
  H = torch.empty(mu_c_dict[0].shape[0], K)
  for i in range(K):
    H[:, i] = mu_c_dict[i] - mu_G

  WH = torch.mm(W, H.cuda())
  WH /= torch.norm(WH, p='fro')
  sub = 1 / pow(K - 1, 0.5) * (torch.eye(K) - 1 / K * torch.ones((K, K))).cuda()

  res = torch.norm(WH - sub, p='fro')
  return res.detach().cpu().numpy().item(), H

def compute_Wh_b_relation(W, mu_G, b):
  Wh = torch.mv(W, mu_G.cuda())
  res_b = torch.norm(Wh + b, p='fro')
  return res_b.detach().cpu().numpy().item()

Evaluations

In [None]:
# https://pypi.org/project/torch-intermediate-layer-getter/

def NC1():
  return np.trace

def NC2():
  pass

def NC3():
  pass

def NC4():
  pass

def accuracy(output, target):
  # NCC accuracy
  pass

def get_evaluations(model, loader, n_layers):
  cols = []
  vals = []
  for layer in n_layers:
    cols.extend([f'acc_{layer}'])




# Training code

In [None]:
def train(model, device, optimizer, criterion, dataloaders, learning_rate, num_epochs, savedir):
  # Train baseline
  best_loss = np.inf
  early_stop_count = 0

  losses = []

  for epoch in range(num_epochs):
      for phase in ['train', 'val']:
          if phase == 'train':
            model.train()
          else:
            model.eval()

          running_loss = 0.0
          num_preds = 0

          bar = tqdm(dataloaders[phase], desc='Epoch {} {}'.format(epoch, phase).ljust(20))
          for i, batch in enumerate(bar):
            X, y = batch
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
              preds = model(X)
              loss = criterion(preds, y)
              if phase == 'train':
                loss.backward()
                optimizer.step()

          running_loss += loss.item()
          num_preds += 1
          if i % 10 == 0:
            bar.set_postfix(loss='{:.2f}'.format(running_loss / num_preds))

          epoch_loss = running_loss / num_preds
          # deep copy the model
          if phase == 'val':
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(baseline_net.state_dict())
          cols, vals = get_evaluations(blah)
          df = pd.DataFrame(columns=['epoch', 'phase']+cols)
          df.loc[0] = [epoch, phase] + vals
          losses.append(df)

    model.load_state_dict(best_model_wts)
    model.eval()

    # Save model weights
    Path(os.path.join(savedir, 'model.pth')).parent.mkdir(parents=True, exist_ok=True)
    torch.save(baseline_net.state_dict(), os.path.join(savedir, 'model.pth'))


    losses = pd.concat(losses, axis=0, ignore_index=True)
    losses.to_csv(os.path.join(savedir, 'losses.csv'), index=False)

    return model