In [1]:
from SFW import *
from DFW import *
from utils.utils import *
from utils.loader import *
from MultiClassHingeLoss import *
from constraints.constraints import *
import torchvision
import torch.optim as optim

In [2]:
dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
mnist_trainset, mnist_testset = load_data("CIFAR10")

In [4]:
batch_size = 1000

dataset_test = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR10('../data', train=False, download=True, transform=torchvision.transforms.ToTensor()), 
  batch_size=100,
  shuffle=True
)
dataset_train = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR10('../data', train=True, download=True, transform=torchvision.transforms.ToTensor()),
  batch_size=batch_size,
  shuffle=True
)

In [5]:
input_size = 28*28
hidden_size = 1500
output_size = 10

In [6]:
class MultiLayerPerceptron(torch.nn.Module):
    def __init__(self):
        super(MultiLayerPerceptron,self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.relu = torch.nn.ReLU()
        self.fc_in = torch.nn.Linear(self.input_size,self.hidden_size,bias=True) #fully connected input_layer
        self.fc_hid_1 = torch.nn.Linear(self.hidden_size,self.hidden_size,bias=True) #fully connected hidden_layer_1
        self.fc_hid_2 = torch.nn.Linear(self.hidden_size,self.hidden_size,bias=True) #dully connected hidden_layer_2
        self.fc_out = torch.nn.Linear(self.hidden_size,self.output_size,bias=True)

    def forward(self,x):
        batch_size = x.shape[0]
        x = x.view(batch_size, self.input_size)
        x = self.relu(self.fc_in(x))
        x = self.relu(self.fc_hid_1(x))
        x = self.relu(self.fc_hid_2(x))
        x = self.fc_out(x)
        return x

In [7]:
mlp = MultiLayerPerceptron().to(device)

In [8]:
optimizer = DFW(mlp.parameters(), eta=0.1)
stats = train_model(mlp, dataset_train, dataset_test, optimizer, torch.nn.CrossEntropyLoss(), 10)

 10%|████████▎                                                                          | 1/10 [01:01<09:12, 61.34s/it]

Epoch 0 | Test accuracy: 0.56900


 20%|████████████████▌                                                                  | 2/10 [02:03<08:12, 61.54s/it]

Epoch 1 | Test accuracy: 0.80210


 30%|████████████████████████▉                                                          | 3/10 [03:04<07:11, 61.60s/it]

Epoch 2 | Test accuracy: 0.86700


 40%|█████████████████████████████████▏                                                 | 4/10 [04:04<06:04, 60.81s/it]

Epoch 3 | Test accuracy: 0.89650


 50%|█████████████████████████████████████████▌                                         | 5/10 [05:02<04:58, 59.71s/it]

Epoch 4 | Test accuracy: 0.90810


 60%|█████████████████████████████████████████████████▊                                 | 6/10 [06:01<03:58, 59.54s/it]

Epoch 5 | Test accuracy: 0.91930


 70%|██████████████████████████████████████████████████████████                         | 7/10 [07:03<03:00, 60.32s/it]

Epoch 6 | Test accuracy: 0.92110


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [08:04<02:01, 60.73s/it]

Epoch 7 | Test accuracy: 0.92840


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [09:03<01:00, 60.10s/it]

Epoch 8 | Test accuracy: 0.93650


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [10:05<00:00, 60.53s/it]

Epoch 9 | Test accuracy: 0.93970



