In [29]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
cd '/content/drive/My Drive/BAI'

/content/drive/My Drive/BAI


In [2]:
import os
os.getcwd()

'/content/drive/My Drive/BAI'

In [0]:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm_notebook as tqdm


In [0]:
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=True, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=50)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)

parser.add_argument('--save', type=str, default='./mnist_ode')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args(args=[])

In [0]:

from torchdiffeq import odeint_adjoint 

from torchdiffeq import odeint


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)


class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut


class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint_adjoint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol,method = 'dopri5')
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val

class Multipy(object):
    def __init__(self, scale = 255.0):
        self.scale = scale
        
    def __call__(self, tensor):
        return tensor*self.scale
  




In [0]:
def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
    img_std = 0.3081
    img_mean = 0.1307
    if data_aug:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            #Multipy(255.0),
            transforms.Normalize((img_mean,), (img_std,))

        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            #Multipy(255.0),
            transforms.Normalize((img_mean,), (img_std,))
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        #Multipy(255.0),
        transforms.Normalize((img_mean,), (img_std,))
    ])

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader
    
train_loader, test_loader, train_eval_loader = get_mnist_loaders(
        args.data_aug, args.batch_size, args.test_batch_size
    )

In [1]:
import torchvision
import torchvision.transforms as transforms
import torch

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)

valset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
train_eval_loader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [0]:



def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = args.lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


In [8]:
if __name__ == '__main__':

    makedirs(args.save)
    file_path= '/content/drive/My Drive/BAI/torchdiffeq_ode.ipynb'
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(file_path))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

    is_odenet = args.network == 'odenet'

    if args.downsampling_method == 'conv':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
    elif args.downsampling_method == 'res':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
        ]

    feature_layers = [ODEBlock(ODEfunc(64))]
    fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

    logger.info(model)
    logger.info('Number of parameters: {}'.format(count_parameters(model)))

    criterion = nn.CrossEntropyLoss().to(device)

    

    data_gen = inf_generator(train_loader)
    batches_per_epoch = len(train_loader)

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[25, 35, 45],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

/content/drive/My Drive/BAI/torchdiffeq_ode.ipynb
IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Namespace(adjoint=True, batch_size=128, data_aug=True, debug=False, downsampling_method='conv', gpu=0, lr=0.1, nepochs=50, network='odenet', save='./mnist_ode', test_batch_size=1000, tol=0.001)
Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): GroupNorm(32, 64, eps=1e-05, affine=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (4): GroupNorm(32, 64, eps=1e-05, affine=True)
  (5): ReLU(inplace=True)
  (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ODEBlock(
    (odefunc): ODEfunc(
      (norm1): GroupNor

In [0]:
for data,target in test_loader:
    data=data
    target=target
print(data.shape)
print(torch.max(data))
print(torch.min(data))
print(torch.mean(data))
print(torch.std(data))

torch.Size([1000, 3, 32, 32])
tensor(2.7537)
tensor(-2.4291)
tensor(0.0066)
tensor(1.2581)


In [9]:
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, )

    best_acc = 0
    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()

    for itr in range(args.nepochs * batches_per_epoch):

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)

        optimizer.zero_grad()
        x, y = data_gen.__next__()
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        loss = criterion(logits, y)

        if is_odenet:
            nfe_forward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        loss.backward()
        optimizer.step()

        if is_odenet:
            nfe_backward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        batch_time_meter.update(time.time() - end)
        if is_odenet:
            f_nfe_meter.update(nfe_forward)
            b_nfe_meter.update(nfe_backward)
        end = time.time()

        if itr % batches_per_epoch == 0:
            with torch.no_grad():
                train_acc = accuracy(model, train_eval_loader)
                val_acc = accuracy(model, test_loader)
                if val_acc > best_acc:
                    torch.save( model.state_dict(), os.path.join(args.save, 'model_ode.pth'))
                    best_acc = val_acc
                logger.info(
                    "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
                    "Train Acc {:.4f} | Test Acc {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                        b_nfe_meter.avg, train_acc, val_acc
                    )
                )

Epoch 0000 | Time 0.442 (0.442) | NFE-F 26.0 | NFE-B 39.0 | Train Acc 0.0974 | Test Acc 0.0984
Epoch 0001 | Time 0.320 (0.140) | NFE-F 20.2 | NFE-B 24.1 | Train Acc 0.9708 | Test Acc 0.9707
Epoch 0002 | Time 0.305 (0.143) | NFE-F 20.1 | NFE-B 25.6 | Train Acc 0.9888 | Test Acc 0.9886
Epoch 0003 | Time 0.306 (0.142) | NFE-F 20.1 | NFE-B 25.3 | Train Acc 0.9877 | Test Acc 0.9823
Epoch 0004 | Time 0.329 (0.147) | NFE-F 20.5 | NFE-B 26.5 | Train Acc 0.9913 | Test Acc 0.9893
Epoch 0005 | Time 0.325 (0.147) | NFE-F 20.3 | NFE-B 26.7 | Train Acc 0.9952 | Test Acc 0.9917
Epoch 0006 | Time 0.341 (0.149) | NFE-F 22.7 | NFE-B 26.7 | Train Acc 0.9922 | Test Acc 0.9890
Epoch 0007 | Time 0.314 (0.150) | NFE-F 22.6 | NFE-B 27.0 | Train Acc 0.9966 | Test Acc 0.9930
Epoch 0008 | Time 0.328 (0.153) | NFE-F 24.8 | NFE-B 26.8 | Train Acc 0.9973 | Test Acc 0.9906
Epoch 0009 | Time 0.328 (0.156) | NFE-F 26.2 | NFE-B 27.0 | Train Acc 0.9983 | Test Acc 0.9923
Epoch 0010 | Time 0.325 (0.155) | NFE-F 25.8 | NFE

KeyboardInterrupt: ignored

In [0]:
torch.save(model.state_dict(), '/content/drive/My Drive/BAI/experiment1/model_ode.pth')
torch.save(optimizer.state_dict(), '/content/drive/My Drive/BAI/experiment1/optimizer_ode.pth')

In [0]:
model_ode = model
#continued_optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9,weight_decay=0.0005)

network_state_dict = torch.load('/content/drive/My Drive/BAI/ciffar_ode/model_ode.pth')
model_ode.load_state_dict(network_state_dict)




<All keys matched successfully>

In [11]:
import time
t1=time.time()
with torch.no_grad():
  val_acc_ode = accuracy(model, test_loader)
  #val_acc_ode = accuracy(continued_model_ode, test_loader)

t2=time.time()

print(val_acc_ode)

0.9942


In [0]:
import torchvision



img_std1 = 0.3081
img_mean1 = 0.1307

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)



test_loader1 = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=False, download=True,
                             transform=torchvision.transforms.Compose([                                                            
                                 torchvision.transforms.ToTensor(),
                                 AddGaussianNoise(0., 0.4),
                                 torchvision.transforms.Normalize((0.1325,), (0.5063,))
                                 
                             ])
    ),
    batch_size=1000, shuffle=True
)

In [39]:
for data,target in test_loader:
    data1=data
    target1=target
print(torch.max(data1))
print(torch.min(data1))
print(torch.mean(data1))
print(torch.std(data1))

tensor(2.8215)
tensor(-0.4242)
tensor(0.0176)
tensor(1.0191)


In [41]:
import time
t1=time.time()
with torch.no_grad():
  val_acc_ode = accuracy(model, test_loader)
  #val_acc_ode = accuracy(continued_model_ode, test_loader)

t2=time.time()

print(val_acc_ode)

0.9942


In [0]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


transform_test1 = transforms.Compose([
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.09),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test1)
test_loader1 = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)


Files already downloaded and verified


In [0]:
with torch.no_grad():
  acc_cnn = accuracy(model_ode, test_loader1)
print(acc_cnn)

0.6214


In [42]:
class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint_adjoint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol,method = 'dopri5')
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')       
feature_layers_ode = [ODEBlock(ODEfunc(64))]
if args.downsampling_method == 'conv':
        downsampling_layers_ode = [
            nn.Conv2d(3, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
fc_layers_ode = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)] 
model_ode = nn.Sequential(*downsampling_layers_ode, *feature_layers_ode, *fc_layers_ode).to(device)

model_ode_state_dict = torch.load('/content/drive/My Drive/BAI/ciffar_ode/model_ode.pth')
model_ode.load_state_dict(model_ode_state_dict)


<All keys matched successfully>

In [43]:
feature_layers_cnn = [ResBlock(64, 64) for _ in range(1)]
if args.downsampling_method == 'conv':
        downsampling_layers_cnn = [
            nn.Conv2d(3, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
fc_layers_cnn = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)] 
model_cnn = nn.Sequential(*downsampling_layers_cnn, *feature_layers_cnn, *fc_layers_cnn).to(device)

model_cnn_state_dict = torch.load('/content/drive/My Drive/BAI/ciffar_cnn/model_cnn.pth')
model_cnn.load_state_dict(model_cnn_state_dict)


<All keys matched successfully>

In [44]:
class SODEfunc(nn.Module):

    def __init__(self, dim):
        super(SODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm3(out)
        return out


class SODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(SODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint_adjoint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol,method = 'dopri5')
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


feature_layers_tis_ode = [SODEBlock(SODEfunc(64))]
if args.downsampling_method == 'conv':
        downsampling_layers_tis_ode = [
            nn.Conv2d(3, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
fc_layers_tis_ode = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)] 
model_ti_ode = nn.Sequential(*downsampling_layers_tis_ode, *feature_layers_tis_ode, *fc_layers_tis_ode).to(device)

model_ti_ode_state_dict = torch.load('/content/drive/My Drive/BAI/ciffar_tis_ode/model_ode.pth')
model_ti_ode.load_state_dict(model_ti_ode_state_dict)


<All keys matched successfully>

In [66]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

transform_test = transforms.Compose([
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.08),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader1 = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)


Files already downloaded and verified


In [67]:
with torch.no_grad():
  train_acc_cnn = accuracy(model_cnn, test_loader1)
  train_acc_ode = accuracy(model_ode, test_loader1)
  train_acc_ti_ode = accuracy(model_ti_ode, test_loader1)
print(train_acc_ode,train_acc_cnn,train_acc_ti_ode)

0.667 0.5816 0.6437


In [25]:
for data,target in test_loader1:
    data=data
    target=target
print(torch.max(data))
print(torch.min(data))
print(torch.mean(data))
print(torch.std(data))
print(torch.mean(data[:,0,:,:]),torch.mean(data[:,1,:,:]),torch.mean(data[:,2,:,:]))
print(torch.std(data[:,0,:,:]),torch.std(data[:,1,:,:]),torch.std(data[:,2,:,:]))


tensor(2.1107)
tensor(-1.9759)
tensor(5.4929e-05)
tensor(0.9999)
tensor(0.0001) tensor(-0.0002) tensor(0.0002)
tensor(0.9999) tensor(0.9999) tensor(1.0000)


In [0]:
for data,target in test_loader:
    data=data
    target=target
print(torch.std(data))
print(torch.mean(data))

tensor(0.3140)
tensor(0.1361)


In [0]:
print(torch.max(data))


tensor(1.)
torch.Size([128, 1, 28, 28])
