# Prepare Data Set

First, a data set is loaded. Function `load_data_from_df` automatically saves calculated features to the provided data directory (unless `use_data_saving` is set to `False`). Every next run will use the saved features.

In [None]:
import os
import pandas as pd
import torch
os.chdir('src')
import copy

In [None]:
from featurization.data_utils import load_data_from_df, construct_loader
from tqdm import tqdm

In [None]:
batch_size = 64

# Formal charges are one-hot encoded to keep compatibility with the pre-trained weights.
# If you do not plan to use the pre-trained weights, we recommend to set one_hot_formal_charge to False.
X, y = load_data_from_df('../data/freesolv/freesolv.csv', one_hot_formal_charge=True)
data_loader = construct_loader(X, y, batch_size)

You can use your data, but the CSV file should contain two columns as shown below:

In [None]:
pd.read_csv('../data/freesolv/freesolv.csv').head()

# Prepare Model

In [None]:
from transformer import make_model

In [None]:
d_atom = X[0][0].shape[1]  # It depends on the used featurization.

model_params = {
    'd_atom': d_atom,
    'd_model': 1024,
    'N': 8,
    'h': 16,
    'N_dense': 1,
    'lambda_attention': 0.33, 
    'lambda_distance': 0.33,
    'leaky_relu_slope': 0.1, 
    'dense_output_nonlinearity': 'relu', 
    'distance_matrix_kernel': 'exp', 
    'dropout': 0.0,
    'aggregation_type': 'mean'
}

model = make_model(**model_params)

# Load Pretrained Weights (optional)

If you want to use the pre-trained weights to train your model, **you should not change model parameters in the cell above**.

In [None]:
# pretrained_name = '../pretrained_weights.pt'  # This file should be downloaded first (See README.md).
# pretrained_state_dict = torch.load(pretrained_name)

In [None]:
# model_state_dict = model.state_dict()
# for name, param in pretrained_state_dict.items():
#     if 'generator' in name:
#          continue
#     if isinstance(param, torch.nn.Parameter):
#         param = param.data
#     model_state_dict[name].copy_(param)

# Run Training/Evaluation Loop

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
training_losses = []
total_losses = []
best_so_far = 1e10

In [None]:
# model.cuda()
for e in range(20):
    for batch in tqdm(data_loader):
        optimizer.zero_grad()
        adjacency_matrix, node_features, distance_matrix, y = batch
        batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
        output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
        loss = torch.mean((y - output) ** 2)
        loss.backward()
        optimizer.step()
        training_losses.append(loss.detach())
    
    with torch.no_grad():
        n_batches = 0
        cum_mse = 0
        print("Evaluating")
        for batch in tqdm(data_loader):
            adjacency_matrix, node_features, distance_matrix, y = batch
            batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
            output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
            cum_mse += torch.mean((y - output) ** 2)
            n_batches += 1
        cum_mse /= n_batches
        print("MSE for epoch: ", cum_mse)
        if cum_mse < best_so_far:
            best_so_far = cum_mse
            best_model = copy.deepcopy(model)
            best_epoch = e
            
            

In [None]:
import matplotlib.pyplot as plt
plt.plot(training_losses[14:])

In [None]:
node_features.shape

In [None]:
batch_mask[0]

In [None]:
with torch.no_grad():
    for batch in tqdm(data_loader):
        adjacency_matrix, node_features, distance_matrix, y = batch
        batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
        output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
        print(torch.stack([y[:5].squeeze(1), output[:5].squeeze(1)]))