We test a convolutional neural network (CNN) for the approximation of the Wasserstein distance. The architecture is borrowed, up to some modification, from https://www.kaggle.com/code/shadabhussain/cifar-10-cnn-using-pytorch.

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd import grad
import numpy as np
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import random
from CF_NeuralNetwork import FFNN
from CF_NeuralNetwork_PC import FFNNPC

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transforms.ToTensor(), download=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=True)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

# Extract raw image data and labels
X_train = trainset.data
X_test = testset.data

# Normalize each image and scale by 32*32
X_train_normalized = X_train / X_train.sum(axis=(1, 2, 3))[:, None, None, None] * 32 * 32
X_test_normalized = X_test / X_test.sum(axis=(1, 2, 3))[:, None, None, None] * 32 * 32

# Remove the reference image at index 34
X_reference = X_train_normalized[34]
X_train_normalized = np.delete(X_train_normalized, 34, axis=0)

# Convert normalized data back to tensors
x_Train_t = torch.Tensor(X_train_normalized).permute(0, 3, 1, 2)  # Convert to (N, C, H, W) shape

x_Test_t = torch.Tensor(X_test_normalized).permute(0, 3, 1, 2)  # Convert to (N, C, H, W) shape

Y_train=np.loadtxt('CIFAR_D_train', delimiter=',')
Y_test=np.loadtxt('CIFAR_D_test', delimiter=',')
Y_train=np.delete(Y_train,34,0)


# Transform data to torch variables
# x_Train_t=Variable(torch.from_numpy(X_train).float(), requires_grad=True)
y_Train_t=Variable(torch.from_numpy(Y_train).float(), requires_grad=False)
#x_Test_t=Variable(torch.from_numpy(X_test).float(), requires_grad=False)
y_Test_t=Variable(torch.from_numpy(Y_test).float(), requires_grad=False)

def my_rel_loss(output,target):
    loss=torch.abs(output-target)
    loss=torch.div(loss,target)
    loss=torch.mean(loss)
    return loss

def my_abs_loss(output,target):
    loss=torch.abs(output-target)
    loss=torch.mean(loss)
    return loss.reshape(-1)

We construct a CNN with 6,896,321 parameters.

In [None]:
class Cifar10Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4

            nn.Flatten(), 
            nn.Linear(256*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 1))
        
    def forward(self, xb):
        return self.network(xb)
    
model = Cifar10Model()

In [None]:
# Initialize the neural network
optimizer = torch.optim.Adam(model.parameters())
epochs = 300
batch_size = 64

# Training Loss
epoch_Loss=[]
# List for the epochs
epoch_list=[]
# Test Loss
epoch_Loss_Val=[]

for epoch in range(epochs):
    acc_loss = 0
    counter = 0
    
    # Shuffle the dataset manually
    permutation = torch.randperm(x_Train_t.size()[0])
    
    for i in range(0, x_Train_t.size(0), batch_size):
        optimizer.zero_grad()
        # Batch selection
        indices = permutation[i:i+batch_size]
        batch_x, batch_y = x_Train_t[indices], y_Train_t[indices]
        # Forward pass
        y_pred = model(batch_x)
        y_pred = y_pred.reshape(-1)
        
        # Compute loss
        loss = my_rel_loss(y_pred, batch_y)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        acc_loss += loss
        counter += 1

    # Calculate average loss
    loss_app=acc_loss/counter
    print(loss_app)
    y_pred_val=model(x_Test_t)
    y_pred_val=y_pred_val.reshape(-1)
    loss_Val=my_rel_loss(y_pred_val,y_Test_t)
    epoch_Loss_Val.append(loss_Val.detach().numpy())
    epoch_Loss.append(loss_app.detach().numpy())
    epoch_list.append(epoch+1)


# Create the numpy arrays for the plots
counter=np.array(epoch_list)
mse_loss=np.array(epoch_Loss)
mse_loss_val=np.array(epoch_Loss_Val)

# Generate the plots
fig, ax = plt.subplots()
ax.set_ylim([0.001, 1])
plt.xlabel('number of epochs',fontsize=16)
plt.ylabel('mean relative error',fontsize=16)
plot_1a, =ax.semilogy(counter,mse_loss,color='blue',label='training set')
plot_1b, =ax.semilogy(counter,mse_loss_val,'--',color='green',label='test set')
plt.axhline(y = 0.037183734101399565, color = 'r', linestyle = ':')
ax.legend(handles=[plot_1a,plot_1b],loc='upper right',fontsize=16)
plt.yticks(fontsize=16)
plt.xticks(fontsize=16)
plt.grid()
plt.grid(which='minor',alpha=0.2)

plt.show()