In [None]:
import torch
import matplotlib.pyplot as plt

from gnn.src.data.dataset import InMemoQM9Dataset
from gnn.src.train.trainer import Trainer

from gnn.src.nn.schnet import SchNet
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from gnn.src.data.utils import get_property

In [None]:
dataset = InMemoQM9Dataset('./records')

In [None]:
trainer = Trainer(
    dataset=dataset,
    device='cuda',
    validation_split = 0.1,
    shuffle_dataset = True,
    dataset_size=len(dataset)
)

In [None]:
model=SchNet(n_interactions=1, n_features=64, n_filters=64)
loss_function=torch.nn.MSELoss()
optimizer=Adam(model.parameters(), lr=10e-6)
scheduler = ExponentialLR(optimizer, gamma=0.9)

losses_train_mean, losses_val_mean = trainer.train(
    model=model,
    loss_function=loss_function,
    optimizer=optimizer,
    scheduler=scheduler,
    max_epochs=100,
    models_dir='./models',
    checkpoint_frequency=10,
    property='α',
    map_property=lambda t: get_property(t, "α"),
    aggregate_outputs=lambda outputs: torch.sum(outputs)
)
print(losses_train_mean, losses_val_mean)

### loss

<details>
  <summary>losses_train_mean</summary>
  [58.839726634226075, 5.518057584098331, 4.24709673092208, 3.781802625836516, 3.545804324713805, 3.3893986223054235, 3.260820652037782, 3.1640724792832753, 3.0792549730140717, 3.023277406395719, 2.9441381441288965, 2.900364953753358, 2.854506544550567, 2.81282514913081, 2.777761020674224, 2.7517463184699604, 2.7123476533829787, 2.6780886802078014, 2.6562178363268907, 2.6275893861338804, 2.585617992756757, 2.5585345910539976, 2.5361176760724646, 2.512520349982907, 2.491692825757837, 2.4647025621262886, 2.442452179340332, 2.421791355421118, 2.3927696472686475, 2.371247854624981, 2.3372035594232043, 2.3163704075814695, 2.2958879049202663, 2.2798291452965427, 2.2605932966211855, 2.2410771305552717, 2.22353428746129, 2.2021394205725815, 2.1866108476307007, 2.169274944522066, 2.1407345424446533, 2.125170769976373, 2.1053828589371664, 2.0935869917408505, 2.080493062096882, 2.0676965031655268, 2.0560521894114996, 2.0365151264835664, 2.0263295594218227, 2.011375386971819, 1.995146867729041, 1.9842671514885397, 1.9731319123726498, 1.9596805482393798, 1.9515715246343102, 1.9431638065633101, 1.9348935881057026, 1.919880099093062, 1.9125202779252373, 1.8993775171988443, 1.8846098125401751, 1.8792307032650168, 1.8717218237263837, 1.863527422024617, 1.8556848455274686, 1.8459508698408218, 1.842926103158485, 1.8357483098326492, 1.8303867314445783, 1.8215086307387254, 1.8074312125311693, 1.8016097572988354, 1.7968493412439401, 1.7903152340131, 1.784299395986577, 1.7766383430193293, 1.7729073097746957, 1.7660229937770457, 1.7627034012585185, 1.7555866638054713, 1.7432151765650972, 1.739609215557068, 1.7339853590706575, 1.7329251916975505, 1.7249594533431316, 1.7232135452066701, 1.71928586311498, 1.7126707510572206, 1.7075514769571443, 1.7039423026630423, 1.6938724343156015, 1.6899929890286505, 1.6872020201227156, 1.6818807813589671, 1.6797516750320323, 1.6745074363198023, 1.6713629234169176, 1.6681738951994916, 1.6624287065476886, 1.6595791729733538]
</details>


<details>
  <summary>losses_val_mean</summary>
  [6.696786389716581, 4.595338658394637, 4.027166699168811, 3.293074002205058, 3.2214907036797857, 2.950677817419746, 2.894027506832193, 3.2400507179663496, 2.784918612820899, 2.9258770902068223, 2.8077502372997856, 2.570631320443605, 2.724896416755456, 2.99020351985809, 2.5315422687038773, 2.42541631338714, 2.4139361921462914, 2.683047888729304, 2.3898569406221295, 2.3672316227692702, 2.4981171000864704, 2.3980367186605296, 2.371855362616904, 2.28409361816778, 2.315909606893592, 2.2666648628660258, 2.3013579498218855, 2.3958109192769257, 2.1966333753318983, 2.137990762676618, 2.255785764426122, 2.1093587500477287, 2.3399982763565883, 2.0705811582691394, 2.1146829906558904, 2.036245481788107, 2.4988876881651634, 2.1227913058721435, 2.020359470826986, 1.9695439726379624, 1.9658679475954433, 1.9521754603578527, 1.9504401019112005, 1.9502745068025766, 1.9598724165349812, 1.9216700351532763, 1.9934616280009962, 2.1256215388281503, 1.8610406403408095, 2.0023100504697964, 1.8637312236919974, 1.8614594564863627, 1.8886796107450519, 1.8737175952460738, 1.8227736833185915, 1.839103752560301, 1.8097475129512932, 1.7958113342515465, 1.9722782856045602, 1.9048077512152704, 1.7929918486892724, 1.781021163151278, 1.917300701937728, 1.7291276022279707, 1.8181967555433691, 1.7697902621283277, 1.9118928093341947, 1.7239480367655333, 1.7684764062455272, 1.7224757668349984, 1.7427164079134985, 1.7463040363205593, 1.8400747520499254, 1.8270545686951807, 1.6998314926451141, 1.6848566944340813, 1.6821685016403702, 1.6563385632058971, 1.7223719709146055, 1.6457564883237064, 1.6808052972634056, 1.7311168761135451, 1.8707231705363712, 1.6300577001674874, 1.6423261352896286, 1.6501248169935414, 1.6309168168798656, 1.6180047249594114, 1.6109822557427957, 1.606718152660372, 1.602412010450738, 1.6657163882108648, 1.6570539034989977, 1.6026585765426202, 1.5874425886427144, 1.578725052806119, 1.615730784177044, 1.584430156137688, 1.5708787461062517, 1.580210113155297]
</details>

In [None]:
losses_train_mean = [58.839726634226075, 5.518057584098331, 4.24709673092208, 3.781802625836516, 3.545804324713805, 3.3893986223054235, 3.260820652037782, 3.1640724792832753, 3.0792549730140717, 3.023277406395719, 2.9441381441288965, 2.900364953753358, 2.854506544550567, 2.81282514913081, 2.777761020674224, 2.7517463184699604, 2.7123476533829787, 2.6780886802078014, 2.6562178363268907, 2.6275893861338804, 2.585617992756757, 2.5585345910539976, 2.5361176760724646, 2.512520349982907, 2.491692825757837, 2.4647025621262886, 2.442452179340332, 2.421791355421118, 2.3927696472686475, 2.371247854624981, 2.3372035594232043, 2.3163704075814695, 2.2958879049202663, 2.2798291452965427, 2.2605932966211855, 2.2410771305552717, 2.22353428746129, 2.2021394205725815, 2.1866108476307007, 2.169274944522066, 2.1407345424446533, 2.125170769976373, 2.1053828589371664, 2.0935869917408505, 2.080493062096882, 2.0676965031655268, 2.0560521894114996, 2.0365151264835664, 2.0263295594218227, 2.011375386971819, 1.995146867729041, 1.9842671514885397, 1.9731319123726498, 1.9596805482393798, 1.9515715246343102, 1.9431638065633101, 1.9348935881057026, 1.919880099093062, 1.9125202779252373, 1.8993775171988443, 1.8846098125401751, 1.8792307032650168, 1.8717218237263837, 1.863527422024617, 1.8556848455274686, 1.8459508698408218, 1.842926103158485, 1.8357483098326492, 1.8303867314445783, 1.8215086307387254, 1.8074312125311693, 1.8016097572988354, 1.7968493412439401, 1.7903152340131, 1.784299395986577, 1.7766383430193293, 1.7729073097746957, 1.7660229937770457, 1.7627034012585185, 1.7555866638054713, 1.7432151765650972, 1.739609215557068, 1.7339853590706575, 1.7329251916975505, 1.7249594533431316, 1.7232135452066701, 1.71928586311498, 1.7126707510572206, 1.7075514769571443, 1.7039423026630423, 1.6938724343156015, 1.6899929890286505, 1.6872020201227156, 1.6818807813589671, 1.6797516750320323, 1.6745074363198023, 1.6713629234169176, 1.6681738951994916, 1.6624287065476886, 1.6595791729733538]
losses_val_mean = [6.696786389716581, 4.595338658394637, 4.027166699168811, 3.293074002205058, 3.2214907036797857, 2.950677817419746, 2.894027506832193, 3.2400507179663496, 2.784918612820899, 2.9258770902068223, 2.8077502372997856, 2.570631320443605, 2.724896416755456, 2.99020351985809, 2.5315422687038773, 2.42541631338714, 2.4139361921462914, 2.683047888729304, 2.3898569406221295, 2.3672316227692702, 2.4981171000864704, 2.3980367186605296, 2.371855362616904, 2.28409361816778, 2.315909606893592, 2.2666648628660258, 2.3013579498218855, 2.3958109192769257, 2.1966333753318983, 2.137990762676618, 2.255785764426122, 2.1093587500477287, 2.3399982763565883, 2.0705811582691394, 2.1146829906558904, 2.036245481788107, 2.4988876881651634, 2.1227913058721435, 2.020359470826986, 1.9695439726379624, 1.9658679475954433, 1.9521754603578527, 1.9504401019112005, 1.9502745068025766, 1.9598724165349812, 1.9216700351532763, 1.9934616280009962, 2.1256215388281503, 1.8610406403408095, 2.0023100504697964, 1.8637312236919974, 1.8614594564863627, 1.8886796107450519, 1.8737175952460738, 1.8227736833185915, 1.839103752560301, 1.8097475129512932, 1.7958113342515465, 1.9722782856045602, 1.9048077512152704, 1.7929918486892724, 1.781021163151278, 1.917300701937728, 1.7291276022279707, 1.8181967555433691, 1.7697902621283277, 1.9118928093341947, 1.7239480367655333, 1.7684764062455272, 1.7224757668349984, 1.7427164079134985, 1.7463040363205593, 1.8400747520499254, 1.8270545686951807, 1.6998314926451141, 1.6848566944340813, 1.6821685016403702, 1.6563385632058971, 1.7223719709146055, 1.6457564883237064, 1.6808052972634056, 1.7311168761135451, 1.8707231705363712, 1.6300577001674874, 1.6423261352896286, 1.6501248169935414, 1.6309168168798656, 1.6180047249594114, 1.6109822557427957, 1.606718152660372, 1.602412010450738, 1.6657163882108648, 1.6570539034989977, 1.6026585765426202, 1.5874425886427144, 1.578725052806119, 1.615730784177044, 1.584430156137688, 1.5708787461062517, 1.580210113155297]

fig, axs = plt.subplots(nrows=2)
plt.tight_layout()
axs[0].plot(range(100), losses_train_mean, c='g')
axs[0].plot(range(100), losses_val_mean, c='b')
axs[0].legend (('Обучение', 'Валидация'))
axs[0].set(xlabel='Число эпох', ylabel='MSE α')
axs[1].plot(range(10, 100), losses_train_mean[10:100], c='g')
axs[1].plot(range(10, 100), losses_val_mean[10:100], c='b')
axs[1].legend (('Обучение', 'Валидация'))
axs[1].set(xlabel='Число эпох', ylabel='MSE α')
print(losses_val_mean[-1])
plt.show()