In [1]:
# Necessary
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchdiffeq import odeint_adjoint as odeint
# from jupyterthemes import jtplot
# from neural_ode.utils import *
# jtplot.style(theme="chesterish")
 # CONSTANT 
device = "cuda"
EPOCHS=1
BATCH_SIZE=32
IMG_SIZE=(32,32,3)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint as odeint

import sys
import os

# sys.path.insert(0,os.path.abspath(__file__))

class ODEBlock(nn.Module):
    def __init__(self, parallel=None):
        super(ODEBlock,self).__init__()
        self.parallel = parallel
        self.conv1 = nn.Conv2d(64+1,64,3,1, padding=1)
        self.norm1 = nn.GroupNorm(32,64)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(64+1,64,3,1, padding=1)
        self.norm2 = nn.GroupNorm(32,64)
        
    def forward(self,t,x): 
        tt = torch.ones_like(x[:, :1, :, :]) * t
        out = torch.cat([tt, x], 1)
        out = self.conv1(out)
        out = self.norm1(out)
        out = self.relu(out)
        out = torch.cat([tt, out], 1)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.relu(out)
        
        return out
     
class ODENet(nn.Module):
    def __init__(self, func, parallel=False, device="cpu"):
        super(ODENet, self).__init__()
        assert isinstance(func, ODEBlock) or isinstance(func.module,ODEBlock), f"argument function is not NeuralODEs model"
        self.fe = nn.Sequential(*[nn.Conv2d(3,64,3,1),
                                  nn.GroupNorm(32,64),
                                  nn.ReLU(),
                                  # nn.Conv2d(64,64,3,1),
                                  # nn.GroupNorm(32,64),
                                  # nn.ReLU(),
                                  nn.Conv2d(64,64,4,2),
                                  nn.GroupNorm(32,64),
                                  #1x64x12x12
                                  nn.ReLU()])
        self.rm = func
        self.fcc = nn.Sequential(*[nn.AdaptiveAvgPool2d(1),
                                   # 1 x 64 x 1 x 1
                                   nn.Flatten(),
                                   nn.Linear(64,10),
                                   nn.Softmax()])
        self.intergrated_time = torch.Tensor([0.,1.]).float().to(device)
        self.parallel = parallel
    def forward(self,x):
        out = self.fe(x)
        self.intergrated_time = self.intergrated_time.to(out.device)
        if self.parallel:
            out = odeint(self.rm.module, out, self.intergrated_time, method="euler",options=dict(step_size=0.1), rtol=1e-3, atol=1e-3)[1]
        else:
            out = odeint(self.rm, out, self.intergrated_time, method="euler",options=dict(step_size=0.1), rtol=1e-3, atol=1e-3)[1]
        
        #out = self.rm(out)
        out = self.fcc(out)
        return out
    def evaluate(self, test_loader):
        correct = 0
        total = 0 
        running_loss = 0
        count = 0        
        with torch.no_grad():
            for batch_id , test_data in enumerate(test_loader,0):
                count += 1
                data, label = test_data
                outputs = self.forward(data)
                _, correct_labels = torch.max(label, 1) 
                _, predicted = torch.max(outputs.data, 1)
                total += label.size(0)
                correct += (predicted == correct_labels).sum().item()
                running_loss += F.torch.nn.functional.binary_cross_entropy_with_logits(
                    outputs.float(), label.float()).item()
        #        print(f"--> Total {total}\n-->batch_id: {batch_id + 1}")
        acc = correct/total
        running_loss /= count 
        return running_loss,acc

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.fe = nn.Sequential(*[
            nn.Conv2d(3,64,3,1),
            nn.GroupNorm(32,64),
            nn.ReLU(),
            # nn.Conv2d(64,64,3,1),
            # nn.GroupNorm(32,64),
            # nn.ReLU(),
            nn.Conv2d(64,64,4,2),
            nn.GroupNorm(32,64),
            nn.ReLU()
             
        ])
        self.rm = nn.Sequential(*[
            nn.Conv2d(64,64,3,1, padding=1),
            nn.GroupNorm(32,64),
            nn.ReLU(), 
            nn.Conv2d(64,64,3,1, padding=1),
            nn.GroupNorm(32,64),
            nn.ReLU(),
        ])
        self.fcc = nn.Sequential(*[
            #nn.Conv2d(64,1,3,1,padding=1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64,10),
            nn.Softmax()
        ])
    def forward(self,x):
        out = self.fe(x)
        out = out + self.rm(out)
        out = self.fcc(out)
        return out

        return self.net(x)
    def evaluate(self, test_loader):
        correct = 0
        total = 0 
        running_loss = 0
        count = 0 
        with torch.no_grad():
            for test_data in test_loader:
                count += 1
                data, label = test_data
                outputs = self.forward(data)
                _, correct_labels = torch.max(label, 1) 
                _, predicted = torch.max(outputs.data, 1)
                total += label.size(0)
                correct += (predicted == correct_labels).sum().item()
                running_loss += F.torch.nn.functional.binary_cross_entropy_with_logits(
                    outputs.float(), label.float()).item()
        acc = correct / total
        running_loss /= count
        
        return running_loss,acc


In [3]:
def train_model(model, optimizer, train_loader, val_loader,loss_fn, lr_scheduler=None, epochs=100, parallel=None):
    #print(model.eval())
    model.train()
    print(f"Numbers of parameters in model: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    best_model, best_acc, best_epoch = None, 0, 0
    history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []}
    for epoch_id in tqdm(range(epochs)):
        total = 0
        correct = 0
        running_loss = 0
        print(f"Start epoch number: {epoch_id + 1}")
#        print(next(enumerate(train_loader,0)))
        loads = list(enumerate(train_loader,0))
        for batch_id, data in loads:
#            print("Go here please")
            # get the inputs; data is a list of [inputs, labels]
            #print(f"Start batch number: {batch_id + 1} in epoch number: {epoch_id + 1}")
            inputs, labels = data
#            print(f"This is labels: {labels}\n\n\n\n")
            #print(f"Get data done")
            # zero the parameter gradients
            optimizer.zero_grad()
            #print(f"Reset the optimizer backward, grad to 0") 
            # forward + backward + optimize
            outputs = model(inputs)
            #print(f"forward data through model")
            _, predicted = torch.max(outputs, 1)
            #print(f"Get predicted class")
            _, correct_labels = torch.max(labels, 1)
            #print(f"Get label class")
            #print(labels)
            total += labels.size(0)
            correct += (predicted == correct_labels).sum().item()
            #print("Calculate the number of correct predictions")
            #print(labels.shape, outputs.shape)
            loss = loss_fn(outputs.float(), labels.float())
            loss.backward()
            #print("Backward loss")
            optimizer.step()
            #print("Step")
            running_loss += loss.item() 
            #print("End batch number: {batch_id + 1} in epoch number {epoch_id + 1}")
        #acc = round(correct/total * 1.0, 5)
        if lr_scheduler is not None:
            lr_scheduler.step()
        acc = correct / total

        #print("Accuracy was calculated")
        history["acc"].append(acc)
        history["loss"].append(running_loss)
        if parallel is not None:
            val_loss, val_acc = model.module.evaluate(val_loader)
        else:
            val_loss, val_acc = model.evaluate(val_loader)
        if acc > best_acc:
            best_acc = acc
            best_epoch = epoch_id + 1
            best_model = model
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        running_loss /= len(loads)
        #print(f"Epoch(s) {epoch_id + 1} | loss: {loss} | acc: {acc} | val_loss: {val_loss} | val_acc: {val_acc}")
        # checkpoint = {
        #     'epoch': epoch_id + 1,
        #     'model': model,
        #     'best_epoch': best_epoch,
        #     'optimizer': optimizer.state_dict()
        # }
        # torch.save(checkpoint, "./checkpoints/checkpoint.pt")
        print("Epoch(s) {:04d}/{:04d} | acc: {:.05f} | loss: {:.09f} | val_acc: {:.05f} | val_loss: {:.09f} | Best epochs: {:04d} | Best acc: {:09f}".format(
            epoch_id + 1, epochs, acc, running_loss, val_acc, val_loss, best_epoch, best_acc
            ))

    return history, best_model, best_epoch, best_acc



def main(ds_len, train_ds, valid_ds,model_type = "ode",data_name = "mnist_50",batch_size=32,epochs=100, lr=1e-3,train_num = 0, valid_num = 0, test_num = 0, weight_decay=None, device="cpu", result_dir="./result", model_dir="./model", parallel=None):
    print(f"Number of train: {train_num}\nNumber of validation: {valid_num}")
    #train_set = torch.utils.data.random_split(ds)
    #print(type(train_set))
    #assert isinstance(train_set,torch.utils.data.Dataset)
    #train_ds, _ = torch.utils.data.random_split(train_ds, lengths=[TRAIN_NUM, ds_len - TRAIN_NUM])
    #valid_ds, _ = torch.utils.data.random_split(valid_ds, lengths=[VALID_NUM, ds_len - VALID_NUM])
    print(len(train_ds))
    train_loader = DataLoader(train_ds, shuffle=True, batch_size=batch_size, drop_last=True)
    val_loader  = DataLoader(valid_ds, shuffle=True, batch_size= batch_size * 16, drop_last=True)
    loss_fn = torch.nn.functional.binary_cross_entropy_with_logits
    if parallel is not None:
        if model_type == "ode": 
            ode_func = ODEBlock(parallel=parallel)
            ode_func = nn.DataParallel(ode_func).to(device)
            model = ODENet(ode_func, parallel, device=device)
            model = nn.DataParallel(model).to(device)
#    ode_func = DDP(ODEBlock().to(device), output_device=device)
#    ode_model = DDP(ODENet(ode_func,device=device).to(device),output_device=device)
        elif model_type == "cnn":
#            epochs= int(epochs * 1.5)
            model = Network()
            model = nn.DataParallel(model).to(device)
    else:
        if model_type == "ode": 
            ode_func = ODEBlock().to(device)
            ode_func = ode_func.to(device)
            model = ODENet(ode_func, device=device)
            model = model.to(device)
#    ode_func = DDP(ODEBlock().to(device), output_device=device)
#    ode_model = DDP(ODENet(ode_func,device=device).to(device),output_device=device)
        elif model_type == "cnn":
            #epochs= int(epochs * 1.5)
            model = Network().to(device)
            #model = nn.DataParallel(model).to(device)
        

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    lr_scheduler = None
    if weight_decay is not None:
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=weight_decay, patience=5)
    his, model, epoch, acc = train_model(model, 
                      optimizer, 
                      train_loader,
                      val_loader,
                      lr_scheduler=lr_scheduler,
                      loss_fn=loss_fn, 
                      epochs=epochs,
                      parallel=parallel)
     
    # save_result(his,model_name=model_type,ds_name=data_name, result_dir=result_dir)
    # if not os.path.exists(f"{MODEL_DIR}/{model_type}_origin"):
    #     os.mkdir(f"{MODEL_DIR}/{model_type}_origin")
    # print("Save original data modeling...")
    # torch.save(model.state_dict(), f"{MODEL_DIR}/{model_type}_origin/{data_name}_origin.pt" ) 
    return model


In [4]:
def add_noise(converted_data, sigma = 10,device="cpu"):
    pertubed_data = converted_data + torch.normal(torch.zeros(converted_data.shape),
                                                  torch.ones(converted_data.shape) * sigma).to(device)
    #pertubed_data = torch.tensor(random_noise(converted_data.cpu(), mode='gaussian', mean=0, var=sigma**2, clip=False)).float().to(device)
    return pertubed_data
def preprocess_data(data, shape = (28,28), sigma=None,device="cpu", train=False):
    if not train:
        #assert type(sigma) == type(list()) or type(sigma) == type(None), f"if train=False, the type(sigma) must be return a list object or NoneType object, but return {type(sigma)}"
        X = []
        Y = []
        ds = {}
        sigma_noise = [50.,75.,100.]
        for data_idx, (x,y) in list(enumerate(data)):
            #X.append(np.array(x).reshape((3,shape[0],shape[0])))
            X.append(np.array(x).transpose(2,0,1)) # Change the shape from (H,W,C) -> (C,H,W)
            Y.append(y)
        y_data = F.one_hot(torch.Tensor(Y).to(torch.int64), num_classes=10)
        y_data = y_data.to(device)
        x_data = torch.Tensor(X)
        x_data = x_data.to(device)
        if sigma:
            x_noise_data = add_noise(x_data, sigma=sigma, device=device) / 255.0
            print(f"Generating {sigma}-pertubed-dataset")
        else:
            x_noise_data = x_data
            print(f"Generating {sigma}-pertubed-dataset")

        pertubed_ds = TensorDataset(x_noise_data,y_data)
        #ds.update({"original": TensorDataset(x_data / 255.0, y_data)})
        ds_len = len(Y)
        return ds_len, pertubed_ds
    else:
        import random
        X = []
        Y = []
        for data_idx, (x, y) in list(enumerate(data)):
            std = random.choice(sigma)
            noise_x = (np.array(x) + np.random.normal(np.zeros_like(np.array(x)), np.ones_like(np.array(x)) * std))
            X.append(noise_x.transpose(2,0,1))
            Y.append(y)
        y_data = F.one_hot(torch.Tensor(Y).to(torch.int64), num_classes=10)
        y_data = y_data.to(device)
        x_data = torch.Tensor(X)
        x_data = x_data.to(device)
        return len(Y), TensorDataset(x_data,y_data) 


In [5]:
BATCH_SIZE = 128
EPOCHS = 100
TRAIN_NUM = VALID_NUM = TEST_NUM = 0
MNIST = torchvision.datasets.CIFAR10('./data',
                                   train=True,
                                   transform=None,
                                   target_transform=None, download=True)

ds_len_, ds_ = preprocess_data(MNIST, sigma=None, device=device)
ds_len_, pertubed_ds_ = preprocess_data(MNIST, sigma=[25.0,30.0,40.0], device=device, train=True)
print(type(ds_))
    
sigma = [None, 1e-7, 50.0, 75.0, 100.0]
loaders = [(key,DataLoader(preprocess_data(MNIST, sigma=key, device=device, train=False)[1], batch_size=12000)) for key in sigma]


Files already downloaded and verified
Generating None-pertubed-dataset
<class 'torch.utils.data.dataset.TensorDataset'>
Generating None-pertubed-dataset
Generating 1e-07-pertubed-dataset
Generating 50.0-pertubed-dataset
Generating 75.0-pertubed-dataset
Generating 100.0-pertubed-dataset


In [6]:
evaluation = {
    "ode": {
        
    },
    "cnn": {

    }
}
for k in sigma:
    evaluation["ode"].update({k: []})
    evaluation["cnn"].update({k: []})


In [7]:
print(device)
EPOCHS = 100

cuda


In [8]:
for i in range(5):
    cnn_model = main(ds_len_,ds_, pertubed_ds_, device=device, model_type="cnn", data_name=f"mnist_origin",batch_size=BATCH_SIZE, epochs=EPOCHS, train_num=TRAIN_NUM, valid_num=VALID_NUM, test_num=TEST_NUM, parallel=None) 
    ode_model = main(ds_len_,ds_, pertubed_ds_, device=device, model_type="ode", data_name=f"mnist_origin",batch_size=BATCH_SIZE, epochs=EPOCHS, train_num=TRAIN_NUM, valid_num=VALID_NUM, test_num=TEST_NUM, parallel=None) 
    for k,l in loaders:
        if isinstance(cnn_model, nn.DataParallel): cnn_model = cnn_model.module
        if isinstance(ode_model, nn.DataParallel): ode_model = ode_model.module
        _, cnn_acc = cnn_model.evaluate(l) 
        _, ode_acc = ode_model.evaluate(l) 
        
        print(f"CNNs for {k}-gaussian-pertubed CIFAR10 = {cnn_acc}")
        print(f"ODEs for {k}-gaussian-pertubed CIFAR10 = {ode_acc}")
        

        evaluation["ode"][k].append(ode_acc)
        evaluation["cnn"][k].append(cnn_acc)


Number of train: 0
Number of validation: 0
50000
Numbers of parameters in model: 142410


  0%|          | 0/100 [00:00<?, ?it/s]

Start epoch number: 1


  input = module(input)


Epoch(s) 0001/0100 | acc: 0.27664 | loss: 0.724936743 | val_acc: 0.22618 | val_loss: 0.728563348 | Best epochs: 0001 | Best acc: 00.276643
Start epoch number: 2
Epoch(s) 0002/0100 | acc: 0.38704 | loss: 0.716102109 | val_acc: 0.28550 | val_loss: 0.724191554 | Best epochs: 0002 | Best acc: 00.387039
Start epoch number: 3
Epoch(s) 0003/0100 | acc: 0.44343 | loss: 0.710656770 | val_acc: 0.29159 | val_loss: 0.722991653 | Best epochs: 0003 | Best acc: 00.443429
Start epoch number: 4
Epoch(s) 0004/0100 | acc: 0.47316 | loss: 0.707569427 | val_acc: 0.39001 | val_loss: 0.714753484 | Best epochs: 0004 | Best acc: 00.473157
Start epoch number: 5
Epoch(s) 0005/0100 | acc: 0.50040 | loss: 0.704928773 | val_acc: 0.42460 | val_loss: 0.711437561 | Best epochs: 0005 | Best acc: 00.500401
Start epoch number: 6
Epoch(s) 0006/0100 | acc: 0.52310 | loss: 0.702952515 | val_acc: 0.36289 | val_loss: 0.717126384 | Best epochs: 0006 | Best acc: 00.523097
Start epoch number: 7
Epoch(s) 0007/0100 | acc: 0.54419 

  0%|          | 0/100 [00:00<?, ?it/s]

Start epoch number: 1
Epoch(s) 0001/0100 | acc: 0.26492 | loss: 0.725372621 | val_acc: 0.23433 | val_loss: 0.726754526 | Best epochs: 0001 | Best acc: 00.264924
Start epoch number: 2
Epoch(s) 0002/0100 | acc: 0.35024 | loss: 0.718871609 | val_acc: 0.30302 | val_loss: 0.722948102 | Best epochs: 0002 | Best acc: 00.350240
Start epoch number: 3
Epoch(s) 0003/0100 | acc: 0.39609 | loss: 0.714957931 | val_acc: 0.22278 | val_loss: 0.727819997 | Best epochs: 0003 | Best acc: 00.396094
Start epoch number: 4
Epoch(s) 0004/0100 | acc: 0.44002 | loss: 0.710665456 | val_acc: 0.39701 | val_loss: 0.715234376 | Best epochs: 0004 | Best acc: 00.440024
Start epoch number: 5
Epoch(s) 0005/0100 | acc: 0.46264 | loss: 0.708463469 | val_acc: 0.36092 | val_loss: 0.717373741 | Best epochs: 0005 | Best acc: 00.462640
Start epoch number: 6
Epoch(s) 0006/0100 | acc: 0.48446 | loss: 0.706133474 | val_acc: 0.39823 | val_loss: 0.714412173 | Best epochs: 0006 | Best acc: 00.484455
Start epoch number: 7
Epoch(s) 000

KeyboardInterrupt: 

In [8]:
for i in range(1):
    cnn_model = main(ds_len_,ds_, ds_, device=device, model_type="cnn", data_name=f"mnist_origin",batch_size=BATCH_SIZE, epochs=EPOCHS, train_num=TRAIN_NUM, valid_num=VALID_NUM, test_num=TEST_NUM, parallel=None) 
    ode_model = main(ds_len_,ds_, ds_, device=device, model_type="ode", data_name=f"mnist_origin",batch_size=BATCH_SIZE, epochs=EPOCHS, train_num=TRAIN_NUM, valid_num=VALID_NUM, test_num=TEST_NUM, parallel=None) 
    for k,l in loaders:
        if isinstance(cnn_model, nn.DataParallel): cnn_model = cnn_model.module
        if isinstance(ode_model, nn.DataParallel): ode_model = ode_model.module
        _, cnn_acc = cnn_model.evaluate(l) 
        _, ode_acc = ode_model.evaluate(l) 
        
        print(f"CNNs for {k}-gaussian-pertubed CIFAR10 = {cnn_acc}")
        print(f"ODEs for {k}-gaussian-pertubed CIFAR10 = {ode_acc}")
        

        evaluation["ode"][k].append(ode_acc)
        evaluation["cnn"][k].append(cnn_acc)


Number of train: 0
Number of validation: 0
50000
Numbers of parameters in model: 142410


  0%|          | 0/100 [00:00<?, ?it/s]

Start epoch number: 1
Epoch(s) 0001/0100 | acc: 0.27839 | loss: 0.724676020 | val_acc: 0.30180 | val_loss: 0.723055390 | Best epochs: 0001 | Best acc: 00.278385
Start epoch number: 2
Epoch(s) 0002/0100 | acc: 0.38566 | loss: 0.716246858 | val_acc: 0.42969 | val_loss: 0.712295259 | Best epochs: 0002 | Best acc: 00.385657
Start epoch number: 3
Epoch(s) 0003/0100 | acc: 0.44507 | loss: 0.710588137 | val_acc: 0.49654 | val_loss: 0.706510117 | Best epochs: 0003 | Best acc: 00.445072
Start epoch number: 4
Epoch(s) 0004/0100 | acc: 0.47394 | loss: 0.707569284 | val_acc: 0.52057 | val_loss: 0.704074000 | Best epochs: 0004 | Best acc: 00.473938
Start epoch number: 5
Epoch(s) 0005/0100 | acc: 0.50541 | loss: 0.704674325 | val_acc: 0.53733 | val_loss: 0.701633309 | Best epochs: 0005 | Best acc: 00.505409
Start epoch number: 6
Epoch(s) 0006/0100 | acc: 0.52007 | loss: 0.703378335 | val_acc: 0.52541 | val_loss: 0.702912383 | Best epochs: 0006 | Best acc: 00.520072
Start epoch number: 7
Epoch(s) 000

  0%|          | 0/100 [00:00<?, ?it/s]

Start epoch number: 1
Epoch(s) 0001/0100 | acc: 0.26148 | loss: 0.725745368 | val_acc: 0.32961 | val_loss: 0.720841157 | Best epochs: 0001 | Best acc: 00.261478
Start epoch number: 2
Epoch(s) 0002/0100 | acc: 0.35284 | loss: 0.718657690 | val_acc: 0.38639 | val_loss: 0.715924082 | Best epochs: 0002 | Best acc: 00.352845
Start epoch number: 3
Epoch(s) 0003/0100 | acc: 0.39762 | loss: 0.714524307 | val_acc: 0.44242 | val_loss: 0.711367168 | Best epochs: 0003 | Best acc: 00.397616
Start epoch number: 4
Epoch(s) 0004/0100 | acc: 0.43215 | loss: 0.711247290 | val_acc: 0.43481 | val_loss: 0.710580808 | Best epochs: 0004 | Best acc: 00.432151
Start epoch number: 5
Epoch(s) 0005/0100 | acc: 0.46637 | loss: 0.708067402 | val_acc: 0.48749 | val_loss: 0.706607848 | Best epochs: 0005 | Best acc: 00.466366
Start epoch number: 6
Epoch(s) 0006/0100 | acc: 0.48608 | loss: 0.706357270 | val_acc: 0.50120 | val_loss: 0.705515613 | Best epochs: 0006 | Best acc: 00.486078
Start epoch number: 7
Epoch(s) 000