In [4]:
import sys
# Load the Iris dataset
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from sklearn.manifold import TSNE

from maraboupy import Marabou
import numpy as np
import time

import torch
import torch.nn as nn
import torch.optim as optim
import onnx
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

import multiprocessing
import time


### Dataset

In [5]:
# Dataset
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load Fashion-MNIST dataset
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)

# DataLoader
batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

### Model

In [6]:
model = nn.Sequential(
    nn.Flatten(start_dim=1),
    nn.Linear(28*28, 32), # input shape and number of neurons
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 10)  # 10 classes output
)

# Test the model with dummy input
dummy_input = torch.randn(1, 1, 28, 28)  # Batch size = 1, 1 channel, 28x28
output = model(dummy_input)
print("Output shape:", output.shape)  # Expected: [1, 10]

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Output shape: torch.Size([1, 10])


### Training

In [7]:
# Train model
num_epoch = 50
device = torch.device('cpu')
for epoch in range(num_epoch):
    # print(epoch)
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        # get model output
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # backward propagation
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epoch}], Loss: {running_loss/len(train_loader):.4f}')

Epoch [1/50], Loss: 0.5747
Epoch [2/50], Loss: 0.4124
Epoch [3/50], Loss: 0.3813
Epoch [4/50], Loss: 0.3592
Epoch [5/50], Loss: 0.3448
Epoch [6/50], Loss: 0.3329
Epoch [7/50], Loss: 0.3205
Epoch [8/50], Loss: 0.3121
Epoch [9/50], Loss: 0.3024
Epoch [10/50], Loss: 0.2990
Epoch [11/50], Loss: 0.2909
Epoch [12/50], Loss: 0.2828
Epoch [13/50], Loss: 0.2782
Epoch [14/50], Loss: 0.2731
Epoch [15/50], Loss: 0.2694
Epoch [16/50], Loss: 0.2625
Epoch [17/50], Loss: 0.2578
Epoch [18/50], Loss: 0.2546
Epoch [19/50], Loss: 0.2499
Epoch [20/50], Loss: 0.2491
Epoch [21/50], Loss: 0.2421
Epoch [22/50], Loss: 0.2404
Epoch [23/50], Loss: 0.2382
Epoch [24/50], Loss: 0.2335
Epoch [25/50], Loss: 0.2316
Epoch [26/50], Loss: 0.2274
Epoch [27/50], Loss: 0.2267
Epoch [28/50], Loss: 0.2220
Epoch [29/50], Loss: 0.2218
Epoch [30/50], Loss: 0.2172
Epoch [31/50], Loss: 0.2181
Epoch [32/50], Loss: 0.2147
Epoch [33/50], Loss: 0.2110
Epoch [34/50], Loss: 0.2098
Epoch [35/50], Loss: 0.2085
Epoch [36/50], Loss: 0.2047
E

In [8]:
# Export the Sequential model to ONNX
onnx_file_path = "simple_nn_fashion_mnist_sequential_50e.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_file_path,
    export_params=True,
    opset_version=12,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
print(f"Model exported to {onnx_file_path}")

model_scripted = torch.jit.script(model, 'simple_nn_fashion_mnist_sequential_50e.pt')
model_scripted.save("simple_nn_fashion_mnist_sequential_50e.pt")
torch.save(model.state_dict(), 'simple_nn_fashion_mnist_sequential_50e.pth')

Model exported to simple_nn_fashion_mnist_sequential_50e.onnx


  model_scripted = torch.jit.script(model, 'simple_nn_fashion_mnist_sequential_50e.pt')
