In [None]:
# Load data (MNIST)

import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np

batch_size_train = 10000
batch_size_test = 20000
log_interval = 10


num_samples = 128 # number of training samples
num_samples_test = 100 # number of test samples

new_dim1 = 28 * 1 # first dimension
new_dim2 = 28 * 1 # second dimension


old_dim = 28 # MNIST original dimension


random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)


train_loader_fashion = torch.utils.data.DataLoader(
  torchvision.datasets.FashionMNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

test_loader_fashion = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

print(example_data.shape)



In [None]:
# Create training data and test data for binary classification task

from tqdm import tqdm 
training_data = enumerate(train_loader)
test_data = enumerate(test_loader)
(batch_id1, (data_tr_old, target_tr_old)) = next(training_data)
(batch_id2, (data_test_old, target_test_old)) = next(test_data)
print("loaded")


# Classes
class1 = [2]
class2 = [9]

# Pad the image on the right and on the bottom
def filter_digits(data_old, target_old, num, threshold):
  data = torch.zeros([num, 1, new_dim1, new_dim2])
  target = torch.zeros([num])

  num_samples_per_class = int(num / 2)
  sample1 = 0
  sample2 = 0
  attempt = 0
  sample = 0

  while (sample1 < num_samples_per_class) or (sample2 < num_samples_per_class):
    if attempt == threshold:
      print("FAILED: need more samples in batch")
      return (data, target)
    # Balance classes
    is_sample_for_class1 = (target_old[attempt] in class1) and (sample1 < num_samples_per_class)
    is_sample_for_class2 = (target_old[attempt] in class2) and (sample2 < num_samples_per_class)
    if is_sample_for_class1 or is_sample_for_class2: 
      avg = torch.mean(data_old[attempt][0])
      target[sample] = target_old[attempt]
      for i in range(old_dim):
        for j in range(old_dim):
          data[sample][0][i][j] = data_old[attempt][0][i][j] - avg
      if is_sample_for_class1:
        sample1 += 1
      if is_sample_for_class2:
        sample2 += 1
      # Augment sample counts
      sample += 1
    attempt += 1

  target.apply_(lambda x: 1 if (x in class1) else -1)

  data = data.float()
  target = target.float()
  return (data, target)


# Training data and test data
(data_tr, target_tr) = filter_digits(data_tr_old, target_tr_old, num_samples, batch_size_train)
print("Created training data")
(data_test, target_test) = filter_digits(data_test_old, target_test_old, num_samples_test, batch_size_test)
print("Created test data")




loaded
Created training data
Created test data


In [None]:
# Check that classes are balanced
counter = 0
for i in range(num_samples):
  if target_tr[i] == 1:
    counter += 1
print(counter/num_samples)

counter = 0
for i in range(num_samples_test):
  if target_test[i] == 1:
    counter += 1
print(counter/num_samples_test)


0.5
0.5


In [None]:
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from matplotlib.colors import LogNorm


# Two-layer linear convolutional neural network
output_channels = 1
class Net(nn.Module):
    def __init__(self, ker_size1, ker_size2, output_channels):
        super(Net, self).__init__()
        self.ker_size1 = ker_size1
        self.ker_size2 = ker_size2
        self.output_channels = output_channels
        self.conv1 = nn.Conv2d(1, output_channels, kernel_size=(self.ker_size1, self.ker_size2), bias=False) 
        self.fc1 = nn.Linear(int(new_dim1 * new_dim2 * output_channels), 1, bias=False)


    def forward(self, x):
        y1 = F.pad(x, (0,self.ker_size2-1,0,self.ker_size1-1), mode='circular') # Circular padding 
        y1 = self.conv1(y1)
        # y1 = F.relu(y1) # ReLU activations
        y1 = y1.reshape(y1.size(0), -1)
        y1 = self.fc1(y1) 
        return y1

    def initialize(self, initialization_scale, ker_size1):
      nn.init.normal_(self.fc1.weight, mean=0.0, std=initialization_scale/np.sqrt(new_dim1))
      nn.init.normal_(self.conv1.weight, mean=0.0, std=initialization_scale/np.sqrt(ker_size1))


output = torch.zeros((num_samples, 1))
output = output.float()
output_test = torch.zeros((num_samples_test, 1))
output_test = output.float()


# Batch gradient descent
def train_minibatch(network, optimizer):
  minibatch_size = 32
  num_batch = int(num_samples/minibatch_size)
  for i in range(num_batch):
    network.train()
    optimizer.zero_grad()
    start_index = i * minibatch_size
    end_index = start_index + minibatch_size
    output = network(data_tr[start_index:end_index])
    loss = torch.sum(torch.exp(-1 * torch.mul(output.flatten(), target_tr[start_index:end_index]))) / minibatch_size
    loss.backward()
    optimizer.step()

# Evaluate training data loss
def train_eval(network):
  network.eval()
  train_loss = 0
  correct = 0
  with torch.no_grad():
    output = network(data_tr)
    train_loss = torch.sum(torch.exp(-1 * torch.mul(output.flatten(), target_tr)))
    pred = output.apply_(lambda x: 1 if x > 0 else -1)
    correct += pred.eq(target_tr.data.view_as(pred)).sum()
  train_loss /= num_samples
  print('\nTrain set: Avg. loss: {:.9f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    train_loss, correct, num_samples,
    100. * correct / num_samples))
  return train_loss

def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    output_test = network(data_test)
    test_loss = torch.sum(torch.exp(-1 * torch.mul(output_test.flatten(), target_test)))
    pred = output_test.apply_(lambda x: 1 if x > 0 else -1)
    correct += pred.eq(target_test.data.view_as(pred)).sum()
  test_loss /= num_samples_test
  accuracy = 100. * correct / num_samples_test
  losses = test_loss
  return (accuracy, losses)


# Get the information about beta
def extract_info(network, show_photo): 

  # Compute beta for linear CNNs
  beta_test = np.zeros((new_dim1,new_dim2))
  for i in range(new_dim1):
    for j in range(new_dim2):
        tempimg = torch.zeros((1,1,new_dim1, new_dim2))
        tempimg[0,0,i,j]=1
        beta_test[i,j] = network(tempimg)  

  # Compute margin
  with torch.no_grad():
    network.eval()
    output_np = np.ndarray.flatten(network(data_tr).data.numpy())
    target_np = np.ndarray.flatten(target_tr.data.numpy())
    margins = [target_np[i] * output_np[i] for i in range(num_samples)]
    min_margin = min(margins) # get the minimum margin for any datapoint 


  # Compute R(beta)
  w1 = network.conv1.weight.detach().numpy()
  w2 = network.fc1.weight.detach().numpy()
  w1_norm_sq = np.sum(np.square(w1))
  w2_norm_sq = np.sum(np.square(w2))
  print(w1_norm_sq, w2_norm_sq)
  Rbeta = (np.sum(np.square(w1)) + np.sum(np.square(w2))) * np.sqrt(new_dim1 * new_dim2)


  # Normalize by margin 
  beta_test = beta_test / min_margin # normalize to have margin 1
  hat_beta = np.absolute(np.fft.fft2(beta_test,norm='ortho'))
  Rbeta = Rbeta / min_margin

  print("l2 norm: " + str(2 * np.sqrt(new_dim1 * new_dim2)* np.linalg.norm(beta_test, ord="fro")))
  print("l1 norm: " + str(2 * np.sum(hat_beta)))
  print("Rbeta: " + str(Rbeta))

  if show_photo:
    print("Time domain:")
    plt.imshow(np.absolute(beta_test), cmap='gray')
    plt.show()
    print("Frequency domain:")
    plt.imshow(np.absolute(hat_beta), cmap='gray', norm=LogNorm(vmin=0.0001, vmax=0.08))
    plt.show()
  
  return (Rbeta, beta_test)


In [None]:
# Train and extract info about beta
import seaborn as sns
n_epochs = 100000
learning_rate_start = 0.005
momentum = 0.3
initialization_scale = 0.01


from tqdm import tqdm_notebook as tqdm
def experiment(ker_size1, ker_size2, output_channels):
  # print(class1, class2)
  network = Net(ker_size1, ker_size2, output_channels)
  network.initialize(initialization_scale, ker_size1)
  optimizer =  optim.SGD(network.parameters(), lr=learning_rate_start, momentum=momentum)
  print("Before training:")
  train_eval(network)
  # test()
  print("Start training:")
  for epoch in tqdm(range(1, n_epochs + 1)):
    train_minibatch(network, optimizer)
    if epoch % 100 == 0:
      loss = train_eval(network)
      if loss <= 0.000001: # stop at 10^-6 loss 
        break
    # After enough epochs, change the learning rate to be higher to expedite convergence
    if epoch == 200: # used to be 500
      extract_info(network, False)
      optimizer = optim.SGD(network.parameters(), lr=0.05, momentum=momentum)
    if epoch == 250: # 750
      # extract_info(False)
      optimizer = optim.SGD(network.parameters(), lr=0.1, momentum=momentum)
    if epoch == 300: # 1000
      optimizer = optim.SGD(network.parameters(), lr=0.5, momentum=momentum)
      # extract_info(False)
    if epoch == 350: # 1500
      optimizer = optim.SGD(network.parameters(), lr=1.3, momentum=momentum)
    if epoch == 400: # 1750
      optimizer = optim.SGD(network.parameters(), lr=2.0, momentum=momentum)
    if epoch == 450: # 2000
      optimizer = optim.SGD(network.parameters(), lr=3.0, momentum=momentum)
      # extract_info(False)
    if epoch == 500: # 2000
      optimizer = optim.SGD(network.parameters(), lr=5.0, momentum=momentum)
      extract_info(network, False)
  print("After training:")
  train_eval(network)
  (accuracy, losses) = test(network)
  print(accuracy, losses)

  (rk, beta) = extract_info(network, True)
  return (rk, beta)


k = 3 # Kernel dimension is (k, k)
output_channels = 2 # number of output channels
(Rbeta, beta) = experiment(k, k, output_channels)


