# Deep and wide model using torch
---

In [2]:
import torch
import numpy as np
from tqdm import tqdm
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

## Load california housing data

In [3]:
housing = fetch_california_housing()
x_train_full, x_test, y_train_full, y_test = train_test_split(housing.data, housing.target)
x_train, x_valid, y_train, y_valid = train_test_split(x_train_full, y_train_full)

y_train, y_valid = torch.as_tensor(y_train, dtype=torch.float), torch.as_tensor(y_valid, dtype=torch.float)

scaler = StandardScaler()
x_train = torch.as_tensor(scaler.fit(x_train).transform(x_train), dtype=torch.float)
x_valid, x_test = torch.as_tensor(scaler.transform(x_valid), dtype=torch.float), torch.as_tensor(scaler.transform(x_test), dtype=torch.float)

## Model Structure 
---
The model structure to be built is:

 - Takes as input 11 dimensional feature vector
 - Splits input into first 5 elements (input_1) and last 6 elements (input_2)
 - input_1 passes through a Dense layer with 30 neurons and relu activation, and then passes through a second Dense layer with the same setup.
 - input_2 is then concatenated onto the output of second hidden layer, giving a 36 dimensional vector
 - the concatenated vector is passed through a final dense layer which is the output

 *How to do it:* Very similar to `tf.keras` implementation - just subclass on `torch.nn.Module`, then specify the network architecture in the `.forward()` method


In [4]:
# class defining the network architecture
class DeepWideNet(torch.nn.Module):
    
    def __init__(self, units=30, activation=torch.nn.ReLU(), **kwargs):
        super().__init__(**kwargs)
        self.lin1 = torch.nn.Linear(5, 30)
        self.lin2 = torch.nn.Linear(30, 30)
        self.activation = activation
        self.out = torch.nn.Linear(36, 1)
    
    def forward(self, inputs):
        # the net architecture is specified in the forward method
        # autograd will take care of backwards
        input_1, input_2 = inputs[0], inputs[1]
        input_1 = self.activation(self.lin1(input_1))   # first hidden layer
        input_1 = self.activation(self.lin2(input_1))   # second hidden layer
        input_cat = torch.cat((input_1, input_2), dim=1)
        return self.out(input_cat)

# define the loss function
loss = torch.nn.functional.mse_loss

# helper function to initiaise model and optimiser
def get_model():
    model = DeepWideNet()
    return model, torch.optim.SGD(model.parameters(), lr=0.1)

# helper function to generate batches
# NB: torch has a module for this, but doing directly for clarity
def data_generator(x, y, batch_size):
    # shuffle data
    idx = np.random.choice(range(x.shape[0]), replace=False)
    x_, y_ = x[idx], y[idx]
    num_batches = int(x.shape[0] / batch_size)
    # yield batches
    for k in range(num_batches):
        start, end = k * batch_size, (k + 1) * batch_size
        yield (x[start:end][:,:5], x[start:end][:,-6:]), y[start:end]

# define the training loop
def train_loop(model, optimiser, x, y, batch_size, epochs):
    # initialise array to store loss history
    history = []
    for epoch in range(epochs):
        dg = data_generator(x, y, batch_size)
        num_batches = int(x.shape[0] / batch_size)
        for k in tqdm(range(num_batches)):
            (x_1, x_2), y_ = next(dg)
            y_hat = model((x_1, x_2))
            l = loss(y_, y_hat)
            optimiser.zero_grad()
            l.backward()
            optimiser.step()
            history.append(l.data.numpy)


In [5]:
mod, opt = get_model()

train_loop(mod, opt, x_train, y_train, 32, 10)

100%|██████████| 362/362 [00:01<00:00, 263.15it/s]
100%|██████████| 362/362 [00:00<00:00, 363.02it/s]
100%|██████████| 362/362 [00:01<00:00, 204.12it/s]
100%|██████████| 362/362 [00:01<00:00, 198.16it/s]
100%|██████████| 362/362 [00:01<00:00, 302.53it/s]
100%|██████████| 362/362 [00:01<00:00, 309.72it/s]
100%|██████████| 362/362 [00:01<00:00, 331.33it/s]
100%|██████████| 362/362 [00:01<00:00, 318.77it/s]
100%|██████████| 362/362 [00:01<00:00, 326.85it/s]
100%|██████████| 362/362 [00:01<00:00, 295.80it/s]
