# Move reused code into python script files
Jupyter is a great place to test out ideas and get code working. It can also be a good place to run and document your experiments. However, once your code is working, it's usually a good idea to move much of the code into python script files. This allows you to create copies of your notebook without duplicating all of the logic for how data is loaded, models are defined, and training is performed. In the end, the notebook should simply document the experiment that you performed. 

By default, jupyter does not reload imported modules. If you are editing local .py files, it's a good idea to use the `autoreload` extension to automatically reload the local files. 

In [None]:
# use autoreload because, by default, python will not re-import modules
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

## Settings 
These parameters are inputs to fitting process. We leave them in the notebook because we might change them from one experiment to the next.

In [None]:
data_dir = f"/scratch/{os.environ['USER']}/data"
model_path = f"/scratch/{os.environ['USER']}/model.pt"

# Model and Training
epochs=5 # number of training epochs
batch_size=128 #input batch size for training (default: 64)
test_batch_size=1000 #input batch size for testing (default: 1000)
num_workers=10 # parallel data loading to speed things up
lr=1.0 #learning rate (default: 1.0)
gamma=0.7 #Learning rate step gamma (default: 0.7)
no_cuda=False #disables CUDA training (default: False)
seed=42 #random seed (default: 42)
log_interval=10 #how many batches to wait before logging training status (default: 10)
save_model=False #save the trained model (default: False)

# additional derived settings
use_cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

print("Device:", device)

## Dataset
The logic for loading data will be repeated across several experiments. To avoid code duplication, we move code out of the notebook and into [a separate .py file](https://github.com/clemsonciti/rcde_workshops/blob/b0f64e2a95447a69906f5f95c9c1af2d339a6c9e/pytorch_advanced/utils/data.py). This also reduces the number of import statements needed in the notebook itself. 

In [None]:
from utils import data

# transforms (we may wish to experiment with these so leave as inputs)
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_transforms = train_transforms

train_loader = data.get_train_dataloader(data_dir, train_transforms, batch_size, num_workers)
test_loader = data.get_test_dataloader(data_dir, test_transforms, test_batch_size, num_workers)

# save a test batch for later testing
image_gen = iter(test_loader)
test_img, test_trg = next(image_gen)

In [None]:
print("Training dataset:", train_loader.dataset)
print("Testing dataset:", test_loader.dataset)

## Model definition
We move the model definitions to [a models.py file](https://github.com/clemsonciti/rcde_workshops/blob/b0f64e2a95447a69906f5f95c9c1af2d339a6c9e/pytorch_advanced/utils/models.py). This file also contains test code for developing the model. In the future we may place several different model definitions into this file, so that we can compare different architecture choices.

In [None]:
from utils import models

# Create the model
model = models.Classifier()

In [None]:
# let's make sure we can run a batch of data through the model
with torch.no_grad():
    x, y = next(iter(train_loader))
    y_hat = model(x)
    
y_hat.shape, y_hat, y_hat.sum(axis=-1)

In [None]:
model

In [None]:
print("Number of parameters:", model.num_params())

## Training and testing
We also move our training logic into [its own .py file](https://github.com/clemsonciti/rcde_workshops/blob/b0f64e2a95447a69906f5f95c9c1af2d339a6c9e/pytorch_advanced/utils/training.py).

In [None]:
from utils import training

In [None]:
model = models.Classifier().to(device)
model

In [None]:
training.train_and_test(model, train_loader, test_loader, epochs, lr, gamma, device)