<a href="https://colab.research.google.com/github/gaspartino/2D-LipCNNs/blob/main/ode_alternative.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import numpy as np
from IPython.display import clear_output
from tqdm import tqdm_notebook as tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.color_palette("bright")
import matplotlib as mpl
import matplotlib.cm as cm

import torch
from torch import Tensor
from torch import nn
from torch.nn  import functional as F
from torch.autograd import Variable

use_cuda = torch.cuda.is_available()

In [2]:
def norm(dim):
    return nn.BatchNorm2d(dim)

def conv3x3(in_feats, out_feats, stride=1):
    return nn.Conv2d(in_feats, out_feats, kernel_size=3, stride=stride, padding=1, bias=False)

def add_time(in_tensor, t):
    bs, c, w, h = in_tensor.shape
    return torch.cat((in_tensor, t.expand(bs, 1, w, h)), dim=1)

In [3]:
class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)

In [4]:
class ODEF(nn.Module):
    def forward_with_grad(self, z, t, grad_outputs):
        """Compute f and a df/dz, a df/dp, a df/dt"""
        batch_size = z.shape[0]

        out = self.forward(z, t)

        a = grad_outputs
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a),
            allow_unused=True, retain_graph=True
        )
        # grad method automatically sums gradients for batch items, we have to expand them back
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
            adfdp = adfdp.expand(batch_size, -1) / batch_size
        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        return out, adfdz, adfdt, adfdp

    def flatten_parameters(self):
        p_shapes = []
        flat_parameters = []
        for p in self.parameters():
            p_shapes.append(p.size())
            flat_parameters.append(p.flatten())
        return torch.cat(flat_parameters)

In [5]:
class ConvODEF(ODEF):
    def __init__(self, dim):
        super(ConvODEF, self).__init__()
        self.conv1 = conv3x3(dim + 1, dim)
        self.norm1 = norm(dim)
        self.conv2 = conv3x3(dim + 1, dim)
        self.norm2 = norm(dim)

    def forward(self, x, t):
        xt = add_time(x, t)
        h = self.norm1(torch.relu(self.conv1(xt)))
        ht = add_time(h, t)
        dxdt = self.norm2(torch.relu(self.conv2(ht)))
        return dxdt
'''
#ODE Original
    def __init__(self, dim):
        super(ConvODEF, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out
        '''

'\n#ODE Original\n    def __init__(self, dim):\n        super(ConvODEF, self).__init__()\n        self.norm1 = norm(dim)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)\n        self.norm2 = norm(dim)\n        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)\n        self.norm3 = norm(dim)\n        self.nfe = 0\n\n    def forward(self, t, x):\n        self.nfe += 1\n        out = self.norm1(x)\n        out = self.relu(out)\n        out = self.conv1(t, out)\n        out = self.norm2(out)\n        out = self.relu(out)\n        out = self.conv2(t, out)\n        out = self.norm3(out)\n        return out\n        '

In [6]:
class ContinuousNeuralMNISTClassifier(nn.Module):
    def __init__(self, ode):
        super(ContinuousNeuralMNISTClassifier, self).__init__()
        self.downsampling = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        )
        self.feature = ode
        self.norm = norm(64)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = self.downsampling(x)
        x = self.feature(x)
        x = self.norm(x)
        x = self.avg_pool(x)
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        x = x.view(-1, shape)
        out = self.fc(x)
        return out

In [7]:
def ode_solve(z0, t0, t1, f):
    """
    Simplest Euler ODE initial value solver
    """
    h_max = 0.05
    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())

    h = (t1 - t0)/n_steps
    t = t0
    z = z0

    for i_step in range(n_steps):
        z = z + h * f(z, t)
        t = t + h
    return z

In [8]:
class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, flat_parameters, func):
        assert isinstance(func, ODEF)
        bs, *z_shape = z0.size()
        time_len = t.size(0)

        with torch.no_grad():
            z = torch.zeros(time_len, bs, *z_shape).to(z0)
            z[0] = z0
            for i_t in range(time_len - 1):
                z0 = ode_solve(z0, t[i_t], t[i_t+1], func)
                z[i_t+1] = z0

        ctx.func = func
        ctx.save_for_backward(t, z.clone(), flat_parameters)
        return z

    @staticmethod
    def backward(ctx, dLdz):
        """
        dLdz shape: time_len, batch_size, *z_shape
        """
        func = ctx.func
        t, z, flat_parameters = ctx.saved_tensors
        time_len, bs, *z_shape = z.size()
        n_dim = np.prod(z_shape)
        n_params = flat_parameters.size(0)

        # Dynamics of augmented system to be calculated backwards in time
        def augmented_dynamics(aug_z_i, t_i):
            """
            tensors here are temporal slices
            t_i - is tensor with size: bs, 1
            aug_z_i - is tensor with size: bs, n_dim*2 + n_params + 1
            """
            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  # ignore parameters and time

            # Unflatten z and a
            z_i = z_i.view(bs, *z_shape)
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                t_i = t_i.detach().requires_grad_(True)
                z_i = z_i.detach().requires_grad_(True)
                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, t_i, grad_outputs=a)  # bs, *z_shape
                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)
                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)
                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)

            # Flatten f and adfdz
            func_eval = func_eval.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim)
            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)

        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz for convenience
        with torch.no_grad():
            ## Create placeholders for output gradients
            # Prev computed backwards adjoints to be adjusted by direct gradients
            adj_z = torch.zeros(bs, n_dim).to(dLdz)
            adj_p = torch.zeros(bs, n_params).to(dLdz)
            # In contrast to z and p we need to return gradients for all times
            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)

            for i_t in range(time_len-1, 0, -1):
                z_i = z[i_t]
                t_i = t[i_t]
                f_i = func(z_i, t_i).view(bs, n_dim)

                # Compute direct gradients
                dLdz_i = dLdz[i_t]
                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

                # Adjusting adjoints with direct gradients
                adj_z += dLdz_i
                adj_t[i_t] = adj_t[i_t] - dLdt_i

                # Pack augmented variable
                aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)

                # Solve augmented system backwards
                aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)

                # Unpack solved backwards augmented system
                adj_z[:] = aug_ans[:, n_dim:2*n_dim]
                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]

                del aug_z, aug_ans

            ## Adjust 0 time adjoint with direct gradients
            # Compute direct gradients
            dLdz_0 = dLdz[0]
            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

            # Adjust adjoints
            adj_z += dLdz_0
            adj_t[0] = adj_t[0] - dLdt_0
        return adj_z.view(bs, *z_shape), adj_t, adj_p, None

In [9]:
class NeuralODE(nn.Module):
    def __init__(self, func):
        super(NeuralODE, self).__init__()
        assert isinstance(func, ODEF)
        self.func = func

    def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False):
        t = t.to(z0)
        z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)
        if return_whole_sequence:
            return z
        else:
            return z[-1]

In [10]:
def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = 0.1 * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn

In [11]:
import torchvision
import torchvision.transforms as transforms

batch_size = 128
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=True, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 transforms.RandomCrop(28, padding=4)
                             ])
    ),
    batch_size=batch_size, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=False, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                             ])
    ),
    batch_size=128, shuffle=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.2MB/s]


Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 491kB/s]


Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.52MB/s]


Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 13.0MB/s]

Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw






In [12]:
func = ConvODEF(64)
ode = NeuralODE(func)
model = ContinuousNeuralMNISTClassifier(ode)
if use_cuda:
    model = model.cuda()

batches_per_epoch = len(train_loader)

lr_fn = learning_rate_with_decay(
        batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)  # ODE original
#optimizer = torch.optim.Adam(model.parameters())

In [13]:
def train(epoch):
    num_items = 0
    train_losses = []

    model.train()
    criterion = nn.CrossEntropyLoss()
    print(f"Training Epoch {epoch}...")
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        if use_cuda:
            data = data.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_losses += [loss.item()]
        num_items += data.shape[0]
    print('Train loss: {:.5f}'.format(np.mean(train_losses)))
    return train_losses

In [14]:
def test():
    accuracy = 0.0
    num_items = 0

    model.eval()
    criterion = nn.CrossEntropyLoss()
    print(f"Testing...")
    with torch.no_grad():
        for batch_idx, (data, target) in tqdm(enumerate(test_loader),  total=len(test_loader)):
            if use_cuda:
                data = data.cuda()
                target = target.cuda()
            output = model(data)
            accuracy += torch.sum(torch.argmax(output, dim=1) == target).item()
            num_items += data.shape[0]
    accuracy = accuracy * 100 / num_items
    print("Test Accuracy: {:.3f}%".format(accuracy))

In [15]:
n_epochs = 40
test()
train_losses = []
for epoch in range(1, n_epochs + 1):

    for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(epoch)

    train_losses += train(epoch)
    test()

Testing...


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_idx, (data, target) in tqdm(enumerate(test_loader),  total=len(test_loader)):


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 9.800%
Training Epoch 1...


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.19499
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 98.650%
Training Epoch 2...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.05066
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.050%
Training Epoch 3...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.03630
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.090%
Training Epoch 4...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.02965
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.270%
Training Epoch 5...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.02518
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.010%
Training Epoch 6...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.02273
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.130%
Training Epoch 7...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01966
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.060%
Training Epoch 8...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01900
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.000%
Training Epoch 9...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01655
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.290%
Training Epoch 10...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01519
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.400%
Training Epoch 11...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01558
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.420%
Training Epoch 12...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01401
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.530%
Training Epoch 13...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01335
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.330%
Training Epoch 14...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01263
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.360%
Training Epoch 15...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01223
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.300%
Training Epoch 16...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01159
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.370%
Training Epoch 17...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00994
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.380%
Training Epoch 18...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.01008
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.390%
Training Epoch 19...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00941
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.510%
Training Epoch 20...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00845
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.280%
Training Epoch 21...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00937
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.330%
Training Epoch 22...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00794
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.420%
Training Epoch 23...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00698
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.500%
Training Epoch 24...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00734
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.510%
Training Epoch 25...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00812
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.380%
Training Epoch 26...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00706
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.360%
Training Epoch 27...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00658
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.410%
Training Epoch 28...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00706
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.350%
Training Epoch 29...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00569
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.560%
Training Epoch 30...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00622
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.530%
Training Epoch 31...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00680
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.510%
Training Epoch 32...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00449
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.420%
Training Epoch 33...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00432
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.560%
Training Epoch 34...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00521
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.420%
Training Epoch 35...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00663
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.310%
Training Epoch 36...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00684
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.580%
Training Epoch 37...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00547
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.430%
Training Epoch 38...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00498
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.450%
Training Epoch 39...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00509
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.500%
Training Epoch 40...


  0%|          | 0/469 [00:00<?, ?it/s]

Train loss: 0.00467
Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

Test Accuracy: 99.450%


In [16]:
!pip install torchattacks

Collecting torchattacks
  Downloading torchattacks-3.5.1-py3-none-any.whl.metadata (927 bytes)
Collecting requests~=2.25.1 (from torchattacks)
  Downloading requests-2.25.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting chardet<5,>=3.0.2 (from requests~=2.25.1->torchattacks)
  Downloading chardet-4.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting idna<3,>=2.5 (from requests~=2.25.1->torchattacks)
  Downloading idna-2.10-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting urllib3<1.27,>=1.21.1 (from requests~=2.25.1->torchattacks)
  Downloading urllib3-1.26.20-py2.py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.1/50.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Downloading torchattacks-3.5.1-py3-none-any.whl (142 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.0/142.0 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading requests-2.25.1-py2.py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━

In [17]:
import torch
import torchattacks
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt

# Função para calcular a acurácia em um conjunto de dados
def accuracy(model, dataset_loader, device):
    total_correct = 0
    total_samples = 0
    for images, labels in dataset_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

    return total_correct / total_samples

# Função para testar o ataque adversarial
def test_attacks(model, test_loader, epsilon, device):
    # Instanciar o ataque FGSM com o torchattacks
    attack = torchattacks.FGSM(model, eps=epsilon)

    correct = 0
    total = 0

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        # Realizar o ataque FGSM
        perturbed_images = attack(images, labels)

        # Avaliar o modelo nas imagens perturbadas
        outputs = model(perturbed_images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy_adv = correct / total
    print(f"Accuracy under FGSM attack with epsilon={epsilon}: {accuracy_adv * 100:.2f}%")

    # Instanciar o ataque FGSM com o torchattacks
    attack = torchattacks.PGD(model, eps=epsilon)

    correct = 0
    total = 0

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        # Realizar o ataque FGSM
        perturbed_images = attack(images, labels)

        # Avaliar o modelo nas imagens perturbadas
        outputs = model(perturbed_images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy_adv = correct / total
    print(f"Accuracy under PGD attack with epsilon={epsilon}: {accuracy_adv * 100:.2f}%")


# Configurações
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epsilon = 0.3  # Tamanho da perturbação (ajuste conforme necessário)

acc = accuracy(model, test_loader, device)
print(f"Accuracy under PGD attack with epsilon={epsilon}: {acc * 100:.2f}%")

# Testar o ataque adversarial
test_attacks(model, test_loader, epsilon, device)

Accuracy under PGD attack with epsilon=0.3: 99.45%
Accuracy under FGSM attack with epsilon=0.3: 9.75%
Accuracy under PGD attack with epsilon=0.3: 0.89%
