In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

# Datasets

In [None]:
# For CIFAR10 we used random cropping, random horizontal flips, and random rotations (by 15k degrees for k uniformly sampled from [24]). All datasets were standardized


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.RandomCrop(224),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.RandomRotation(180),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 128

#### CIFAR-10

In [None]:
trainset_10 = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader_10 = torch.utils.data.DataLoader(trainset_10, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

valset_10 = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
valloader_10 = torch.utils.data.DataLoader(testset_10, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


# Models

MLP-L-H

In [None]:
class MLP(nn.Module):
  def __init__(self, L, H, fixdim, simplex):
    super().__init__()
    
    self.relu = nn.ReLU()

    layers = [nn.Linear(3072, H)]
    bns = [nn.BatchNorm1d(H)]

    num_classes = 10

    for i in range(L-2):
      layers.append(nn.Linear(H, H))
      bns.append(nn.BatchNorm1d(H))
        
    if fixdim:
      layers.append(nn.Linear(H, num_classes))
      self.fc_out= nn.Linear(num_classes, num_classes)
    else:
      layers.append(nn.Linear(H, H))
      self.fc_out = nn.Linear(H, num_classes)
    
    bns.append(nn.BatchNorm1d(H))

    self.fcs = nn.ModuleList(layers)
    self.bns = nn.ModuleList(bns)

    if simplex:
      weight = torch.sqrt(torch.tensor(num_classes/(num_classes-1)))*(torch.eye(num_classes)-(1/num_classes)*torch.ones((num_classes, num_classes)))
      weight /= torch.sqrt((1/num_classes*torch.norm(weight, 'fro')**2))
      if fixdim:
        self.fc_out.weight = nn.Parameter(weight)
      else:
        self.fc_out.weight = nn.Parameter(torch.mm(weight, torch.eye(num_classes, H)))
      self.fc_out.weight.requires_grad_(False)
    
  def forward(self, X):
    ret = {}
    h = X
    for i in range(len(self.fcs)):
      h = self.fcs[i](h)
      ret[str(i+1)] = h
      count += 1
      h = self.relu(self.bns[i](h))
    h = self.fc_out(h)
    ret['final'] = h
    return ret, len(self.fcs)

CONV-L-H

In [None]:
class CONV(nn.Module):
  def __init__(self, L, H, simplex):
    super().__init__()
    
    self.relu = nn.ReLU()

    starting_layers =  [nn.Conv2d(1, H, kernel_size=2, stride=2), nn.Conv2d(1, H, kernel_size=2, stride=2)]
    starting_bns = [nn.BatchNorm2d(H), nn.BatchNorm2d(H)]

    self.conv_in = nn.ModuleList(starting_layers)
    self.bn_in = nn.ModuleList(starting_bns)

    layers = []
    bns = []

    num_classes = 10

    for i in range(L):
      layers.append(nn.Conv2d(H, H, kernel_size=3, stride=1, padding=1))
      bns.append(nn.BatchNorm2d(H))
        
    self.fc_out = nn.Linear(64*H, num_classes)
    self.classify = nn.Linear(num_classes, num_classes)

    self.convs = nn.ModuleList(layers)
    self.bns = nn.ModuleList(bns)

    if simplex:
      weight = torch.sqrt(torch.tensor(num_classes/(num_classes-1)))*(torch.eye(num_classes)-(1/num_classes)*torch.ones((num_classes, num_classes)))
      weight /= torch.sqrt((1/num_classes*torch.norm(weight, 'fro')**2))
      self.fc_out.weight = nn.Parameter(weight)
      self.fc_out.weight.requires_grad_(False)
    
    def forward(self, X):
      h = X
      for layer, bn in zip(self.conv_in, self.bn_in):
        h = bn(layer(h))
      h = self.relu(h)

      for i in range(len(self.convs):
        h = self.convs[i](h)
        h = self.relu(self.bns[i](h))
      h = self.fc_out(h)
      h = self.classify(h)
      return h, np.argmax(h, axis=1)

    def get_layer_info(self, X):
      ret = {}
      h = X
      for layer, bn in zip(self.conv_in, self.bn_in):
        h = bn(layer(h))
      h = self.relu(h)
      for i in range(len(self.convs):
        h = self.convs[i](h)
        ret[str(i+1)] = h
        h = self.relu(self.bns[i](h))
      h = self.fc_out(h)
      ret['final'] = h
      return ret, len(self.convs)

# Evaluations

In [None]:
from sklearn.neighbors import NearestCentroid

eval utils

Evaluations

In [None]:
# NCC accuracy
def ncc_accuracy(features, classes):
  clf = NearestCentroid()
  clf.fit(featuers, classes)
  preds = np.array([clf.predict(features)]).T
  assert preds.shape == classes.shape
  return np.sum(preds == classes) / preds.shape[0]

# class distance normalized variance:
# (Var(Q1) + Var(Q2)) / (2 ||mu(Q1) - mu(Q2)||^2)
# average over all pairs of nonequal classes
def cdnv(Q1, Q2):
  return (np.var(Q1) + np.var(Q2)) / (2*np.linalg.norm(np.mean(Q1) - np.mean(Q2))**2)

def cdnv_tot(features, classes):
  pass
  
def get_evaluations(outs, labels, n_layers):
  cols = []
  vals = []
  for layer in n_layers:
    cols.append(f'acc_{layer}')
    vals.append(ncc_acuracy(outs[str(layer)], labels))
    cols.append(f'cdnv_{layer}')
    vals.append(cdnv_tot(outs[str(layer)], labels))




# Training code

In [None]:
def train(model, device, optimizer, scheduler, criterion, dataloaders, savedir, n_layers, num_epochs=300):
  # Train baseline
  best_loss = np.inf

  losses = []

  for epoch in range(num_epochs):
      for phase in ['train', 'val']:
          if phase == 'train':
            model.train()
          else:
            model.eval()

          running_loss = 0.0
          num_preds = 0

          bar = tqdm(dataloaders[phase], desc='Epoch {} {}'.format(epoch, phase).ljust(20))
          for i, batch in enumerate(bar):
            X, y = batch
            X, y = X.to(device), y.to(device)

            with torch.set_grad_enabled(phase == 'train'):
              outs, preds = model(X)
              loss = criterion(outs, y)
              if phase == 'train':
                loss.backward()
                optimizer.step()

          running_loss += loss.item()
          num_preds += 128
          if i % 10 == 0:
            bar.set_postfix(loss='{:.2f}'.format(running_loss / num_preds))

          epoch_loss = running_loss / num_preds
          # deep copy the model
          if phase == 'val':
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
          cols, vals = get_evaluations(model, loader, n_layers)
          df = pd.DataFrame(columns=['epoch', 'phase']+cols)
          df.loc[0] = [epoch, phase] + vals
          losses.append(df)

    model.load_state_dict(best_model_wts)
    model.eval()

    # Save model weights
    Path(os.path.join(savedir, 'model.pth')).parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(savedir, 'model.pth'))


    losses = pd.concat(losses, axis=0, ignore_index=True)
    losses.to_csv(os.path.join(savedir, 'losses.csv'), index=False)

    return model

# Set up model parameters

In [None]:
n_layers = 5
etf = False
model = CONV(layers, 100, etf)
device = torch.device("cuda:0")
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.1)
criterion = torch.nn.CrossEntropyLoss()
savedir = './hehe'
dataloaders = {
    'train': trainloader_10,
    'val': valloader_10
}


train(model, device, optimizer, scheduler, criterion, dataloaders, savedir, n_layers, num_epochs=300)