In [None]:
import pickle
from data import rxn
from data_process import *
from torch.utils.data import Dataset, DataLoader
from models import *
import dgl

In [None]:
import random
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=50, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default = 0.001, help='learning rate.')
parser.add_argument('--model', type=str, default="GNN", help = 'MLP, GNN')
parser.add_argument('--dev', type=int, default=7)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--data_path', type=str)

args = parser.parse_args([])
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda:"+str(args.dev) if torch.cuda.is_available() else "cpu")

In [None]:
with open('./normal_ids.pkl', 'rb') as file:
    train_ids = pickle.load(file)
with open('./train_uncertain_ids.pkl', 'rb') as file:
    uncertain_ids = pickle.load(file)
with open('./test_clean_ids.pkl', 'rb') as file:
    test_ids = pickle.load(file)
with open('./test_uncertain_ids.pkl', 'rb') as file:
    test_u_ids = pickle.load(file)

data_processer = data_process(args.data_path)

train_ids = data_processer.load_data(train_ids)
uncertain_data = data_processer.load_data(uncertain_ids)
test_u_data = data_processer.load_data(test_u_ids)

In [None]:
mean = train_ids['yields'].mean()
std = train_ids['yields'].std()
input_dim = train_ids['features'].shape[1]
def norm_label(labels):
    return (labels - mean)/std

In [None]:
class feas_ds(Dataset):
    def __init__(self, X_f, Y):
        """
        X: feature
        Y: yields
        """
        self.X_f = X_f
        self.Y = Y
    def __len__(self):
        return self.Y.shape[0]
    
    def __getitem__(self, idx):
        return  (self.X_f[idx][0], self.Y[idx])

class graph_ds(Dataset):
    def __init__(self, X_g, Y):
        """
        X: feature
        Y: yields
        """
        self.X_g = np.array(X_g)
        self.Y = Y
    def __len__(self):
        return self.Y.shape[0]
    
    def __getitem__(self, idx):
        # print(self.X_g[idx])
        g1 = self.X_g[idx][0]
        g2 = self.X_g[idx][1]
        g3 = self.X_g[idx][2]
        return  (g1, g2, g3, self.Y[idx])
    
def collate_reaction(batch):
    batchdata = list(map(list, zip(*batch)))
    gs = [dgl.batch(s) for s in batchdata[:3]]
    labels = torch.FloatTensor(batchdata[-1])
    return gs, labels

In [None]:
if args.model == 'MLP':
    train_loader = DataLoader(feas_ds(train_ids['features'], train_ids['yields']), \
                                  batch_size=256)
    test_u_loader = DataLoader(feas_ds(test_u_data['features'], test_u_data['yields']), \
                                  batch_size=256)
    model = MLP(input_dim,1024,6)

else:
    train_loader = DataLoader(graph_ds(train_ids['graphs'], train_ids['yields']), \
                                  batch_size=256, collate_fn=collate_reaction)
    test_u_loader = DataLoader(graph_ds(test_u_data['graphs'], test_u_data['yields']), \
                                  batch_size=256, collate_fn=collate_reaction)
    model = reactionMPNN(11,3)


In [None]:
ckptpath = './logs/trained_{}.ckpt'.format(args.model)
optimizer =torch.optim.AdamW(model.parameters(), lr = args.lr)
best_loss = 12345
model.to(device)
criterion = nn.L1Loss()

for epoch in tqdm(range(args.epochs)):
    train_loss = 0
    for datas, label in train_loader:
        model.train()
        optimizer.zero_grad()
        if args.model == "MLP":
            datas = datas.to(torch.float32).to(device)
        else:
            datas  = [graph.to(device) for graph in datas]
        label = norm_label(label).to(device).reshape(-1, 1)
        out = model(datas)
        loss = criterion(out,label)
        loss.backward()
        optimizer.step()
        train_loss += loss.detach().cpu().numpy()
    if train_loss<best_loss:
        torch.save(model.state_dict(), ckptpath)


In [None]:
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
def valid(model, data_loader):
    model.eval()
    Ys, Y_hats = [], []
    with torch.no_grad():
        for xs, ys in data_loader:
            if args.model == "MLP":
                xs = xs.to(torch.float32).to(device)
            else:
                xs  = [graph.to(device) for graph in xs]
            ys = ys.reshape(-1, 1).numpy()
            out = model(xs).detach().to("cpu").numpy()
            Ys.append(ys)
            Y_hats.append(out)
    Ys = np.concatenate(Ys, axis=0) * std + mean
    Y_hats = np.concatenate(Y_hats, axis=0) * std + mean
    r2 = r2_score(Ys, Y_hats)
    mae = mean_absolute_error(Ys, Y_hats)
    rmse = mean_squared_error(Ys, Y_hats, squared=False)
    return (r2, mae, rmse)

In [None]:
model.load_state_dict(torch.load(ckptpath, map_location=device))
model = model.to(device)
valid(model, test_u_loader)