# **Initializations**

In [None]:
import pickle,gzip,math,os,time,shutil,torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sys,gc,traceback
import fastcore.all as fc
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager

import torchvision.transforms.functional as TF
import torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder

from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *

In [None]:
from fastcore.test import test_close

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)

import logging
logging.disable(logging.WARNING)

#set_seed(42)

## **Recreating the Case of Faulty Initializations**

Recall that in the previous NB, `10_activations`,  we were able to use Hooks to visualize the layerwise activations of a simple baseline CNN which uses a `MomentumLearner`. This proved to be a powerful technique to identify and diagnose any issues during the early stages of the training process.

We also established that while the baseline model had improved loss and accuracy metrics when compared to the standard `Learner`, it still began training incorrectly. This could be observed in the poor activation stats for its different layers. 

Let's proceed to recreate the faulty CNN to see how the issue can be fixed.

In [None]:
# Load Data and run inplace transformations
xl, yl = 'image', 'label'
name = "fashion_mnist"
dsd = load_dataset(name)

@inplace
def transformi(b): b[xl] = [TF.to_tensor(o) for o in b[xl]]

# Batch size and transforms
bs = 1024
tds = dsd.with_transform(transformi)

In [None]:
# Load the data onto Dataloader and create batches
dls = DataLoaders.from_dd(tds, bs, num_workers=4)
dt = dls.train
xb, yb = next(iter(dt))
xb.shape, yb[:10]

In [None]:
# Setup model
def get_model():
    return nn.Sequential(conv(1, 8), 
                         conv(8, 16),
                         conv(16, 32),
                         conv(32, 64),
                         conv(64, 10, act=False),
                         nn.Flatten()).to(def_device)

In [None]:
# Find a good learning rate using the LR Finder
MomentumLearner(get_model(), dls, F.cross_entropy, cbs=[DeviceCB()]).lr_find(gamma=1.1, start_lr=1e-2)

In [None]:
# Setup metrics, activation stats, callbacks and learner
metrics = MetricsCB(accuracy=MulticlassAccuracy())
astats = ActivationStats(fc.risinstance(nn.ReLU))
cbs = [DeviceCB(), metrics, ProgressCB(plot=True), astats]
learn = MomentumLearner(get_model(), dls, F.cross_entropy, lr=0.2, cbs=cbs)

In [None]:
# Fit one epoch
learn.fit(1)

In [None]:
# Activation stats for the model's layers following one epoch of training
astats.color_dim();

In [None]:
# Means and Standard Deviations of activations
astats.plot_stats();

### **Some Extra Functionality**  

The following functions allow us to clean and recover both GPU and system memory without having to restart the notebook kernel.

`clean_ipython_hist()` is designed to clean the history of commands in an active IPython session.

In [1]:
def clean_ipython_hist():
    # Copied mainly from IPython source
    if not 'get_ipython' in globals(): return
    ip = get_ipython() # Currently running IPython instance
    user_ns = ip.user_ns # Namespace (dictionary) where user variables are stored
    pc = ip.displayhook.prompt_count + 1 # Number of commands/cells executed in the current session 
    # Cycle through inputs in the namespace, and clear the stored input history
    # None ensures no error is raised if a key doesn't exist
    for n in range(1, pc): user_ns.pop('_i' + repr(n), None) 
    # Set the last 3 inputs in the session to empty to clear references
    user_ns.update(dict(_i='', _ii='', _iii=''))
    # hm is the history manager of any given session
    hm = ip.history_manager
    hm.input_hist_parsed[:] = [''] * pc # Clear parsed input history
    hm.input_hist_raw[:] = [''] * pc # Clear raw input history
    hm._i = hm._ii = hm._iii = hm._i00 = '' # Clear last few inputs from hm.

Next, `clean_tb()` is meant to free up resources and preventing memory leaks after an exception has occurred by clearing the last traceback's frames to free up memory.

In [3]:
def clean_tb():
    if hasattr(sys, 'last_traceback'): # Check for last_traceback attribute in the event of an unhandled exception
        # Remove references to local variables and free up memory
        traceback.clear_frames(sys.last_traceback) 
        delattr(sys, 'last_traceback')
    if hasattr(sys, 'last_type'): delattr(sys, 'last_type') # Remove last exception's type
    if hasattr(sys, 'last_value'): delattr(sys, 'last_value') # Remove last exception's actual value

In [None]:
# Function to run memory cleaning ops
def clean_mem():
    clean_tb()
    clean_ipython_hist()
    gc.collect()
    torch.cuda.empty_cache