# Checkpoint 2: Network training
1. Append the codes in this notebook to Checkpoint 1
2. Fill out the ANITrainer class
3. Train a model to learn from 2 small datasets (1 heavy atom, and >1 heavy atoms)
- Play with the batch size and number of epochs
- Aim for RMSE (Root Mean Squared Error) to be less than 5 kcal/mol
4. Compare the time between running on CPU and GPU. Here's the suggested workflow
- Complete all the code on your laptop and run
- Rerun on Savio with GPU to see if the time improves

In [None]:
import matplotlib.pyplot as plt
import time

def timeit(f):
    def timed(*args, **kw):
        ts = time.time()
        result = f(*args, **kw)
        te = time.time()
        print(f'func: {f.__name__} took: {te-ts:.4f} sec on {device}')
        return result
    return timed

class ANITrainer:
    def __init__(self, model, batch_size, learning_rate, epoch, l2):
        self.model = model
        
        num_params = sum(item.numel() for item in model.parameters())
        print(f"{model.__class__.__name__} - Number of parameters: {num_params}")
        
        self.batch_size = ...
        self.optimizer = ...
        self.epoch = ...
        # definition of loss function: MSE is a good choice! 
        self.loss_function = ...
    
    @timeit
    def train(self, train_data, val_data, 
              early_stop=True, draw_curve=True, verbose=True):
        self.model.train()
        
        # init data loader
        print("Initialize training data...")
        train_data_loader = ...
        
        # record epoch losses
        train_loss_list = []
        val_loss_list = []
        lowest_val_loss = np.inf
        
        if verbose:
            iterator = range(self.epoch)
        else:
            iterator = tqdm(range(self.epoch), leave=True)
        
        for i in iterator:
            train_epoch_loss = 0.0
            for train_data_batch in train_data_loader:
                
                # compute energies
                ...
                
                # compute loss
                batch_loss = ...
                
                # do a step
                ...
                
                batch_importance = ...
                train_epoch_loss += ...
            
            # use the self.evaluate to get loss/MAE/RMSE on the validation set 
            val_epoch_loss, mae, rmse = ...
            
            # append the losses
            ...
            
            if early_stop:
                if val_epoch_loss < lowest_val_loss:
                    lowest_val_loss = val_epoch_loss
                    weights = self.model.state_dict()
        
        if draw_curve:
            # Plot train loss and validation loss
            fig, ax = plt.subplots(1, 1, figsize=(5, 4), constrained_layout=True)
            # If you used MSELoss above to compute the loss
            # Calculate the RMSE for plotting
            ...
            ax.plot(..., ..., label='Train')
            ax.plot(..., ..., label='Validation')
            ax.legend()
            ax.set_xlabel("Epoch")
            ax.set_ylabel("RMSE")
        
        if early_stop:
            self.model.load_state_dict(weights)
        
        return train_loss_list, val_loss_list
    
    
    def evaluate(self, data, draw_plot=False):
        
        # init data loader
        data_loader = ...
        total_loss = 0.0
        
        # init energies containers
        true_energies_all = []
        pred_energies_all = []
            
        with torch.no_grad():
            for batch_data in data_loader:
                
                # compute energies
                ...
                
                # compute loss
                batch_loss = ...

                batch_importance = ...
                total_loss += ...
                
                # store true and predicted energies
                true_energies_all.append(true_energies.detach().cpu().numpy().flatten())
                pred_energies_all.append(pred_energies.detach().cpu().numpy().flatten())
        true_energies_all = np.concatenate(true_energies_all)
        pred_energies_all = np.concatenate(pred_energies_all)

        # Report the mean absolute error (MAE) and root mean square error (RMSE)
        # The unit of energies in the dataset is hartree
        # please convert it to kcal/mol when reporting
        # 1 hartree = 627.5094738898777 kcal/mol
        # MAE = mean(|true - pred|)
        # RMSE = sqrt(mean( (true-pred)^2 ))
        hartree2kcalmol = ...
        mae = ... 
        rmse = ...

        if draw_plot:
            fig, ax = plt.subplots(1, 1, figsize=(5, 4), constrained_layout=True)
            ax.scatter(true_energies_all, pred_energies_all, label=f"MAE: {mae:.2f} kcal/mol, RMSE: {rmse:.2f} kcal/mol", s=2)
            ax.set_xlabel("Ground Truth")
            ax.set_ylabel("Predicted")
            xmin, xmax = ax.get_xlim()
            ymin, ymax = ax.get_ylim()
            vmin, vmax = min(xmin, ymin), max(xmax, ymax)
            ax.set_xlim(vmin, vmax)
            ax.set_ylim(vmin, vmax)
            ax.plot([vmin, vmax], [vmin, vmax], color='red')
            ax.legend()
            
        return total_loss, mae, rmse

## 1 heavy atom

In [None]:
# Load dataset with 1 heavy atom
# Then do a train/val/test = 80/10/10 split
dataset = load_ani_dataset("...")
train_data, val_data, test_data = ...
print(f'Train/Total: {len(train_data)}/{len(dataset)}')

# Define the model
model = nn.Sequential(
    aev_computer,
    ani_net
).to(device)

# Initiate the trainer and evaluate on test_dataset with draw_plot=True
trainer = ANITrainer(model, ..., 1e-3, ..., 1e-5)
loss, mae, rmse = ...

In [None]:
# Run on CPU
# Perform training and re-evaluate on test_dataset with draw_plot=True
train_losses, val_losses = trainer.train(train_data, val_data, verbose=True)
loss, mae, rmse = ...

In [None]:
# Run on GPU
# Perform training and re-evaluate on test_dataset with draw_plot=True
train_losses, val_losses = trainer.train(train_data, val_data, verbose=True)
loss, mae, rmse = ...

## n heavy atoms

In [None]:
# Load dataset with n (different from 1) heavy atom
# Then do a train/val/test = 80/10/10 split
dataset = load_ani_dataset("...")
train_data, val_data, test_data = ...
print(f'Train/Total: {len(train_data)}/{len(dataset)}')

# Define the model
model = nn.Sequential(
    aev_computer,
    ani_net
).to(device)

# Initiate the trainer and evaluate on test_dataset with draw_plot=True
trainer = ANITrainer(model, ..., 1e-3, ..., 1e-5)
loss, mae, rmse = ...

In [None]:
# Run on CPU
# Perform training and re-evaluate on test_dataset with draw_plot=True
train_losses, val_losses = trainer.train(train_data, val_data, verbose=True)
loss, mae, rmse = ...

In [None]:
# Run on GPU
# Perform training and re-evaluate on test_dataset with draw_plot=True
train_losses, val_losses = trainer.train(train_data, val_data, verbose=True)
loss, mae, rmse = ...