# Final project for Computer Vision: WaveMix

Tung Lun (Tony) NGOK

In [1]:
!pip install einops
!pip install torchsummary
!pip install dualopt

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
Collecting dualopt
  Downloading dualopt-0.1.8-py3-none-any.whl.metadata (5.1 kB)
Collecting lion-pytorch (from dualopt)
  Downloading lion_pytorch-0.1.4-py3-none-any.whl.metadata (618 bytes)
Downloading dualopt-0.1.8-py3-none-any.whl (5.0 kB)
Downloading lion_pytorch-0.1.4-py3-none-any.whl (4.3 kB)
Installing collected packages: lion-pytorch, dualopt
Successfully installed dualopt-0.1.8 lion-pytorch-0.1.4


In [2]:
import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Function
import torch.nn as nn
import pywt
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

import torchvision
import torchvision.transforms as transforms
from torchsummary import summary

# https://pypi.org/project/dualopt/
import dualopt
from dualopt import classification, post_train

from PIL import Image

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import os, glob

torch.backends.cudnn.benchmarks = True
torch.backends.cudnn.deterministic = True

In [3]:
# use GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.get_device_properties(device))

_CudaDeviceProperties(name='Tesla P100-PCIE-16GB', major=6, minor=0, total_memory=16276MB, multi_processor_count=56)


## Loading the Tiny ImageNet dataset

In [4]:
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip

--2024-05-05 23:01:21--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.64.64
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://cs231n.stanford.edu/tiny-imagenet-200.zip [following]
--2024-05-05 23:01:21--  https://cs231n.stanford.edu/tiny-imagenet-200.zip
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: 'tiny-imagenet-200.zip'


2024-05-05 23:01:33 (20.8 MB/s) - 'tiny-imagenet-200.zip' saved [248100043/248100043]



In [5]:
!unzip -q tiny-imagenet-200.zip
!rm tiny-imagenet-200.zip

In [6]:
id_dict = {}
for i, line in enumerate(open('/kaggle/working/tiny-imagenet-200/wnids.txt', 'r')):
    id_dict[line.replace('\n', '')] = i

In [7]:
# https://github.com/pranavphoenix/TinyImageNetLoader/blob/main/tinyimagenetloader.py
# dataset loader provided by the author

class TrainTinyImageNetDataset(Dataset):
    def __init__(self, id, transform=None):
        self.filenames = glob.glob("/kaggle/working/tiny-imagenet-200/train/*/*/*.JPEG")
        self.transform = transform
        self.id_dict = id

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = Image.open(img_path)

        if image.mode == "L":
          image = image.convert('RGB')
        label = self.id_dict[img_path.split('/')[5]]
        if self.transform:
            image = self.transform(image)
        return image, label

In [8]:
class TestTinyImageNetDataset(Dataset):
    def __init__(self, id, transform=None):
        self.filenames = glob.glob("/kaggle/working/tiny-imagenet-200/val/images/*.JPEG")
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        for i, line in enumerate(open('/kaggle/working/tiny-imagenet-200/val/val_annotations.txt', 'r')):
            a = line.split('\t')
            img, cls_id = a[0],a[1]
            self.cls_dic[img] = self.id_dict[cls_id]


    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == "L":
          image = image.convert('RGB')
        label = self.cls_dic[img_path.split('/')[-1]]
        if self.transform:
            image = self.transform(image)
        return image, label

In [9]:
# transforms
transform_train_1 = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.TrivialAugmentWide(),
            transforms.ToTensor(),
     transforms.Normalize((0.4803, 0.4481, 0.3975), (0.2764, 0.2689, 0.2816))])

transform_train_2 = transforms.Compose(
        [
            transforms.ToTensor(),
     transforms.Normalize((0.4803, 0.4481, 0.3975), (0.2764, 0.2689, 0.2816))])

transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
     transforms.Normalize((0.4825, 0.4499, 0.3984), (0.2764, 0.2691, 0.2825))])

In [10]:
batch_size = 304

In [11]:
# dataset

trainset_1 = TrainTinyImageNetDataset(id=id_dict, transform=transform_train_1)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

trainset_2 = TrainTinyImageNetDataset(id=id_dict, transform=transform_train_2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

testset = TestTinyImageNetDataset(id=id_dict, transform = transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

## Processing the validation images

In [12]:
import cv2
from PIL import Image

In [13]:
# Gaussian blur function
def blur_image(image_array, blur_radius=5):
    blurred_image = cv2.GaussianBlur(image_array, (blur_radius, blur_radius), 0)
    return blurred_image

In [14]:
# Noise-adding function
def add_gaussian_noise(image_tensor, mean=0., std=0.1):
    noise = torch.randn(image_tensor.size()) * std + mean
    noisy_image = image_tensor + noise
    noisy_image = torch.clamp(noisy_image, 0., 1.)
    return noisy_image

In [15]:
# process every image
def process_images(input_folder_path, output_folder_blur, output_folder_noise, blur_radius=5, mean=0., std=0.1):

    # check output directory
    os.makedirs(output_folder_blur, exist_ok=True)
    os.makedirs(output_folder_noise, exist_ok=True)

    transform_to_tensor = transforms.ToTensor()
    transform_to_pil = transforms.ToPILImage()

    # Iterate every image
    for filename in os.listdir(input_folder_path):
        if filename.endswith(".JPEG"):
            image_path = os.path.join(input_folder_path, filename)
            image = cv2.imread(image_path)

            # Blur the image
            blurred_image = blur_image(image, blur_radius)
            cv2.imwrite(os.path.join(output_folder_blur, filename), blurred_image)

            # Adding noise
            image_pil = Image.open(image_path)
            image_tensor = transform_to_tensor(image_pil)

            # Save the noised images
            noisy_image_tensor = add_gaussian_noise(image_tensor, mean, std)
            noisy_image_pil = transform_to_pil(noisy_image_tensor)
            noisy_image_pil.save(os.path.join(output_folder_noise, filename))

In [16]:
# setting path
input_folder_path = "/kaggle/working/tiny-imagenet-200/val/images"  # folder path of original images
output_folder_blur = "/kaggle/working/tiny-imagenet-200/val_blur/images"  # folder path of saving blurred image
output_folder_noise = "/kaggle/working/tiny-imagenet-200/val_noise/images"  # folder path of saving noised image

process_images(input_folder_path, output_folder_blur, output_folder_noise)

In [17]:
class TestBlurredTinyImageNetDataset(Dataset):
    def __init__(self, id, transform=None):
        self.filenames = glob.glob("/kaggle/working/tiny-imagenet-200/val_blur/images/*.JPEG")
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        for i, line in enumerate(open('/kaggle/working/tiny-imagenet-200/val/val_annotations.txt', 'r')):
            a = line.split('\t')
            img, cls_id = a[0],a[1]
            self.cls_dic[img] = self.id_dict[cls_id]


    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == "L":
          image = image.convert('RGB')
        label = self.cls_dic[img_path.split('/')[-1]]
        if self.transform:
            image = self.transform(image)
        return image, label

In [18]:
# blurred dataset

trainset_blur_1 = TrainTinyImageNetDataset(id=id_dict, transform=transform_train_1)
trainloader_blur_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

trainset_blur_2 = TrainTinyImageNetDataset(id=id_dict, transform=transform_train_2)
trainloader_blur_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

testset_blur = TestBlurredTinyImageNetDataset(id=id_dict, transform=transform_test)
testloader_blur = torch.utils.data.DataLoader(testset_blur, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

In [19]:
class TestNoisedTinyImageNetDataset(Dataset):
    def __init__(self, id, transform=None):
        self.filenames = glob.glob("/kaggle/working/tiny-imagenet-200/val_noise/images/*.JPEG")
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        for i, line in enumerate(open('/kaggle/working/tiny-imagenet-200/val/val_annotations.txt', 'r')):
            a = line.split('\t')
            img, cls_id = a[0],a[1]
            self.cls_dic[img] = self.id_dict[cls_id]


    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == "L":
          image = image.convert('RGB')
        label = self.cls_dic[img_path.split('/')[-1]]
        if self.transform:
            image = self.transform(image)
        return image, label

In [20]:
# noised dataset

trainset_noise_1 = TrainTinyImageNetDataset(id=id_dict, transform=transform_train_1)
trainloader_noise_1 = torch.utils.data.DataLoader(trainset_1, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

trainset_noise_2 = TrainTinyImageNetDataset(id=id_dict, transform=transform_train_2)
trainloader_noise_2 = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

testset_noise = TestNoisedTinyImageNetDataset(id=id_dict, transform=transform_test)
testloader_noise = torch.utils.data.DataLoader(testset_noise, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, prefetch_factor=2, persistent_workers=2)

## The model

In [21]:
def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1):
    """ 1D synthesis filter bank of an image tensor
    """
    C = lo.shape[1]
    d = dim % 4
    # If g0, g1 are not tensors, make them. If they are, then assume that they
    # are in the right order
    if not isinstance(g0, torch.Tensor):
        g0 = torch.tensor(np.copy(np.array(g0).ravel()),
                          dtype=torch.float, device=lo.device)
    if not isinstance(g1, torch.Tensor):
        g1 = torch.tensor(np.copy(np.array(g1).ravel()),
                          dtype=torch.float, device=lo.device)
    L = g0.numel()
    shape = [1,1,1,1]
    shape[d] = L
    N = 2*lo.shape[d]
    # If g aren't in the right shape, make them so
    if g0.shape != tuple(shape):
        g0 = g0.reshape(*shape)
    if g1.shape != tuple(shape):
        g1 = g1.reshape(*shape)

    s = (2, 1) if d == 2 else (1,2)
    g0 = torch.cat([g0]*C,dim=0)
    g1 = torch.cat([g1]*C,dim=0)
    if mode == 'per' or mode == 'periodization':
        y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
            F.conv_transpose2d(hi, g1, stride=s, groups=C)
        if d == 2:
            y[:,:,:L-2] = y[:,:,:L-2] + y[:,:,N:N+L-2]
            y = y[:,:,:N]
        else:
            y[:,:,:,:L-2] = y[:,:,:,:L-2] + y[:,:,:,N:N+L-2]
            y = y[:,:,:,:N]
        y = roll(y, 1-L//2, dim=dim)
    else:
        if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \
                mode == 'periodic':
            pad = (L-2, 0) if d == 2 else (0, L-2)
            y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
                F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C)
        else:
            raise ValueError("Unkown pad type: {}".format(mode))

    return y

In [22]:
def reflect(x, minx, maxx):
    """Reflect the values in matrix *x* about the scalar values *minx* and
    *maxx*.  Hence a vector *x* containing a long linearly increasing series is
    converted into a waveform which ramps linearly up and down between *minx*
    and *maxx*.  If *x* contains integers and *minx* and *maxx* are (integers +
    0.5), the ramps will have repeated max and min samples.
    .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
    .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
    """
    x = np.asanyarray(x)
    rng = maxx - minx
    rng_by_2 = 2 * rng
    mod = np.fmod(x - minx, rng_by_2)
    normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
    out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
    return np.array(out, dtype=x.dtype)

In [23]:
def mode_to_int(mode):
    if mode == 'zero':
        return 0
    elif mode == 'symmetric':
        return 1
    elif mode == 'per' or mode == 'periodization':
        return 2
    elif mode == 'constant':
        return 3
    elif mode == 'reflect':
        return 4
    elif mode == 'replicate':
        return 5
    elif mode == 'periodic':
        return 6
    else:
        raise ValueError("Unkown pad type: {}".format(mode))

In [24]:
def int_to_mode(mode):
    if mode == 0:
        return 'zero'
    elif mode == 1:
        return 'symmetric'
    elif mode == 2:
        return 'periodization'
    elif mode == 3:
        return 'constant'
    elif mode == 4:
        return 'reflect'
    elif mode == 5:
        return 'replicate'
    elif mode == 6:
        return 'periodic'
    else:
        raise ValueError("Unkown pad type: {}".format(mode))

In [25]:
def afb1d(x, h0, h1, mode='zero', dim=-1):
    """ 1D analysis filter bank (along one dimension only) of an image
    Inputs:
        x (tensor): 4D input with the last two dimensions the spatial input
        h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1,
            h, 1) or (1, 1, 1, w)
        h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1,
            h, 1) or (1, 1, 1, w)
        mode (str): padding method
        dim (int) - dimension of filtering. d=2 is for a vertical filter (called
            column filtering but filters across the rows). d=3 is for a
            horizontal filter, (called row filtering but filters across the
            columns).
    Returns:
        lohi: lowpass and highpass subbands concatenated along the channel
            dimension
    """
    C = x.shape[1]
    # Convert the dim to positive
    d = dim % 4
    s = (2, 1) if d == 2 else (1, 2)
    N = x.shape[d]
    # If h0, h1 are not tensors, make them. If they are, then assume that they
    # are in the right order
    if not isinstance(h0, torch.Tensor):
        h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
                          dtype=torch.float, device=x.device)
    if not isinstance(h1, torch.Tensor):
        h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
                          dtype=torch.float, device=x.device)
    L = h0.numel()
    L2 = L // 2
    shape = [1,1,1,1]
    shape[d] = L
    # If h aren't in the right shape, make them so
    if h0.shape != tuple(shape):
        h0 = h0.reshape(*shape)
    if h1.shape != tuple(shape):
        h1 = h1.reshape(*shape)
    h = torch.cat([h0, h1] * C, dim=0)

    if mode == 'per' or mode == 'periodization':
        if x.shape[dim] % 2 == 1:
            if d == 2:
                x = torch.cat((x, x[:,:,-1:]), dim=2)
            else:
                x = torch.cat((x, x[:,:,:,-1:]), dim=3)
            N += 1
        x = roll(x, -L2, dim=d)
        pad = (L-1, 0) if d == 2 else (0, L-1)
        lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
        N2 = N//2
        if d == 2:
            lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2]
            lohi = lohi[:,:,:N2]
        else:
            lohi[:,:,:,:L2] = lohi[:,:,:,:L2] + lohi[:,:,:,N2:N2+L2]
            lohi = lohi[:,:,:,:N2]
    else:
        # Calculate the pad size
        outsize = pywt.dwt_coeff_len(N, L, mode=mode)
        p = 2 * (outsize - 1) - N + L
        if mode == 'zero':
            # Sadly, pytorch only allows for same padding before and after, if
            # we need to do more padding after for odd length signals, have to
            # prepad
            if p % 2 == 1:
                pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0)
                x = F.pad(x, pad)
            pad = (p//2, 0) if d == 2 else (0, p//2)
            # Calculate the high and lowpass
            lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
        elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic':
            pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
            x = mypad(x, pad=pad, mode=mode)
            lohi = F.conv2d(x, h, stride=s, groups=C)
        else:
            raise ValueError("Unkown pad type: {}".format(mode))

    return lohi

In [26]:
class AFB2D(Function):
    """ Does a single level 2d wavelet decomposition of an input. Does separate
    row and column filtering by two calls to
    :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
    Needs to have the tensors in the right form. Because this function defines
    its own backward pass, saves on memory by not having to save the input
    tensors.
    Inputs:
        x (torch.Tensor): Input to decompose
        h0_row: row lowpass
        h1_row: row highpass
        h0_col: col lowpass
        h1_col: col highpass
        mode (int): use mode_to_int to get the int code here
    We encode the mode as an integer rather than a string as gradcheck causes an
    error when a string is provided.
    Returns:
        y: Tensor of shape (N, C*4, H, W)
    """
    @staticmethod
    def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode):
        ctx.save_for_backward(h0_row, h1_row, h0_col, h1_col)
        ctx.shape = x.shape[-2:]
        mode = int_to_mode(mode)
        ctx.mode = mode
        lohi = afb1d(x, h0_row, h1_row, mode=mode, dim=3)
        y = afb1d(lohi, h0_col, h1_col, mode=mode, dim=2)
        s = y.shape
        y = y.reshape(s[0], -1, 4, s[-2], s[-1])
        low = y[:,:,0].contiguous()
        highs = y[:,:,1:].contiguous()
        return low, highs

    @staticmethod
    def backward(ctx, low, highs):
        dx = None
        if ctx.needs_input_grad[0]:
            mode = ctx.mode
            h0_row, h1_row, h0_col, h1_col = ctx.saved_tensors
            lh, hl, hh = torch.unbind(highs, dim=2)
            lo = sfb1d(low, lh, h0_col, h1_col, mode=mode, dim=2)
            hi = sfb1d(hl, hh, h0_col, h1_col, mode=mode, dim=2)
            dx = sfb1d(lo, hi, h0_row, h1_row, mode=mode, dim=3)
            if dx.shape[-2] > ctx.shape[-2] and dx.shape[-1] > ctx.shape[-1]:
                dx = dx[:,:,:ctx.shape[-2], :ctx.shape[-1]]
            elif dx.shape[-2] > ctx.shape[-2]:
                dx = dx[:,:,:ctx.shape[-2]]
            elif dx.shape[-1] > ctx.shape[-1]:
                dx = dx[:,:,:,:ctx.shape[-1]]
        return dx, None, None, None, None, None

In [27]:
def prep_filt_afb1d(h0, h1, device=device):
    """
    Prepares the filters to be of the right form for the afb2d function.  In
    particular, makes the tensors the right shape. It takes mirror images of
    them as as afb2d uses conv2d which acts like normal correlation.
    Inputs:
        h0 (array-like): low pass column filter bank
        h1 (array-like): high pass column filter bank
        device: which device to put the tensors on to
    Returns:
        (h0, h1)
    """
    h0 = np.array(h0[::-1]).ravel()
    h1 = np.array(h1[::-1]).ravel()
    t = torch.get_default_dtype()
    h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1))
    h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1))
    return h0, h1

In [28]:
def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=device):
    """
    Prepares the filters to be of the right form for the afb2d function.  In
    particular, makes the tensors the right shape. It takes mirror images of
    them as as afb2d uses conv2d which acts like normal correlation.
    Inputs:
        h0_col (array-like): low pass column filter bank
        h1_col (array-like): high pass column filter bank
        h0_row (array-like): low pass row filter bank. If none, will assume the
            same as column filter
        h1_row (array-like): high pass row filter bank. If none, will assume the
            same as column filter
        device: which device to put the tensors on to
    Returns:
        (h0_col, h1_col, h0_row, h1_row)
    """
    h0_col, h1_col = prep_filt_afb1d(h0_col, h1_col, device)
    if h0_row is None:
        h0_row, h1_col = h0_col, h1_col
    else:
        h0_row, h1_row = prep_filt_afb1d(h0_row, h1_row, device)

    h0_col = h0_col.reshape((1, 1, -1, 1))
    h1_col = h1_col.reshape((1, 1, -1, 1))
    h0_row = h0_row.reshape((1, 1, 1, -1))
    h1_row = h1_row.reshape((1, 1, 1, -1))
    return h0_col, h1_col, h0_row, h1_row

In [29]:
class DWTForward(nn.Module):
    """ Performs a 2d DWT Forward decomposition of an image
    Args:
        J (int): Number of levels of decomposition
        wave (str or pywt.Wavelet or tuple(ndarray)): Which wavelet to use.
            Can be:
            1) a string to pass to pywt.Wavelet constructor
            2) a pywt.Wavelet class
            3) a tuple of numpy arrays, either (h0, h1) or (h0_col, h1_col, h0_row, h1_row)
        mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
            padding scheme
        """
    def __init__(self, J=1, wave='db1', mode='zero'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        filts = prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
        self.register_buffer('h0_col', filts[0])
        self.register_buffer('h1_col', filts[1])
        self.register_buffer('h0_row', filts[2])
        self.register_buffer('h1_row', filts[3])
        self.J = J
        self.mode = mode

    def forward(self, x):
        """ Forward pass of the DWT.
        Args:
            x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
        Returns:
            (yl, yh)
                tuple of lowpass (yl) and bandpass (yh) coefficients.
                yh is a list of length J with the first entry
                being the finest scale coefficients. yl has shape
                :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
                :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
                dimension in yh iterates over the LH, HL and HH coefficients.
        Note:
            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
            downsampled shapes of the DWT pyramid.
        """
        yh = []
        ll = x
        mode = mode_to_int(self.mode)

        # Do a multilevel transform
        for j in range(self.J):
            # Do 1 level of the transform
            ll, high = AFB2D.apply(
                ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, mode)
            yh.append(high)

        return ll, yh

In [30]:
xf1 = DWTForward(J=1, mode='zero', wave='db1').to(device)
xf2 = DWTForward(J=2, mode='zero', wave='db1').to(device)
xf3 = DWTForward(J=3, mode='zero', wave='db1').to(device)
xf4 = DWTForward(J=4, mode='zero', wave='db1').to(device)

In [31]:
class Level1Waveblock(nn.Module):
    def __init__(
        self,
        *,
        mult = 2,
        ff_channel = 16,
        final_dim = 16,
        dropout = 0.5,
    ):
        super().__init__()
        
      
        self.feedforward = nn.Sequential(
                nn.Conv2d(final_dim, final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, final_dim, 4, stride=2, padding=1),
                nn.BatchNorm2d(final_dim)
            
            )

        self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)
        
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        x = self.reduction(x)
        
        Y1, Yh = xf1(x)
        
        x = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
        
        x = torch.cat((Y1,x), dim = 1)
        
        x = self.feedforward(x)
        
        return x

In [32]:
class Level2Waveblock(nn.Module):
    def __init__(
        self,
        *,
        mult = 2,
        ff_channel = 16,
        final_dim = 16,
        dropout = 0.5,
    ):
        super().__init__()
        
        self.feedforward1 = nn.Sequential(
                nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, final_dim, 4, stride=2, padding=1),
                nn.BatchNorm2d(final_dim)         
            )

        self.feedforward2 = nn.Sequential(
                nn.Conv2d(final_dim, final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
                nn.BatchNorm2d(int(final_dim/2))            
            )

        self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)
        
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        x = self.reduction(x)
        
        Y1, Yh = xf1(x)
        Y2, Yh = xf2(x)

        
        x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
        x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4)))
        
        x1 = torch.cat((Y1,x1), dim = 1)
        x2 = torch.cat((Y2,x2), dim = 1)
        
        x2 = self.feedforward2(x2)

        x1 = torch.cat((x1,x2), dim = 1)
        x = self.feedforward1(x1)
        
        return x

In [33]:
class Level3Waveblock(nn.Module):
    def __init__(
        self,
        *,
        mult = 2,
        ff_channel = 16,
        final_dim = 16,
        dropout = 0.5,
    ):
        super().__init__()
        
        self.feedforward1 = nn.Sequential(
                nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, final_dim, 4, stride=2, padding=1),
                nn.BatchNorm2d(final_dim)         
            )

        self.feedforward2 = nn.Sequential(
                nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
                nn.BatchNorm2d(int(final_dim/2))            
            )

        self.feedforward3 = nn.Sequential(
                nn.Conv2d(final_dim, final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
                nn.BatchNorm2d(int(final_dim/2))          
            )

        self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)
        
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        x = self.reduction(x)
        
        Y1, Yh = xf1(x)
        Y2, Yh = xf2(x)
        Y3, Yh = xf3(x)
        
        
        x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
        x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4)))
        x3 = torch.reshape(Yh[2], (b, int(c*3/4), int(h/8), int(w/8)))
        
        
        x1 = torch.cat((Y1,x1), dim = 1)
        x2 = torch.cat((Y2,x2), dim = 1)
        x3 = torch.cat((Y3,x3), dim = 1)
       
        
        x3 = self.feedforward3(x3)
        
        x2 = torch.cat((x2,x3), dim = 1)

        x2 = self.feedforward2(x2)

        x1 = torch.cat((x1,x2), dim = 1)
        x = self.feedforward1(x1)
        
        return x

In [34]:
class Level4Waveblock(nn.Module):
    def __init__(
        self,
        *,
        mult = 2,
        ff_channel = 16,
        final_dim = 16,
        dropout = 0.5,
    ):
        super().__init__()
        
      
        self.feedforward1 = nn.Sequential(
                nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, final_dim, 4, stride=2, padding=1),
                nn.BatchNorm2d(final_dim)         
            )

        self.feedforward2 = nn.Sequential(
                nn.Conv2d(final_dim + int(final_dim/2), final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
                nn.BatchNorm2d(int(final_dim/2))            
            )

        self.feedforward3 = nn.Sequential(
                nn.Conv2d(final_dim+ int(final_dim/2), final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
                nn.BatchNorm2d(int(final_dim/2))          
            )

        self.feedforward4 = nn.Sequential(
                nn.Conv2d(final_dim, final_dim*mult,1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv2d(final_dim*mult, ff_channel, 1),
                nn.ConvTranspose2d(ff_channel, int(final_dim/2), 4, stride=2, padding=1),
                nn.BatchNorm2d(int(final_dim/2))          
            )    

        self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1)
        
        
    def forward(self, x):
        b, c, h, w = x.shape
  
        x = self.reduction(x)
        
        Y1, Yh = xf1(x)
        Y2, Yh = xf2(x)
        Y3, Yh = xf3(x)
        Y4, Yh = xf4(x)
        
        x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2)))
        x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4)))
        x3 = torch.reshape(Yh[2], (b, int(c*3/4), int(h/8), int(w/8)))
        x4 = torch.reshape(Yh[3], (b, int(c*3/4), int(h/16), int(w/16)))
        
        x1 = torch.cat((Y1,x1), dim = 1)
        x2 = torch.cat((Y2,x2), dim = 1)
        x3 = torch.cat((Y3,x3), dim = 1)
        x4 = torch.cat((Y4,x4), dim = 1)
        
        
        x4 = self.feedforward4(x4)
        
        x3 = torch.cat((x3,x4), dim = 1)
        
        x3 = self.feedforward3(x3)
        
        x2 = torch.cat((x2,x3), dim = 1)

        x2 = self.feedforward2(x2)

        x1 = torch.cat((x1,x2), dim = 1)
        x = self.feedforward1(x1)
    
        return x

In [35]:
class WaveMix(nn.Module):
    def __init__(
        self,
        *,
        num_classes=1000,
        depth = 16,
        mult = 2,
        ff_channel = 192,
        final_dim = 192,
        dropout = 0.5,
        level = 3,
        initial_conv = 'pachify', # or 'strided'
        patch_size = 4,
        stride = 2,

    ):
        super().__init__()
        
        self.layers = nn.ModuleList([])
        for _ in range(depth): 
                if level == 4:
                    self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
                elif level == 3:
                    self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
                elif level == 2:
                    self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
                else:
                    self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
        
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...'),
            nn.Linear(final_dim, num_classes)
        )

        if initial_conv == 'strided':
            self.conv = nn.Sequential(
            nn.Conv2d(3, int(final_dim/2), 3, stride, 1),
            nn.Conv2d(int(final_dim/2), final_dim, 3, stride, 1)
        )
        else:
            self.conv = nn.Sequential(
            nn.Conv2d(3, int(final_dim/4),3, 1, 1),
            nn.Conv2d(int(final_dim/4), int(final_dim/2), 3, 1, 1),
            nn.Conv2d(int(final_dim/2), final_dim, patch_size, patch_size),
            nn.GELU(),
            nn.BatchNorm2d(final_dim)
            )
        

    def forward(self, img):
        x = self.conv(img)   
            
        for attn in self.layers:
            x = attn(x) + x

        out = self.pool(x)

        return out

In [36]:
# https://github.com/pranavphoenix/WaveMix/blob/main/Image_Classification/tinyimagenet.py
model = WaveMix(
    num_classes = 200,
    depth = 16,
    mult = 2,
    ff_channel = 192,
    final_dim = 192,
    dropout = 0.5,
    level = 3,
    initial_conv = 'pachify',
    patch_size = 4
)

In [37]:
# load pre-trained weights provided by the author
url = 'https://huggingface.co/cloudwalker/wavemix/resolve/main/Saved_Models_Weights/tinyimagenet/tiny_71.69.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(url))
model.to(device)

# summary
print(summary(model, (3, 64, 64)))  

PATH = 'tiny_71.69.pth' # path to save the model

print("ImageNet Weights Loaded")

Downloading: "https://huggingface.co/cloudwalker/wavemix/resolve/main/Saved_Models_Weights/tinyimagenet/tiny_71.69.pth" to /root/.cache/torch/hub/checkpoints/tiny_71.69.pth
100%|██████████| 106M/106M [00:02<00:00, 52.3MB/s] 


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 48, 64, 64]           1,344
            Conv2d-2           [-1, 96, 64, 64]          41,568
            Conv2d-3          [-1, 192, 16, 16]         295,104
              GELU-4          [-1, 192, 16, 16]               0
       BatchNorm2d-5          [-1, 192, 16, 16]             384
            Conv2d-6           [-1, 48, 16, 16]           9,264
            Conv2d-7            [-1, 384, 2, 2]          74,112
              GELU-8            [-1, 384, 2, 2]               0
           Dropout-9            [-1, 384, 2, 2]               0
           Conv2d-10            [-1, 192, 2, 2]          73,920
  ConvTranspose2d-11             [-1, 96, 4, 4]         295,008
      BatchNorm2d-12             [-1, 96, 4, 4]             192
           Conv2d-13            [-1, 384, 4, 4]         110,976
             GELU-14            [-1, 38

## Evaluation

In [38]:
counter = 3 # number of epochs without any improvement in accuracy before the training stops for each optimiser
num_classes = 200

### Baseline

In [39]:
top1 = [] # top1 accuracy
top5 = [] # top5 accuracy
traintime = []
testtime = []

In [40]:
# iterative deepening dual-optimiser training
classification(model, trainloader_1, testloader, device, PATH, top1, top5, traintime, testtime, num_classes=num_classes, set_counter=counter)

print('Finished Training')

Training with AdamW


Epoch 1: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.2392 - acc: 0.6721]


Epoch : 1 - Top 1: 70.45 - Top 5: 89.34 -  Train Time: 278.91 - Test Time: 13.04

1


Epoch 2: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.2317 - acc: 0.6722]


Epoch : 2 - Top 1: 69.92 - Top 5: 89.50 -  Train Time: 278.37 - Test Time: 12.79



Epoch 3: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.2438 - acc: 0.6727]


Epoch : 3 - Top 1: 69.98 - Top 5: 89.46 -  Train Time: 278.30 - Test Time: 12.80



Epoch 4: 100%|██████████| 329/329 [04:37<00:00,  1.18batch/s,  loss : 1.2139 - acc: 0.6770]


Epoch : 4 - Top 1: 70.62 - Top 5: 89.39 -  Train Time: 278.19 - Test Time: 12.78

1


Epoch 5: 100%|██████████| 329/329 [04:37<00:00,  1.18batch/s,  loss : 1.2124 - acc: 0.6788]


Epoch : 5 - Top 1: 69.83 - Top 5: 89.50 -  Train Time: 278.13 - Test Time: 12.79



Epoch 6: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.2063 - acc: 0.6789]


Epoch : 6 - Top 1: 70.44 - Top 5: 89.27 -  Train Time: 278.43 - Test Time: 12.82



Epoch 7: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.1859 - acc: 0.6834]


Epoch : 7 - Top 1: 69.99 - Top 5: 89.11 -  Train Time: 278.37 - Test Time: 12.90

Finished Training
Training with SGD


Epoch 1: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1587 - acc: 0.6931]


Epoch : 1 - Top 1: 71.04 - Top 5: 89.70 -  Train Time: 276.20 - Test Time: 12.80

1


Epoch 2: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1424 - acc: 0.6957]


Epoch : 2 - Top 1: 71.03 - Top 5: 89.67 -  Train Time: 276.13 - Test Time: 12.80



Epoch 3: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1356 - acc: 0.6993]


Epoch : 3 - Top 1: 71.17 - Top 5: 89.77 -  Train Time: 276.11 - Test Time: 12.81

1


Epoch 4: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1384 - acc: 0.6981]


Epoch : 4 - Top 1: 71.06 - Top 5: 89.74 -  Train Time: 275.99 - Test Time: 12.89



Epoch 5: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1369 - acc: 0.6969]


Epoch : 5 - Top 1: 71.21 - Top 5: 89.71 -  Train Time: 276.03 - Test Time: 12.88

1


Epoch 6: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1335 - acc: 0.6988]


Epoch : 6 - Top 1: 71.49 - Top 5: 89.93 -  Train Time: 275.78 - Test Time: 12.77

1


Epoch 7: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1321 - acc: 0.6996]


Epoch : 7 - Top 1: 71.57 - Top 5: 89.93 -  Train Time: 275.56 - Test Time: 12.85

1


Epoch 8: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1303 - acc: 0.6992]


Epoch : 8 - Top 1: 71.58 - Top 5: 89.89 -  Train Time: 275.63 - Test Time: 12.84

1


Epoch 9: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1311 - acc: 0.6993]


Epoch : 9 - Top 1: 71.62 - Top 5: 89.80 -  Train Time: 275.70 - Test Time: 12.90

1


Epoch 10: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1285 - acc: 0.6995]


Epoch : 10 - Top 1: 71.66 - Top 5: 89.85 -  Train Time: 275.70 - Test Time: 12.77

1


Epoch 11: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1227 - acc: 0.7017]


Epoch : 11 - Top 1: 71.52 - Top 5: 89.89 -  Train Time: 275.63 - Test Time: 12.85



Epoch 12: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1200 - acc: 0.7007]


Epoch : 12 - Top 1: 71.56 - Top 5: 89.81 -  Train Time: 275.66 - Test Time: 12.76



Epoch 13: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1215 - acc: 0.7018]


Epoch : 13 - Top 1: 71.36 - Top 5: 89.81 -  Train Time: 275.69 - Test Time: 12.76

Finished Training


In [41]:
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [42]:
post_train(model, trainloader_2, testloader, device, PATH, top1, top5, traintime, testtime, num_classes=num_classes, set_counter=counter)
print('Finished Training')

Post-training with SGD


Epoch 1: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 0.6575 - acc: 0.8136]


Epoch : 1 - Top 1: 71.27 - Top 5: 89.80 -  Train Time: 275.70 - Test Time: 12.87



Epoch 2: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 0.6467 - acc: 0.8160]


Epoch : 2 - Top 1: 71.40 - Top 5: 89.86 -  Train Time: 275.60 - Test Time: 12.81



Epoch 3: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 0.6436 - acc: 0.8176]


Epoch : 3 - Top 1: 71.34 - Top 5: 89.75 -  Train Time: 275.58 - Test Time: 12.78

Finished Training


In [43]:
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}")

Results
Top 1 Accuracy: 71.66 -Top 5 Accuracy : 89.93 - Train Time: 276 -Test Time: 13


## Blurred validation images

In [44]:
top1 = [] # top1 accuracy
top5 = [] # top5 accuracy
traintime = []
testtime = []

In [45]:
classification(model, trainloader_blur_1, testloader_blur, device, PATH, top1, top5, traintime, testtime, num_classes=num_classes, set_counter=counter)
print('Finished Training')

Training with AdamW


Epoch 1: 100%|██████████| 329/329 [04:37<00:00,  1.18batch/s,  loss : 1.2111 - acc: 0.6789]


Epoch : 1 - Top 1: 51.35 - Top 5: 77.03 -  Train Time: 278.17 - Test Time: 13.14

1


Epoch 2: 100%|██████████| 329/329 [04:37<00:00,  1.18batch/s,  loss : 1.1981 - acc: 0.6800]


Epoch : 2 - Top 1: 51.60 - Top 5: 76.80 -  Train Time: 278.06 - Test Time: 12.77

1


Epoch 3: 100%|██████████| 329/329 [04:37<00:00,  1.18batch/s,  loss : 1.1893 - acc: 0.6842]


Epoch : 3 - Top 1: 52.23 - Top 5: 77.59 -  Train Time: 277.99 - Test Time: 12.78

1


Epoch 4: 100%|██████████| 329/329 [04:37<00:00,  1.18batch/s,  loss : 1.1763 - acc: 0.6866]


Epoch : 4 - Top 1: 52.06 - Top 5: 77.40 -  Train Time: 277.96 - Test Time: 12.78



Epoch 5: 100%|██████████| 329/329 [04:37<00:00,  1.18batch/s,  loss : 1.1712 - acc: 0.6880]


Epoch : 5 - Top 1: 50.09 - Top 5: 75.61 -  Train Time: 277.91 - Test Time: 12.84



Epoch 6: 100%|██████████| 329/329 [04:37<00:00,  1.18batch/s,  loss : 1.1549 - acc: 0.6898]


Epoch : 6 - Top 1: 51.63 - Top 5: 77.01 -  Train Time: 277.99 - Test Time: 12.80

Finished Training
Training with SGD


Epoch 1: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1303 - acc: 0.6991]


Epoch : 1 - Top 1: 52.10 - Top 5: 77.07 -  Train Time: 275.70 - Test Time: 12.86



Epoch 2: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1109 - acc: 0.7031]


Epoch : 2 - Top 1: 52.57 - Top 5: 77.29 -  Train Time: 275.69 - Test Time: 12.88

1


Epoch 3: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1094 - acc: 0.7048]


Epoch : 3 - Top 1: 52.63 - Top 5: 77.58 -  Train Time: 275.97 - Test Time: 12.90

1


Epoch 4: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1075 - acc: 0.7049]


Epoch : 4 - Top 1: 51.37 - Top 5: 76.53 -  Train Time: 275.81 - Test Time: 12.87



Epoch 5: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1038 - acc: 0.7051]


Epoch : 5 - Top 1: 52.13 - Top 5: 77.16 -  Train Time: 275.84 - Test Time: 12.79



Epoch 6: 100%|██████████| 329/329 [04:35<00:00,  1.19batch/s,  loss : 1.1033 - acc: 0.7070]


Epoch : 6 - Top 1: 51.86 - Top 5: 76.89 -  Train Time: 275.86 - Test Time: 12.92

Finished Training


In [46]:
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [47]:
post_train(model, trainloader_blur_2, testloader_blur, device, PATH, top1, top5, traintime, testtime, num_classes=num_classes, set_counter=counter)
print('Finished Training')

Post-training with SGD


Epoch 1: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 0.6303 - acc: 0.8215]


Epoch : 1 - Top 1: 51.25 - Top 5: 76.55 -  Train Time: 276.55 - Test Time: 12.91



Epoch 2: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 0.6171 - acc: 0.8251]


Epoch : 2 - Top 1: 50.88 - Top 5: 76.49 -  Train Time: 276.38 - Test Time: 12.93



Epoch 3: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 0.6111 - acc: 0.8266]


Epoch : 3 - Top 1: 50.51 - Top 5: 75.95 -  Train Time: 276.36 - Test Time: 12.91

Finished Training


In [48]:
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}")

Results
Top 1 Accuracy: 52.63 -Top 5 Accuracy : 77.59 - Train Time: 276 -Test Time: 13


### Validation images with Gaussian noises

In [49]:
top1 = [] # top1 accuracy
top5 = [] # top5 accuracy
traintime = []
testtime = []

In [50]:
classification(model, trainloader_noise_1, testloader_noise, device, PATH, top1, top5, traintime, testtime, num_classes=num_classes, set_counter=counter)
print('Finished Training')

Training with AdamW


Epoch 1: 100%|██████████| 329/329 [04:39<00:00,  1.18batch/s,  loss : 1.1752 - acc: 0.6885]


Epoch : 1 - Top 1: 43.22 - Top 5: 68.71 -  Train Time: 279.34 - Test Time: 13.02

1


Epoch 2: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.1726 - acc: 0.6875]


Epoch : 2 - Top 1: 43.22 - Top 5: 67.36 -  Train Time: 279.14 - Test Time: 12.85

1


Epoch 3: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.1557 - acc: 0.6935]


Epoch : 3 - Top 1: 44.57 - Top 5: 69.53 -  Train Time: 278.92 - Test Time: 12.90

1


Epoch 4: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.1560 - acc: 0.6928]


Epoch : 4 - Top 1: 42.50 - Top 5: 67.55 -  Train Time: 279.17 - Test Time: 12.85



Epoch 5: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.1325 - acc: 0.6970]


Epoch : 5 - Top 1: 44.47 - Top 5: 69.24 -  Train Time: 278.91 - Test Time: 12.92



Epoch 6: 100%|██████████| 329/329 [04:38<00:00,  1.18batch/s,  loss : 1.1333 - acc: 0.6975]


Epoch : 6 - Top 1: 43.83 - Top 5: 68.81 -  Train Time: 278.91 - Test Time: 12.92

Finished Training
Training with SGD


Epoch 1: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 1.1004 - acc: 0.7060]


Epoch : 1 - Top 1: 44.05 - Top 5: 68.67 -  Train Time: 276.59 - Test Time: 12.95



Epoch 2: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 1.0836 - acc: 0.7096]


Epoch : 2 - Top 1: 44.21 - Top 5: 69.24 -  Train Time: 276.68 - Test Time: 12.86



Epoch 3: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 1.0824 - acc: 0.7115]


Epoch : 3 - Top 1: 44.38 - Top 5: 69.23 -  Train Time: 276.59 - Test Time: 12.92

Finished Training


In [51]:
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [52]:
post_train(model, trainloader_noise_2, testloader_noise, device, PATH, top1, top5, traintime, testtime, num_classes=num_classes, set_counter=counter)
print('Finished Training')

Post-training with SGD


Epoch 1: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 0.6149 - acc: 0.8236]


Epoch : 1 - Top 1: 43.70 - Top 5: 68.83 -  Train Time: 276.71 - Test Time: 12.82



Epoch 2: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 0.5964 - acc: 0.8291]


Epoch : 2 - Top 1: 44.35 - Top 5: 69.31 -  Train Time: 276.38 - Test Time: 12.92



Epoch 3: 100%|██████████| 329/329 [04:36<00:00,  1.19batch/s,  loss : 0.5911 - acc: 0.8315]


Epoch : 3 - Top 1: 44.27 - Top 5: 69.24 -  Train Time: 276.39 - Test Time: 12.92

Finished Training


In [53]:
print("Results")
print(f"Top 1 Accuracy: {max(top1):.2f} -Top 5 Accuracy : {max(top5):.2f} - Train Time: {min(traintime):.0f} -Test Time: {min(testtime):.0f}")

Results
Top 1 Accuracy: 44.57 -Top 5 Accuracy : 69.53 - Train Time: 276 -Test Time: 13
