In [0]:
import math
import numpy as np
from IPython.display import clear_output
from tqdm import tqdm_notebook as tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.color_palette("bright")
import matplotlib as mpl
import matplotlib.cm as cm

import torch
from torch import Tensor
from torch import nn
from torch.nn  import functional as F 
from torch.autograd import Variable

use_cuda = torch.cuda.is_available()

In [0]:
def ode_solve(z0, t0, t1, f,absolute):
    """
    Simplest Euler ODE initial value solver
    """
    h_max = 0.05
    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())

    h = (t1 - t0)/n_steps
    t = t0
    z = z0

    for i_step in range(n_steps):
        if absolute:
          z = z + h * f(z, t).abs()
        else:
          z = z + h * f(z, t)
        t = t + h
    return z

In [0]:
class ODEF(nn.Module):
    def forward_with_grad(self, z, t, grad_outputs):
        """Compute f and a df/dz, a df/dp, a df/dt"""
        batch_size = z.shape[0]

        out = self.forward(z, t)

        a = grad_outputs
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a),
            allow_unused=True, retain_graph=True
        )
        # grad method automatically sums gradients for batch items, we have to expand them back 
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
            adfdp = adfdp.expand(batch_size, -1) / batch_size
        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        return out, adfdz, adfdt, adfdp

    def flatten_parameters(self):
        p_shapes = []
        flat_parameters = []
        for p in self.parameters():
            p_shapes.append(p.size())
            flat_parameters.append(p.flatten())
        return torch.cat(flat_parameters)

In [0]:
class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, flat_parameters, func,abs_solve):
        assert isinstance(func, ODEF)
        bs, *z_shape = z0.size()
        time_len = t.size(0)

        with torch.no_grad():
            z = torch.zeros(time_len, bs, *z_shape).to(z0)
            z[0] = z0
            for i_t in range(time_len - 1):
                if abs_solve:
                  z0 = ode_solve(z0, t[i_t], t[i_t+1], func,absolute=True)
                else:
                  z0 = ode_solve(z0, t[i_t], t[i_t+1], func,absolute=False)
                z[i_t+1] = z0

        ctx.func = func
        ctx.abs_solve = abs_solve
        ctx.save_for_backward(t, z.clone(), flat_parameters)
        return z

    @staticmethod
    def backward(ctx, dLdz):
        """
        dLdz shape: time_len, batch_size, *z_shape
        """
        func = ctx.func
        abs_solve = ctx.abs_solve
        t, z, flat_parameters = ctx.saved_tensors
        time_len, bs, *z_shape = z.size()
        n_dim = np.prod(z_shape)
        n_params = flat_parameters.size(0)

        # Dynamics of augmented system to be calculated backwards in time
        def augmented_dynamics(aug_z_i, t_i):
            """
            tensors here are temporal slices
            t_i - is tensor with size: bs, 1
            aug_z_i - is tensor with size: bs, n_dim*2 + n_params + 1
            """
            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  # ignore parameters and time

            # Unflatten z and a
            z_i = z_i.view(bs, *z_shape)
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                t_i = t_i.detach().requires_grad_(True)
                z_i = z_i.detach().requires_grad_(True)
                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, t_i, grad_outputs=a)  # bs, *z_shape
                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)
                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)
                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)

            # Flatten f and adfdz
            func_eval = func_eval.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim) 
            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)

        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz for convenience
        with torch.no_grad():
            ## Create placeholders for output gradients
            # Prev computed backwards adjoints to be adjusted by direct gradients
            adj_z = torch.zeros(bs, n_dim).to(dLdz)
            adj_p = torch.zeros(bs, n_params).to(dLdz)
            # In contrast to z and p we need to return gradients for all times
            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)

            for i_t in range(time_len-1, 0, -1):
                z_i = z[i_t]
                t_i = t[i_t]
                f_i = func(z_i, t_i).view(bs, n_dim)

                # Compute direct gradients
                dLdz_i = dLdz[i_t]
                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

                # Adjusting adjoints with direct gradients
                adj_z += dLdz_i
                adj_t[i_t] = adj_t[i_t] - dLdt_i

                # Pack augmented variable
                aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)

                # Solve augmented system backwards
                
                if abs_solve:
                  aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics,absolute=True)
                else:
                  aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics,absolute=False)

                # Unpack solved backwards augmented system
                adj_z[:] = aug_ans[:, n_dim:2*n_dim]
                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]

                del aug_z, aug_ans

            ## Adjust 0 time adjoint with direct gradients
            # Compute direct gradients 
            dLdz_0 = dLdz[0]
            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

            # Adjust adjoints
            adj_z += dLdz_0
            adj_t[0] = adj_t[0] - dLdt_0
        return adj_z.view(bs, *z_shape), adj_t, adj_p, None , None

In [0]:
class NeuralODE(nn.Module):
    def __init__(self, func,neural_abs):
        super(NeuralODE, self).__init__()
        assert isinstance(func, ODEF)
        self.func = func
        self.neural_abs = neural_abs

    def forward(self, z0, t1=Tensor([0., 1.]),t2=Tensor([1., 2.]), return_whole_sequence=False):
        t1 = t1.to(z0)
        t2 = t2.to(z0)
        if self.neural_abs:
          z = ODEAdjoint.apply(z0, t2, self.func.flatten_parameters(), self.func,True)
        else:
          z = ODEAdjoint.apply(z0, t1, self.func.flatten_parameters(), self.func,False)
        z_abs= ODEAdjoint.apply(z0, t2, self.func.flatten_parameters(), self.func,True)
        if return_whole_sequence:
            return z,z_abs
        else:
            return z[-1],z_abs[-1]

In [0]:
class ConvODEF(ODEF):
    def __init__(self, dim):
        super(ConvODEF, self).__init__()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm1 = nn.BatchNorm2d(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm2 = nn.BatchNorm2d(dim)

    def forward(self, x, t):
        #xt = add_time(x, t)
        h = self.norm1(torch.relu(self.conv1(x)))
        #ht = add_time(h, t)
        dxdt = self.norm2(torch.relu(self.conv2(h)))
        return dxdt

In [0]:
class ContinuousNeuralMNISTClassifier(nn.Module):
    def __init__(self, ode):
        super(ContinuousNeuralMNISTClassifier, self).__init__()
        self.downsampling = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        )
        self.feature = ode
        #self.feature_tis = ode_tis
        self.norm = nn.BatchNorm2d(64)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = self.downsampling(x)
        x ,x1 = self.feature(x)
        output2=x
        x1 ,output1 = self.feature(x)
        #print(output1.shape)
        #output2 = self.feature_tis(output1)
        #print(output1.shape)
        x = self.norm(x)
        x = self.avg_pool(x)
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        x = x.view(-1, shape)
        out = self.fc(x)

        
        out1 = output1-output2
        return out,out1

In [0]:
func = ConvODEF(64)
ode = NeuralODE(func, neural_abs=False)
#ode_tis= NeuralODE(func,neural_abs=True)
model = ContinuousNeuralMNISTClassifier(ode)
if use_cuda:
    model = model.cuda()

In [0]:
import torchvision

img_std = 0.3081
img_mean = 0.1307


batch_size = 32
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=True, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize((img_mean,), (img_std,))
                             ])
    ),
    batch_size=batch_size, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=False, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize((img_mean,), (img_std,))
                             ])
    ),
    batch_size=128, shuffle=True
)

In [0]:
optimizer = torch.optim.Adam(model.parameters())

In [0]:
def calc_loss_new(output, target,output1 ):
    criterion = nn.CrossEntropyLoss()
    
    loss1= criterion(output, target)
    
    
    #loss2=((output1.norm(dim=-1)).norm(dim=-1)).norm(dim=-1).cuda()
    #print(torch.min(output1),torch.max(output1))
    b= output1.view(output1.shape[0], -1).cuda()
    loss2_1 = torch.norm(b, p=2, dim=1).cuda()
    
    loss2=torch.mean(loss2_1)
    
    loss = loss1 + loss2*0.1
    loss2.detach()
    print(loss1,loss2, 'output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2' )
    return loss

In [0]:
def train(epoch):
    num_items = 0
    train_losses = []

    model.train()
    criterion = nn.CrossEntropyLoss()
    print(f"Training Epoch {epoch}...")
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        if use_cuda:
            data = data.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        output,output1 = model(data)
        #output = model(data)
        loss = calc_loss_new(output, target,output1 )
        #loss = calc_loss_new(output, target )
        loss.backward()
        optimizer.step()

        train_losses += [loss.item()]
        num_items += data.shape[0]
    print('Train loss: {:.5f}'.format(np.mean(train_losses)))
    return train_losses

In [0]:
def test():
    accuracy = 0.0
    num_items = 0

    model.eval()
    criterion = nn.CrossEntropyLoss()
    print(f"Testing...")
    with torch.no_grad():
        for batch_idx, (data, target) in tqdm(enumerate(test_loader),  total=len(test_loader)):
            if use_cuda:
                data = data.cuda()
                target = target.cuda()
            output,output1 = model(data)
            #output = model(data)
            accuracy += torch.sum(torch.argmax(output, dim=1) == target).item()
            num_items += data.shape[0]
    accuracy = accuracy * 100 / num_items
    print("Test Accuracy: {:.3f}%".format(accuracy))

In [118]:
n_epochs = 5

train_losses = []
for epoch in range(1, n_epochs + 1):
    train_losses += train(epoch)
    test()

Training Epoch 1...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(2.3050, device='cuda:0', grad_fn=<NllLossBackward>) tensor(46.9889, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.2839, device='cuda:0', grad_fn=<NllLossBackward>) tensor(46.8530, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.3210, device='cuda:0', grad_fn=<NllLossBackward>) tensor(45.8612, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.3484, device='cuda:0', grad_fn=<NllLossBackward>) tensor(45.5561, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.2390, device='cuda:0', grad_fn=<NllLossBackward>) tensor(44.0038, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.4284, device='cuda:0', grad_fn=<NllLossBackward>) tensor(42.3307, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shap

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 31.750%
Training Epoch 2...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(1.7668, device='cuda:0', grad_fn=<NllLossBackward>) tensor(160.9199, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.8626, device='cuda:0', grad_fn=<NllLossBackward>) tensor(162.8104, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.6951, device='cuda:0', grad_fn=<NllLossBackward>) tensor(161.9644, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.5973, device='cuda:0', grad_fn=<NllLossBackward>) tensor(163.1819, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.9252, device='cuda:0', grad_fn=<NllLossBackward>) tensor(165.1269, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.9097, device='cuda:0', grad_fn=<NllLossBackward>) tensor(163.6934, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 34.040%
Training Epoch 3...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(1.6441, device='cuda:0', grad_fn=<NllLossBackward>) tensor(295.0612, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.5786, device='cuda:0', grad_fn=<NllLossBackward>) tensor(294.8637, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.8884, device='cuda:0', grad_fn=<NllLossBackward>) tensor(294.6985, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.8141, device='cuda:0', grad_fn=<NllLossBackward>) tensor(295.0254, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.6749, device='cuda:0', grad_fn=<NllLossBackward>) tensor(292.7831, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.7140, device='cuda:0', grad_fn=<NllLossBackward>) tensor(296.5703, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 25.350%
Training Epoch 4...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(1.3466, device='cuda:0', grad_fn=<NllLossBackward>) tensor(415.4891, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.8273, device='cuda:0', grad_fn=<NllLossBackward>) tensor(419.7753, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.5233, device='cuda:0', grad_fn=<NllLossBackward>) tensor(419.0929, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.9534, device='cuda:0', grad_fn=<NllLossBackward>) tensor(418.6275, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.5865, device='cuda:0', grad_fn=<NllLossBackward>) tensor(418.9594, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.0047, device='cuda:0', grad_fn=<NllLossBackward>) tensor(419.4115, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 13.460%
Training Epoch 5...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(1.7756, device='cuda:0', grad_fn=<NllLossBackward>) tensor(546.0722, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.7993, device='cuda:0', grad_fn=<NllLossBackward>) tensor(545.2141, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.8769, device='cuda:0', grad_fn=<NllLossBackward>) tensor(545.0205, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.7797, device='cuda:0', grad_fn=<NllLossBackward>) tensor(544.8263, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.7302, device='cuda:0', grad_fn=<NllLossBackward>) tensor(541.9186, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.6115, device='cuda:0', grad_fn=<NllLossBackward>) tensor(541.2341, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 27.960%


In [101]:
n_epochs = 5

train_losses = []
for epoch in range(1, n_epochs + 1):
    train_losses += train(epoch)
    test()

Training Epoch 1...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(2.2851, device='cuda:0', grad_fn=<NllLossBackward>) tensor(85.4905, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.3223, device='cuda:0', grad_fn=<NllLossBackward>) tensor(81.2726, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.2923, device='cuda:0', grad_fn=<NllLossBackward>) tensor(77.3955, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.2953, device='cuda:0', grad_fn=<NllLossBackward>) tensor(72.1213, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.3316, device='cuda:0', grad_fn=<NllLossBackward>) tensor(72.6233, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(2.3132, device='cuda:0', grad_fn=<NllLossBackward>) tensor(74.6322, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shap

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 33.930%
Training Epoch 2...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(1.8781, device='cuda:0', grad_fn=<NllLossBackward>) tensor(211.6924, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.8074, device='cuda:0', grad_fn=<NllLossBackward>) tensor(208.4545, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.6146, device='cuda:0', grad_fn=<NllLossBackward>) tensor(209.5448, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.5012, device='cuda:0', grad_fn=<NllLossBackward>) tensor(202.5643, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.9065, device='cuda:0', grad_fn=<NllLossBackward>) tensor(204.9729, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.7160, device='cuda:0', grad_fn=<NllLossBackward>) tensor(207.4841, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 27.050%
Training Epoch 3...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(1.6019, device='cuda:0', grad_fn=<NllLossBackward>) tensor(351.9213, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.4208, device='cuda:0', grad_fn=<NllLossBackward>) tensor(351.8893, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.6414, device='cuda:0', grad_fn=<NllLossBackward>) tensor(350.9954, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.6702, device='cuda:0', grad_fn=<NllLossBackward>) tensor(352.1866, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.4986, device='cuda:0', grad_fn=<NllLossBackward>) tensor(352.3868, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.5979, device='cuda:0', grad_fn=<NllLossBackward>) tensor(352.0697, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 33.070%
Training Epoch 4...


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

tensor(1.7507, device='cuda:0', grad_fn=<NllLossBackward>) tensor(486.2399, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.5387, device='cuda:0', grad_fn=<NllLossBackward>) tensor(485.5952, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.7197, device='cuda:0', grad_fn=<NllLossBackward>) tensor(486.5784, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.4963, device='cuda:0', grad_fn=<NllLossBackward>) tensor(485.9507, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.4914, device='cuda:0', grad_fn=<NllLossBackward>) tensor(487.3737, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_1.shape,loss1,loss2
tensor(1.5257, device='cuda:0', grad_fn=<NllLossBackward>) tensor(487.4764, device='cuda:0', grad_fn=<MeanBackward0>) output1.shape[0]),b.shape,loss2_

KeyboardInterrupt: ignored

In [0]:
test()

In [0]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
img_std = 0.3081
img_mean = 0.1307
test_loader_guassian = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=False, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),AddGaussianNoise(0., 0.3),
                                 torchvision.transforms.Normalize((img_mean,), (img_std,))
                             ])
    ),
    batch_size=128, shuffle=True
)

In [0]:
def test_guassian():
    accuracy = 0.0
    num_items = 0

    model.eval()
    criterion = nn.CrossEntropyLoss()
    print(f"Testing...")
    with torch.no_grad():
        for batch_idx, (data, target) in tqdm(enumerate(test_loader_guassian),  total=len(test_loader_guassian)):
            if use_cuda:
                data = data.cuda()
                target = target.cuda()
            output,output1 = model(data)
            accuracy += torch.sum(torch.argmax(output, dim=1) == target).item()
            num_items += data.shape[0]
    accuracy = accuracy * 100 / num_items
    print("Test Accuracy: {:.3f}%".format(accuracy))

In [125]:
test_guassian()

Testing...


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Test Accuracy: 19.680%
