# Chem 277B Spring 2024 Tutorial 7
---
# Outline

+ Convolutional Neural Network (CNN):
    + Hyperparamters in CNN: channels, padding, stride, dilation
    + Pooling
    + CNN in PyTorch
+ Residual Network
+ Batch Normalization

# HW6 - Helper function


You can use the following decorator to report time:

In [None]:
import time

def timeit(f):

    def timed(*args, **kw):

        ts = time.time()
        result = f(*args, **kw)
        te = time.time()

        print(f'func:{f.__name__} took: {te-ts:.4f} sec')
        return result

    return timed

@timeit
def sleep(sec):
    return time.sleep(sec)

sleep(0.1)

In [None]:
class Trainer:
    
    def __init__(self, model, opt_method, learning_rate, batch_size, epoch, l2):
        self.model = model
        
        if opt_method == "adam":
            self.optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=l2)
        else:
            raise NotImplementedError("This optimization is not supported")
        
        self.epoch = epoch
        self.batch_size = batch_size
    
    @timeit
    def train(self, train_data, val_data, early_stop=True, verbose=True, draw_curve=True):
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        
        train_loss_list, train_acc_list = [], []
        val_loss_list, val_acc_list = [], []
        weights = self.model.state_dict()
        lowest_val_loss = np.inf
        loss_func = nn.CrossEntropyLoss()
        for n in tqdm(range(self.epoch), leave=False):
            # enable train mode
            self.model.train()
            epoch_loss, epoch_acc = 0.0, 0.0
            for X_batch, y_batch in train_loader:
                # batch_importance is the ratio of batch_size 
                batch_importance = y_batch.shape[0]/len(train_data)
                y_pred = self.model(X_batch)
                batch_loss = loss_func(y_pred, y_batch)
                
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
                
                epoch_loss += batch_loss.detach().cpu().item() * batch_importance
                batch_acc = torch.sum(torch.argmax(y_pred, axis=1) == y_batch)/y_batch.shape[0]
                epoch_acc += batch_acc.detach().cpu().item() * batch_importance
                
                
#             train_loss_list.append(epoch_loss)
#             train_acc_list.append(epoch_acc)
            # previous way to report might get low acc due to dropout
            train_loss, train_acc = self.evaluate(train_data)
    
            val_loss, val_acc = self.evaluate(val_data)
            val_loss_list.append(val_loss)
            val_acc_list.append(val_acc)
            
            if early_stop:
                if val_loss < lowest_val_loss:
                    lowest_val_loss = val_loss
                    weights = self.model.state_dict()
            
        if draw_curve:
            x_axis = np.arange(self.epoch)
            fig, axes = plt.subplots(1, 2, figsize=(10, 4))
            axes[0].plot(x_axis, train_loss_list, label="Train")
            axes[0].plot(x_axis, val_loss_list, label="Validation")
            axes[0].set_title("Loss")
            axes[0].legend()
            axes[1].plot(x_axis, train_acc_list, label='Train')
            axes[1].plot(x_axis, val_acc_list, label='Validation')
            axes[1].set_title("Accuracy")
            axes[1].legend()
        
        if early_stop:
            self.model.load_state_dict(weights)
        
        return {
            "train_loss_list": train_loss_list,
            "train_acc_list": train_acc_list,
            "val_loss_list": val_loss_list,
            "val_acc_list": val_acc_list,
        }
    
    def evaluate(self, data, print_acc=False):
        # enable evaluation mode
        self.model.eval()
        loader = DataLoader(data, batch_size=self.batch_size, shuffle=True)
        loss_func = nn.CrossEntropyLoss()
        acc, loss = 0.0, 0.0
        for X_batch, y_batch in loader:
            with torch.no_grad():
                batch_importance = y_batch.shape[0]/len(data)
                y_pred = self.model(X_batch)
                batch_loss = loss_func(y_pred, y_batch)
                batch_acc = torch.sum(torch.argmax(y_pred, axis=1) == y_batch)/y_batch.shape[0]
                acc += batch_acc.detach().cpu().item() * batch_importance
                loss += batch_loss.detach().cpu().item() * batch_importance
        if print_acc:
            print(f"Accuracy: {acc:.3f}")
        return loss, acc

## Convolutional Neural Netwok (CNN)

### CNN general architechture
![](https://cdn-images-1.medium.com/max/800/1*lvvWF48t7cyRWqct13eU0w.jpeg)  


### Convolution Filters help extract features
![](https://qph.fs.quoracdn.net/main-qimg-50915e66f98186a786b3d0344eea9aba-pjlq)  

### Calculating convolution output shape
Here is a [visualization](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for padding, stride and dilation

$$H_{\text {out }}=\left[\frac{H_{\text {in }}+2 \times \text { padding }-\operatorname{dilation} \times(\text { kernel size }-1)-1}{\text { stride }}+1\right]$$


In [None]:
import pickle
import torch
import torch.nn as nn

In [None]:
# init a Conv2d layer
conv = ...
conv

![](https://production-media.paperswithcode.com/methods/MaxpoolSample2.png)  

In [None]:
# init a MaxPool layer
max_pool = ...
max_pool

In [None]:
# init a Average Pool layer
avg_pool = ...
avg_pool

In [None]:
def out_dim(in_dim, kernel_size, padding, stride, dilation):
    return (in_dim + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1


# data shape: (N, C, W, H)
data = torch.rand(1, 1, 2, 2)
conv(data)

### LeNet architecture
LeCun, Y.; Bottou, L.; Bengio, Y. & Haffner, P. (1998). Gradient-based learning applied to document recognition.Proceedings of the IEEE. 86(11): 2278 - 2324.

|Layer No.|Layer type|#channels/#features|Kernel size|Stride|Activation|
|---|---|---|---|---|---|
|1|2D Convolution|6|5|1|tanh|
|2|Average pooling|6|2|2|\\|
|3|2D Convolution|16|5|1|tanh|
|4|Average pooling|16|2|2|\\|
|5|2D Convolution|120|5|1|tanh|
|6|Flatten|\\|\\|\\|\\|
|7|Fully connected|84|\\|\\|tanh|
|8|Fully connected|10|\\|\\|softmax|

In [None]:
def load_dataset(path):
    with open(path, 'rb') as f:
        train_data, test_data = pickle.load(f)
    
    X_train = torch.tensor(train_data[0], dtype=torch.float).unsqueeze(1)
    y_train = torch.tensor(train_data[1], dtype=torch.long)
    X_test = torch.tensor(test_data[0], dtype=torch.float).unsqueeze(1)
    y_test = torch.tensor(test_data[1], dtype=torch.long)
    return X_train, y_train, X_test, y_test

X_train, y_train, X_test, y_test = load_dataset("../../Datasets/mnist.pkl")

In [None]:
class LeNet(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.conv = nn.ModuleList([
            ...
        ])
        self.pool = ...
        self.activation = nn.Tanh()
        self.fc = nn.ModuleList([
            nn.Linear(120, 84),
            nn.Linear(84, 10)
        ])
    
    def forward(self, x):
        for i in range(2):
            x = self.pool(self.activation(self.conv[i](x)))
        x = nn.Flatten()(self.activation(self.conv[2](x)))
        x = self.activation(self.fc[0](x))
        x = nn.Softmax(dim=-1)(self.fc[1](x))
        return x

net = LeNet()
net

In [None]:
# Use torchsummary to print the architecture
# ! pip install torchsummary
from torchsummary import summary

s = summary(net, (1, 32, 32))

In [None]:
net(X_train[:10])

## Residual Network (ResNet)


An example of residual block:

<img src="https://miro.medium.com/v2/resize:fit:868/format:webp/0*sGlmENAXIZhSqyFZ" width="400" />

In [None]:
class ResBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.ModuleList([nn.Linear(dim, dim), nn.Linear(dim, dim)])
        self.activation = nn.ReLU()
    
    def forward(self, x):
        out = self.activation(self.fc[0](x))
        out = self.fc[1](out)
        out += x
        out = self.activation(out)
        return out
    

In [None]:
# Let't modify the LeNet by adding a skip connection at the first fc layer
class LeNetRes(nn.Module):
    ...

## Batch Normalization (BN)

For a 4-D input data $X$ with shape $(N,C,W,H)$. For each channel, the data is normalized by:

$$\hat{X}_{ijkl}=\frac{X_{ijkl}-\mathrm{mean}(X_j)}{\sqrt{\mathrm{var}(X_j)+\epsilon}} * \gamma_j + \beta_j$$

where

$$\mathrm{mean}(X_j)=\frac{1}{NWH}\sum_{i}^N\sum_k^W\sum_l^H X_{ikl}$$
$$\mathrm{var}(X_j)=\frac{1}{NWH}\sum_{i}^N\sum_k^W\sum_l^H (X_{ikl}-\mathrm{mean}(X_j))^2$$

$\epsilon$ is a small number (say, $10^{-5}$) to avoid numerical instability. $\boldsymbol{\gamma, \beta}$ are learnable parameters

In [None]:
batch_norm = ...