# Data loading & module imports

In [None]:
import torch
import numpy as np

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

train_on_gpu=torch.cuda.is_available()

# Here we will set up our data loading for testing
train_transform = transforms.Compose([transforms.RandomRotation(10),
                                     transforms.RandomHorizontalFlip(), 
                                     transforms.ToTensor(),
                                     transforms.Normalize((.5, .5, .5), (.5, .5, .5))])

test_transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((.5, .5, .5), (.5, .5, .5))])

train_data = datasets.CIFAR10('data', train=True,
                              download=True, transform=train_transform)
test_data = datasets.CIFAR10('data', train=False,
                             download=True, transform=test_transform)

num_train = len(train_data)
train_indeces = list(range(num_train))
np.random.shuffle(train_indeces)
split = int(np.floor(.1 * num_train))
train_idx, valid_idx = train_indeces[split:], train_indeces[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

num_workers = 0
train_loader = torch.utils.data.DataLoader(train_data, batch_size=50,
    sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=50, 
    sampler=valid_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, 
    num_workers=num_workers)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Defining distance functions for various manifolds

In [None]:
# The torus distance functions for an arbitrary dimensional torus. 
# When applied to tensors of the form a.shape() = [d, 1, n], b.shape() = [d, k, 1] it gives
# the n x k matrix of distances between all pairs of points on a dimension d torus (S^1)^d
# The parameter "10" here is an aribtrary size choice for the torus. Changing it (in both
# terms) is equivalent to changing the coefficient of the regularizer. Changing them
# independently is not an experiment we conducted
def torus_distance(a, b):
    return torch.min(torch.remainder(a - b, 10), torch.remainder(b-a, 10))


# The Klein bottle distance is more complicated. The universal cover of a Klein bottle is 
# R^2, so we compute distance of two points p, p' on [0, 10] x [0, 10] with the Klein metric 
# as the minimum distance between p and Dp', where D ranges over the set of deck transform-
# ations of R^2. For obvious reasons, we only need to compute the transforms on the 9 squares
# centered on [0, 10] x [0, 10], which we do by applying the appropriate matrices.
# Once again, the choice of 10 is arbitrary and could be changed
embed_dim = 2

M1 = torch.tensor([[1, 0, 10],
                           [ 0, -1, 10],
                           [ 0, 0, 1]], dtype=torch.float32)
M2 = torch.tensor(np.array([[1, 0, -10],
                           [ 0, -1, 10],
                           [ 0, 0, 1]]), dtype=torch.float32)
M3 = torch.tensor(np.array([[1, 0, 0],
                           [ 0, 1, 10],
                           [ 0, 0, 1]]), dtype=torch.float32)
M4 = torch.tensor(np.array([[1, 0, 0],
                           [ 0, 1, -10],
                           [ 0, 0, 1]]), dtype=torch.float32)


M1 = M1.to(device)
M2 = M2.to(device)
M3 = M3.to(device)
M4 = M4.to(device)


def klein_distance(a, b):
    padder = nn.ConstantPad1d((0, 1), 1)
    a = torch.remainder(a, 10)
    b = torch.remainder(b, 10)
    a = padder(a).transpose(0, 1)
    b = padder(b).transpose(0, 1)
    a = a.view(3, -1, 1)
    #b = b.view(3, 1, -1)
    return torch.min(torch.cat((torch.unsqueeze(torch.linalg.norm(a - b.view(3, 1, -1), dim = 0), 0),
                    torch.unsqueeze(torch.linalg.norm(a - torch.matmul(M1, b).view(3, 1, -1), dim=0), 0),
                    torch.unsqueeze(torch.linalg.norm(a - torch.matmul(M2, b).view(3, 1, -1), dim=0), 0),
                    torch.unsqueeze(torch.linalg.norm(a - torch.matmul(M3, b).view(3, 1, -1), dim=0), 0),
                    torch.unsqueeze(torch.linalg.norm(a - torch.matmul(M4, b).view(3, 1, -1), dim=0), 0),
                    torch.unsqueeze(torch.linalg.norm(a - torch.matmul(torch.matmul(M1, M3), b).view(3, 1, -1), dim=0), 0),
                    torch.unsqueeze(torch.linalg.norm(a - torch.matmul(torch.matmul(M1, M4), b).view(3, 1, -1), dim=0), 0),
                    torch.unsqueeze(torch.linalg.norm(a - torch.matmul(torch.matmul(M2, M3), b).view(3, 1, -1), dim=0), 0),
                    torch.unsqueeze(torch.linalg.norm(a - torch.matmul(torch.matmul(M2, M4), b).view(3, 1, -1), dim=0), 0))), dim=0)




# The sphere distance is less complicated than the others. We don't need to work on the 
# universal covers, instead we just project onto a unit sphere and compute the angle between,
# to determine the geodesic distance between the two points. 
# Technically there should be a factor of pi to accomodate that, but we absorb it into
# our constant multiple and avoid the multiplication. /
def sphere_distance(a, b):
    return 10 * torch.nan_to_num(torch.acos(torch.inner(a, b)))

# Model implementation with learned embeddings
Here is a sample model which learns optimal embeddings on the Klein bottle

In [None]:
class FCKlein(nn.Module):
    def __init__(self):
        super(FCKlein, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.moduli_embed1 = nn.Parameter(torch.zeros((128*8*8, 2)).uniform_(0, 10))

        self.moduli_embed2 = nn.Parameter(torch.zeros((256, 2)).uniform_(0, 10))
        

        
        self.conv2 = nn.Conv2d(64, 128, 3, padding = 1)
        
        self.maxpool2 = nn.MaxPool2d(2)
        
                
        self.fc1 = nn.Linear(128*8*8, 256)
        
        self.fc2 = nn.Linear(256, 10)
        
        self.dropout = nn.Dropout(p=.5)
        
    
    def forward(self, x):
        conv_out = F.relu(self.conv1(x))
        max_out = self.maxpool1(conv_out)
        conv_out = F.relu(self.conv2(max_out))
        conv_out = self.maxpool2(conv_out)
        conv_out = conv_out.view(-1, 128*8*8)
        conv_out = self.dropout(conv_out)
        fc1_out = F.relu(self.fc1(conv_out))
        fc2_out = self.fc2(fc1_out)
        return fc2_out
    
        
    
    # These are the two regularizing terms we want to add to the loss. 
    # alpha should always be >= 0, and beta always <= 0
    def topological_regularizer(self, alpha):
        x = self.moduli_embed1#.view(2, -1, 1)
        y = self.moduli_embed2#.view(2, 1, -1)
        distance_matrix, _ = klein_distance(x, y)
        return torch.mean((torch.mul(distance_matrix,torch.transpose(self.fc1.weight, 0, 1)))**alpha)
    
    def dispersement_regularizer(self):
        x = self.moduli_embed1#.view(2, -1, 1)
        y = self.moduli_embed2#.view(2, -1, 1)
        distance_matrix, _ = klein_distance(x, x)
        dist_2, _ = klein_distance(y, y)
        zero_vec = torch.zeros(len(distance_matrix)).to(device)
        zero_2 = torch.zeros(len(dist_2)).to(device)
        distance_matrix = torch.log(distance_matrix+torch.tensor([.1]).to(device))
        dist_2 = torch.log(dist_2+torch.tensor([.1]).to(device))
        distance_matrix[range(len(distance_matrix)), range(len(distance_matrix))] = zero_vec
        dist_2[range(len(dist_2)), range(len(dist_2))] = zero_2
        return torch.mean(distance_matrix) + torch.mean(dist_2)
    
    def L2FC(self):
        return torch.sum(self.fc1.weight**2)

In [None]:
klein_test = FCKlein()
klein_test = klein_test.to(device)


# We tried using different optimizers and learning rates for the different parameters,
# but ultimately found it had very little effect. 
my_list = ['moduli_embed1', 'moduli_embed2']
moduli_params = [x[1] for x in list(filter(lambda kv: kv[0] in my_list, klein_test.named_parameters()))]
base_params = [x[1] for x in list(filter(lambda kv: kv[0] not in my_list, klein_test.named_parameters()))]
criterion = nn.CrossEntropyLoss()
optimizer1 = optim.Adam([{'params': moduli_params, 'lr': 1e-3}], lr=1e-3)
optimizer2 = optim.Adam([{'params': base_params}], lr=1e-3)

In [None]:
n_epochs = 50

topological_decay_lambda = .1
topological_decay_alpha = 2
moduli_dispersement_lambda = -.1
L2factor = 0



valid_loss_min = np.Inf

for epoch in range(n_epochs):
    train_loss = 0.0
    test_CE = 0.0
    valid_loss = 0.0
    
    # Model training
    klein_test.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        output = klein_test(data)
        
                
        loss = criterion(output, target) + topological_decay_lambda * klein_test.topological_regularizer(topological_decay_alpha) + moduli_dispersement_lambda * klein_test.dispersement_regularizer() #+ L2factor*klein_test.L2FC()
        
        test_cross_entropy = criterion(output, target)
        
        loss.backward()
        
                
        optimizer1.step()
        optimizer2.step()
        
        
        
        train_loss += loss.item()*data.size(0)
        test_CE += test_cross_entropy.item()*data.size(0)
        
    
    # Model eval
    klein_test.eval()
    num_correct = 0
    for data, target in valid_loader:
        data, target = data.to(device), target.to(device)
        output = klein_test(data)
        _, pred = torch.max(output.data, 1)
        # Note I have left only the criterion loss here: this loss doesn't include the components from weight
        # decay and moduli dispersement, only the cross entropy loss
        loss = criterion(output, target)
        valid_loss += loss.item()*data.size(0)
        num_correct += torch.sum(pred == target.data)
        
    train_loss = train_loss/len(train_loader.sampler)
    test_CE = test_CE/len(train_loader.sampler)
    valid_loss = valid_loss/len(valid_loader.sampler)
    accuracy = num_correct.item()/float(len(valid_loader.sampler))
    
    print('Epoch: {} \tAccuracy: {:.6f} \tTraining Cross Entropy: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, accuracy, test_CE, valid_loss))
    
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(klein_test.state_dict(), 'model_klein_test.pt')
        valid_loss_min = valid_loss

# Model implementation with fixed random embeddings
Very similar to the above, but now we no longer simultaneously train the manifold embedding. We use a higher dimensional torus, for variety. 

In [None]:
# embed_dim denotes what dimensional torus we work with. All of this could be
# implemented inside the FCTorus __init__, if desired. We've left it outside
# to emphasize the lack of modification during training. 

embed_dim = 9
x = torch.zeros((128*8*8, embed_dim)).uniform_(0, 10).to(device)
y = torch.zeros((256, embed_dim)).uniform_(0, 10).to(device)

x = x.view(embed_dim, 1, -1)
y = y.view(embed_dim, -1, 1)

distance_matrix= torus_distance(x, y)
distance_matrix = torch.linalg.norm(distance_matrix, dim=0)

In [None]:
class FCTorus(nn.Module):
    def __init__(self):
        super(FCTorus, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.maxpool1 = nn.MaxPool2d(2)

        
        self.conv2 = nn.Conv2d(64, 128, 3, padding = 1)
        
        self.maxpool2 = nn.MaxPool2d(2)
        
                
        self.fc1 = nn.Linear(128*8*8, 256)
        
        self.fc2 = nn.Linear(256, 10)
        
        self.dropout = nn.Dropout(p=.5)
        
    
    def forward(self, x):
        conv_out = F.relu(self.conv1(x))
        max_out = self.maxpool1(conv_out)
        conv_out = F.relu(self.conv2(max_out))
        conv_out = self.maxpool2(conv_out)
        conv_out = conv_out.view(-1, 128*8*8)
        conv_out = self.dropout(conv_out)
        fc1_out = F.relu(self.fc1(conv_out))
        fc2_out = self.fc2(fc1_out)
        return fc2_out

    # These are the two regularizing terms we want to add to the loss. 
    # alpha should always be >= 0, and beta always <= 0
    def topological_regularizer(self, alpha):
        return torch.mean((torch.mul(distance_matrix, self.fc1.weight))**alpha)
    
    def dispersement_regularizer(self, beta):
        x = self.moduli_embed1.weight.view(embed_dim, -1, 1)
        y = self.moduli_embed2.weight.view(embed_dim, -1, 1)
        distance_matrix = self.torus_distance(x, x.view(embed_dim, 1, -1))
        dist_2 = self.torus_distance(y, y.view(embed_dim, 1, -1))
        zero_vec = torch.zeros(len(distance_matrix)).cuda()
        zero_2 = torch.zeros(len(dist_2)).cuda()
        distance_matrix = (distance_matrix+torch.tensor([.01]).cuda())**beta
        dist_2 = (dist_2+torch.tensor([.01]).cuda())**beta
        distance_matrix[range(len(distance_matrix)), range(len(distance_matrix))] = zero_vec
        dist_2[range(len(dist_2)), range(len(dist_2))] = zero_2
        return torch.mean(distance_matrix) + torch.mean(dist_2)
    
    def L2FC(self):
        return torch.sum(self.fc1.weight**2)

In [None]:
torus_3_0 = FCTorus()

torus_3_0 = torus_3_0.to(device)
    
my_list = ['moduli_embed1.weight', 'moduli_embed2.weight']
moduli_params = [x[1] for x in list(filter(lambda kv: kv[0] in my_list, torus_3_0.named_parameters()))]
base_params = [x[1] for x in list(filter(lambda kv: kv[0] not in my_list, torus_3_0.named_parameters()))]

criterion = nn.CrossEntropyLoss()
optimizer1 = optim.Adam([
                            {'params': moduli_params, 'lr': 1e-3},
                        ], lr=1e-3)
optimizer2 = optim.Adam([{'params': base_params}], lr=1e-3)

In [None]:
n_epochs = 50

topological_decay_lambda = 1
topological_decay_alpha = 2
moduli_dispersement_lambda = .1
moduli_dispersement_beta = -2
L2factor = .001



valid_loss_min = np.Inf

for epoch in range(n_epochs):
    train_loss = 0.0
    test_CE = 0.0
    valid_loss = 0.0
    
    # Model training
    torus_3_0.train()
    for data, target in train_loader:
        if train_on_gpu:
            data, target = data.cuda(), target.cuda()
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        output = torus_3_0(data)
        
                
        loss = criterion(output, target) + topological_decay_lambda * torus_3_0.topological_regularizer(topological_decay_alpha) #+ moduli_dispersement_lambda * torus_3_0.dispersement_regularizer(moduli_dispersement_beta)+L2factor*torus_3_0.L2FC()
        
        test_cross_entropy = criterion(output, target)
        
        loss.backward()
        
                
        optimizer1.step()
        optimizer2.step()
        
        
        
        train_loss += loss.item()*data.size(0)
        test_CE += test_cross_entropy.item()*data.size(0)
        
    
    # Model eval
    torus_3_0.eval()
    num_correct = 0
    for data, target in valid_loader:
        if train_on_gpu:
            data, target = data.cuda(), target.cuda()
        output = torus_3_0(data)
        _, pred = torch.max(output.data, 1)
        # Note I have left only the criterion loss here: this loss doesn't include the components from weight
        # decay and moduli dispersement, only the cross entropy loss
        loss = criterion(output, target)
        valid_loss += loss.item()*data.size(0)
        num_correct += torch.sum(pred == target.data)
        
    train_loss = train_loss/len(train_loader.sampler)
    test_CE = test_CE/len(train_loader.sampler)
    valid_loss = valid_loss/len(valid_loader.sampler)
    accuracy = num_correct.item()/float(len(valid_loader.sampler))
    
    print('Epoch: {} \tAccuracy: {:.6f} \tTraining Cross Entropy: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, accuracy, test_CE, valid_loss))
    
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(torus_3_0.state_dict(), 'model_hyperparameter_torus_3_0.pt')
        valid_loss_min = valid_loss