In [1]:

import sys 
sys.path.append('../')

import cProfile
from pstats import Stats, SortKey

import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, Grayscale
from torchvision.datasets import MNIST, CIFAR10 
from torch.utils.data import DataLoader, Subset, TensorDataset
from sklearn.model_selection import StratifiedShuffleSplit

from networks import *
from utils import *
from losses import *
from landscape import *
from datasets import *

import matplotlib.pyplot as pltimport

In [2]:

def run_script(criterion):
    # Set training static parameters and hyperparameters
    dims_latent = 32                            
    learning_rate=1e-3
    batch_size=8                                       
    device="cuda:0"
    train_size=0.00108
    
    # Datasets and dataloaders
    train_transform = Compose([
        Resize(28),
        ToTensor(),
    ])

    ds = MNIST("../notebooks/mnist_example", download=False, train=True, transform=train_transform)
    shuffler = StratifiedShuffleSplit(n_splits=1, test_size=1-train_size, random_state=42).split(ds.data, ds.targets)
    train_idx, valid_idx = [(train_idx, validation_idx) for train_idx, validation_idx in shuffler][0]

    X_train, y_train = ds.data[train_idx] / 255., ds.targets[train_idx]
    trainds = TensorDataset(X_train.unsqueeze(1).float(), y_train.float())

    train_loader = DataLoader(trainds, batch_size=batch_size, shuffle=True, num_workers=4)
    
    # Sample 
    X = trainds[0][0].unsqueeze(0).to(device)

    # Model and optimizer
    model = Autoencoder(dims_latent=dims_latent, nc=X.shape[1]).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    
    #Train one epoch
    _ = train(model, train_loader, optimizer, criterion, device=device)
    
    return None

In [3]:

def run_profiler(criterion, ntimes=10, order="cumtime"):
    profiler = cProfile.Profile()
    profiler.enable()
    with progressbar.ProgressBar(max_val=ntimes) as bar:
        for i in range(ntimes):
            run_script(criterion)
            bar.update(i)
    profiler.disable()
    stats = Stats(profiler).sort_stats(order)
    return stats

In [4]:
criterion = AWLoss1DRoll(reduction="sum", std=1e-4, store_filters=True, alpha=0.02)
stats = run_profiler(criterion, ntimes=50)
stats.print_stats()




RuntimeError: The size of tensor a (1567) must match the size of tensor b (784) at non-singleton dimension 0

In [10]:
criterion = AWLoss1D(reduction="sum", std=1e-4, store_filters=True, alpha=0.02)
stats = run_profiler(criterion, ntimes=50)
stats.print_stats()

| |                                             #    | 49 Elapsed Time: 0:05:25


         1343057 function calls (1330305 primitive calls) in 331.427 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       50    0.027    0.001  331.372    6.627 <ipython-input-4-187eb2ae3652>:1(run_script)
       50    0.102    0.002  327.518    6.550 ..\utils.py:76(train)
 8800/800    0.101    0.000  242.323    0.303 C:\Program Files (x86)\Anaconda3\lib\site-packages\torch\nn\modules\module.py:715(_call_impl)
      400    2.120    0.005  241.705    0.604 ..\losses.py:68(forward)
     3200  133.407    0.042  133.586    0.042 ..\losses.py:42(make_toeplitz)
     3200  103.777    0.032  103.777    0.032 {built-in method inverse}
      450    0.002    0.000   78.880    0.175 C:\Program Files (x86)\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py:432(__next__)
      450    0.006    0.000   78.878    0.175 C:\Program Files (x86)\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py:1038(_next_data)
      400   

<pstats.Stats at 0x2764563ca08>

In [15]:
criterion = AWLoss2D(reduction="sum", std=1e-4, store_filters=True, alpha=0.02)
stats = run_profiler(criterion, ntimes=5)
stats.print_stats()

| |             #                                     | 4 Elapsed Time: 0:01:12


         170946 function calls (169669 primitive calls) in 87.247 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        5    0.002    0.000   87.234   17.447 <ipython-input-2-187eb2ae3652>:1(run_script)
        5    0.010    0.002   86.897   17.379 ..\utils.py:76(train)
   880/80    0.011    0.000   78.723    0.984 C:\Program Files (x86)\Anaconda3\lib\site-packages\torch\nn\modules\module.py:715(_call_impl)
       40    0.332    0.008   78.649    1.966 ..\losses.py:164(forward)
      320   47.410    0.148   47.410    0.148 {built-in method inverse}
      320   15.983    0.050   30.637    0.096 ..\losses.py:117(make_doubly_block)
     8960   14.142    0.002   14.611    0.002 ..\losses.py:108(make_toeplitz)
       45    0.000    0.000    7.376    0.164 C:\Program Files (x86)\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py:432(__next__)
       45    0.001    0.000    7.375    0.164 C:\Program Files (x86)\Anacond

<pstats.Stats at 0x155c88b0e08>