In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
cd '/content/drive/My Drive/BAI'

/content/drive/My Drive/BAI


In [3]:
import os
os.getcwd()

'/content/drive/My Drive/BAI'

In [0]:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm_notebook as tqdm


In [0]:
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=True, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=50)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)

parser.add_argument('--save', type=str, default='/content/drive/My Drive/BAI/experiment2') #change1
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args(args=[])
from torchdiffeq import odeint_adjoint
from torchdiffeq import odeint


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)


class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut


class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        # dim is 64
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out



class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint_adjoint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol,method = 'euler')
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val




def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = args.lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger



In [0]:
import torchvision

img_std = 0.3081
img_mean = 0.1307


batch_size = 128
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=True, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize((img_mean,), (img_std,))
                             ])
    ),
    batch_size=batch_size, shuffle=True
)

train_eval_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=True, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize((img_mean,), (img_std,))
                             ])
    ),
    batch_size=1000, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data/mnist", train=False, download=True,
                             transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize((img_mean,), (img_std,))
                             ])
    ),
    batch_size=1000, shuffle=True
)

data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)

In [27]:
class SODEfunc(nn.Module):

    def __init__(self, dim):
        super(SODEfunc, self).__init__()
        # dim is 64
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm3(out)
        return out

        
class SODEBlock(nn.Module):

    def __init__(self, odefunc, solver):
        super(SODEBlock, self).__init__()
        self.odefunc = odefunc
        self.solver = solver
        if self.solver== 'euler':
            self.integration_time = torch.tensor([0,1]).float()
        
        
    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint_adjoint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol, method = self.solver)
        return out[1] 

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


if __name__ == '__main__':

    makedirs(args.save)
    file_path= '/content/drive/My Drive/BAI/torchdiffeq_tis_ode.ipynb' 
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(file_path))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) )

    is_odenet = args.network == 'odenet'

    if args.downsampling_method == 'conv':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
    elif args.downsampling_method == 'res':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
        ]

    feature_layers = [SODEBlock(SODEfunc(64), 'euler')] 
    fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    fullmodel = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

    modelFE = nn.Sequential(*list(fullmodel.children())[0:7]).to(device)
    modelODE = nn.Sequential(*list(fullmodel.children())[7:8]).to(device)
    modelFC = nn.Sequential(*fc_layers).to(device) 

    logger.info(fullmodel)
    logger.info('Number of parameters: {}'.format(count_parameters(fullmodel)))

    criterion = nn.CrossEntropyLoss().to(device)

    

    

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[20, 30, 45],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

    optimizer = torch.optim.SGD(fullmodel.parameters(), lr=args.lr, momentum=0.9)

    

/content/drive/My Drive/BAI/torchdiffeq_tis_ode.ipynb
/content/drive/My Drive/BAI/torchdiffeq_tis_ode.ipynb
{"nbformat":4,"nbformat_minor":0,"metadata":{"anaconda-cloud":{},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.1"},"colab":{"name":"torchdiffeq_tis_ode.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"29201ef340154a63b34c29801c932b77":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_fb3031a5a7a44faa8e8789da6fa6eaf9","_mod

In [7]:

if __name__ == '__main__':

    makedirs(args.save)
    file_path= '/content/drive/My Drive/BAI/torchdiffeq_ode.ipynb' 
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(file_path))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) )

    is_odenet = args.network == 'odenet'

    if args.downsampling_method == 'conv':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
    elif args.downsampling_method == 'res':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
        ]

    feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
    #feature_layers = [ResBlock(64, 64) for _ in range(1)]
    fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

    logger.info(model)
    logger.info('Number of parameters: {}'.format(count_parameters(model)))

    criterion = nn.CrossEntropyLoss().to(device)

    

    

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[20, 30, 40],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    

/content/drive/My Drive/BAI/torchdiffeq_ode.ipynb
{"nbformat":4,"nbformat_minor":0,"metadata":{"anaconda-cloud":{},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.1"},"colab":{"name":"torchdiffeq_ode.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"CHbdMT-UL2zG","colab_type":"code","outputId":"cb6dc096-37ab-483c-80df-c6748f867e53","colab":{"base_uri":"https://localhost:8080/","height":107}},"source":["from google.colab import drive\n","drive.mount('/content/drive')"],"execution_count":0,"outputs":[{"output_type":"stream","text":["Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_

In [8]:
    best_acc = 0
    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()
    epoch = 0
    for itr in range(args.nepochs * batches_per_epoch):

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)

        optimizer.zero_grad()
        x, y = data_gen.__next__()
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        #print(loss)
        

        loss.backward()
        optimizer.step()

        
        batch_time_meter.update(time.time() - end)
        
        end = time.time()

        if itr % batches_per_epoch == 0:
            if 1 :
                #model.eval
                #train_acc = accuracy(model, train_eval_loader)
                #val_acc = accuracy(model, test_loader)
                #if val_acc > best_acc:
                    #torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
                    #best_acc = val_acc
                epoch=epoch+1
                for param_group in optimizer.param_groups:
                  print(param_group['lr'],':-lr',  'epoch#:-',epoch )
                print((loss),':-loss')
                '''logger.info(
                    "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
                    "Train Acc {:.4f} | Test Acc {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                        b_nfe_meter.avg, train_acc, val_acc
                    )
                )'''

0.1 :-lr epoch#:- 1
tensor(2.3524, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 2
tensor(0.1960, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 3
tensor(0.0184, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 4
tensor(0.0741, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 5
tensor(0.0116, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 6
tensor(0.0189, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 7
tensor(0.0409, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 8
tensor(0.0167, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 9
tensor(0.0089, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 10
tensor(0.0483, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 11
tensor(0.0486, device='cuda:0', grad_fn=<NllLossBackward>) :-loss
0.1 :-lr epoch#:- 12
tensor(0.0325, device='cuda:0',

In [0]:
torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'ode_euler_mnist.pth'))

In [1]:
modeltrained_ode_euler = torch.load('/content/drive/My Drive/BAI/experiment2/ode_euler_mnist.pth')
#modeltrained_ode_tis_euler = torch.load('/content/drive/My Drive/BAI/experiment2/ode_tis_euler.pth')


NameError: ignored

In [30]:
model.load_state_dict(modeltrained_ode_euler['state_dict'])
#fullmodel.load_state_dict(modeltrained_ode_tis_euler['state_dict'])


<All keys matched successfully>

In [32]:
import time
t1=time.time()
with torch.no_grad():
  val_acc_ode = accuracy(model, test_loader)
  #val_acc_ode_tis = accuracy(fullmodel, test_loader)

t2=time.time()

print(val_acc_ode,val_acc_ode_tis)

0.9925 0.9921


In [0]:
for data,target in test_loader:
    data=data
    target=target

In [65]:
def accuracy1(model, dataset_loader,noise):
    total_correct = 0
  
    for x, y in dataset_loader:
        x = x.to(device)
        

        x = x + noise.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)

noise=torch.randn(data[0].shape)*0.9
with torch.no_grad():
  val_acc_ode = accuracy1(model, test_loader,noise)
  #val_acc_ode_tis = accuracy1(fullmodel, test_loader,noise)

print(val_acc_ode,val_acc_ode_tis)

0.8873 0.6227
