In [2]:
import load_data_multitox as ld
import dataloaders_sigma as dl
from Model_train_test_regression import Net, EarlyStopping, train_regression, train_classification, test_regression, test_classification
from sklearn.metrics import roc_auc_score


import pandas as pd
import numpy as np

import torch
from torch.utils import data as td
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter


import sys 
import os
import glob

from sklearn.model_selection import train_test_split

from tensorboardX import SummaryWriter

import time
from sklearn.preprocessing import MinMaxScaler#StandardScaler

import json


# number of conformers created for every molecule
NUM_CONFS = 100

# amount of chemical elements taking into account
AMOUNT_OF_ELEM = 9

# amount of target values
TARGET_NUM = 12

#loss penalty for classification
PENALTY = torch.FloatTensor([0.1,0.2,0.4,0.4,0.4,0.2,0.2,0.6,0.2,0.3,0.6,0.2])


#dataset folder
# DATASET_PATH="~/Tox21-MultiTox/MultiTox"
DATASET_PATH = "/gpfs/gpfs0/a.alenicheva"

TOX21_STORAGE = "../Tox21_Neural_Net"
TOX21_STORAGE = "./"

MULTITOX_STORAGE = "./"

EXPERIMENTS_DATA = "./"

#logs path
LOG_PATH=os.path.join(EXPERIMENTS_DATA, "logs_sigma_right")


#models path
MODEL_PATH=os.path.join(EXPERIMENTS_DATA, "models_sigma_right")


In [3]:
EXPERIMENT_NUM=57

In [4]:
dir_path = os.path.join(LOG_PATH,'exp_'+str(EXPERIMENT_NUM))
os.makedirs(dir_path, exist_ok=True)
LOG_PATH = dir_path
dir_path = os.path.join(MODEL_PATH,'exp_'+str(EXPERIMENT_NUM))
os.makedirs(dir_path, exist_ok=True)
MODEL_PATH = dir_path

In [5]:
path="./"

In [6]:
with open(os.path.join(path,"logs_sigma_right",'exp_'+str(52),str(52)+'_parameters.json'),'r') as f:
  args = json.load(f)

In [7]:
# args['NUM_EXP']=str(EXPERIMENT_NUM)
# args['BATCH_SIZE']=64

# args['TRANSF']='w'
# args['SIGMA_TRAIN']=False

In [8]:
args['BATCH_SIZE']=32

In [9]:
args

{'EPOCHS_NUM': 100,
 'PATIENCE': 25,
 'SIGMA': 2.2,
 'BATCH_SIZE': 32,
 'TRANSF': 'g',
 'NUM_EXP': '52',
 'VOXEL_DIM': 50,
 'LEARN_RATE': 1e-05,
 'SIGMA_TRAIN': False,
 'MODE': 'c',
 'CONTINUE': 1,
 'TRLEARNING': 0}

In [10]:
f_log=open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'w')
f_log.close()
start_time=time.time()
writer=SummaryWriter(LOG_PATH)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
    f_log.write('Using device:'+str(device)+'\n')
print()
#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

    with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
        f_log.write(torch.cuda.get_device_name(0)+'\n'+'Memory Usage:'+'\n'+'Allocated:'+str(round(torch.cuda.memory_allocated(0)/1024**3,1))+ 'GB'+'\n'+'Cached:   '+str(round(torch.cuda.memory_cached(0)/1024**3,1))+'GB'+'\n')
print('Start loading dataset...')
with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
    f_log.write('Start loading dataset...'+'\n')


Using device: cuda:0

GeForce GTX 1080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
Start loading dataset...


In [11]:
# get dataset without duplicates from csv
if args["MODE"] == 'r':
    data = pd.read_csv(os.path.join(MULTITOX_STORAGE,'database/data', 'MultiTox.csv'))
    props = list(data)
    props.remove("SMILES")
    print(props)
    scaler = MinMaxScaler()
    data[props]=scaler.fit_transform(data[props])
elif args["MODE"] == 'c':
    data = pd.read_csv(os.path.join(TOX21_STORAGE,'database/data', 'tox21_10k_data_all_no_salts.csv'))

# create elements dictionary
#     elements = ld.create_element_dict(data, amount=AMOUNT_OF_ELEM+1)
elements={'N':0,'C':1,'Cl':2,'I':3,'Br':4,'F':5,'O':6,'P':7,'S':8}

# read databases to dictionary
#     conf_calc = ld.reading_sql_database(database_dir='./dat/')
if args["MODE"] == "r":
    conf_calc = ld.reading_sql_database(database_dir=os.path.join(DATASET_PATH,"MultiTox"))
elif args["MODE"] == "c":
    conf_calc = ld.reading_sql_database(database_dir=os.path.join(DATASET_PATH,"Tox21", "elements_9"))

keys=list(conf_calc.keys())
print ('Initial dataset size = ', len(keys))
with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
    f_log.write('Initial dataset size = '+str(len(keys))+'\n')
new_conf_calc={}
for smiles in conf_calc.keys():
    for conf_num in conf_calc[smiles]:
        if smiles in new_conf_calc.keys():
            new_conf_calc[smiles][int(conf_num)]=conf_calc[smiles][conf_num]
        else:
            new_conf_calc[smiles]={}
            new_conf_calc[smiles][int(conf_num)]=conf_calc[smiles][conf_num]

conf_calc=new_conf_calc

elems = []
for key in keys:
    conformers=list(conf_calc[key].keys())
    for conformer in conformers:
        try:
            energy = conf_calc[key][conformer]['energy']
            elems = list(set(elems+list(conf_calc[key][conformer]['coordinates'].keys())))
        except:
            del conf_calc[key][conformer]
    if set(conf_calc[key].keys())!=set(range(100)):
          del conf_calc[key]
    elif conf_calc[key]=={}:
        del conf_calc[key]

print ('Post-processed dataset size = ', len(list(conf_calc.keys())))
with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
    f_log.write('Post-processed dataset size = '+str(len(list(conf_calc.keys())))+'\n')
# create indexing and label_dict for iteration
indexing, label_dict = ld.indexing_label_dict(data, conf_calc)
print('Dataset has been loaded, ', int(time.time()-start_time),' s')
with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
    f_log.write('Dataset has been loaded, '+str(int(time.time()-start_time))+' s'+'\n')

start_time=time.time()
# create train and validation sets' indexes
print('Neural network initialization...')
with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
    f_log.write('Neural network initialization...'+'\n')
train_indexes, test_indexes, _, _ = train_test_split(np.arange(0, len(conf_calc.keys())),
                                                     np.arange(0, len(conf_calc.keys())), test_size=0.2,
                                                     random_state=115)

Initial dataset size =  7489
Post-processed dataset size =  7481
Dataset has been loaded,  76  s
Neural network initialization...


In [12]:
train_set = dl.Cube_dataset(conf_calc, label_dict, elements, indexing, train_indexes, dim = args['VOXEL_DIM'])
train_generator = td.DataLoader(train_set, batch_size=args['BATCH_SIZE'], shuffle=True)

test_set = dl.Cube_dataset(conf_calc, label_dict, elements, indexing, test_indexes, dim = args['VOXEL_DIM'])
test_generator = td.DataLoader(test_set, batch_size=args['BATCH_SIZE'], shuffle=True)

if args["TRLEARNING"]:
    model = Net(dim=args["VOXEL_DIM"], num_elems=AMOUNT_OF_ELEM, num_targets=29, elements=elements, transformation=args["TRANSF"],device=device,sigma_0 = args["SIGMA"],sigma_trainable = args["SIGMA_TRAIN"], mode = args["MODE"])
else:
    model = Net(dim=args["VOXEL_DIM"], num_elems=AMOUNT_OF_ELEM, num_targets=TARGET_NUM, elements=elements, transformation=args["TRANSF"],device=device,sigma_0 = args["SIGMA"],sigma_trainable = args["SIGMA_TRAIN"], mode = args["MODE"])

In [13]:
model=model.to(device)

In [14]:
f_train_loss=open(os.path.join(LOG_PATH,args["NUM_EXP"]+'_log_train_loss.txt'),'w')
f_train_loss_ch=open(os.path.join(LOG_PATH,args["NUM_EXP"]+'_log_train_loss_channels.txt'),'w')
f_test_loss=open(os.path.join(LOG_PATH,args["NUM_EXP"]+'_log_test_loss.txt'),'w')
    
optimizer = torch.optim.Adam(model.parameters(), lr=args["LEARN_RATE"])
early_stopping = EarlyStopping(patience=args["PATIENCE"], verbose=True,model_path=MODEL_PATH)

In [15]:
def train_classification(model, optimizer, train_generator, epoch, device, batch_size, num_targets=12, PENALTY = None, writer = None,f_loss=None,f_loss_ch = None, elements=None, MODEL_PATH=None, LOGS_FILEPATH=None):
    """ Train model and write logs to tensorboard and .txt files

        Parameters
        ----------
        model
            torch.nn.Module object to train
        optimizer
            torch.optim object
        train_generator
            torch.utils.data.DataLoader object, contain iterable set of torch.Tensor data (num_elems, dim,dim,dim) and torch.Tensor labels (num_targets, )
        epoch
            number of trained epoch
        device
            torch.device
        batch_size
            size of batch
        num_targets
            number of labels in the task
        PENALTY
            vector of penalties for each label 
        writer
            tensorboardX.SummaryWriter
        f_loss
            .txt file for train loss saving
        f_loss_ch
            .txt file for aucs per target saving
        elements
            dictionary with {atom name : number} mapping

        Returns
        -------
        None
        """
    elems=dict([(elements[element], element) for element in elements.keys()])
    model.train()
    train_loss=0
    
    losses=np.zeros(num_targets)
    num_losses=np.zeros(num_targets)
    
    aucs=np.zeros(num_targets)
    num_aucs=np.zeros(num_targets)
#         b_accs=np.zeros(num_targets)
#         num_b_accs=np.zeros(num_targets)
    for batch_idx, (data, target) in enumerate(train_generator):
        data = data.to(device)
        target = target.to(device)
        if LOGS_FILEPATH is not None:
            with open(LOGS_FILEPATH,'a') as f_log:
                f_log.write('Batch , '+str(batch_idx)+'\n')
        # set gradients to zero
        optimizer.zero_grad()
        output = model(data)

        i=0
        for one_target,one_output in zip(target.cpu().t(),output.cpu().t()):
            with torch.no_grad(): 
                mask = (one_target == one_target)
                output_masked = torch.masked_select(one_output, mask).type_as(one_output)
                target_masked = torch.masked_select(one_target, mask).type_as(one_output)
                pred = output_masked.ge(0.5).type_as(one_output)
                try:
                    auc=roc_auc_score(target_masked.cpu().detach(),pred.cpu().detach())
                    aucs[i]+=auc
                    num_aucs[i]+=1
                except ValueError:
                    pass
            i+=1
        # calculate output vector
        
        # create mask to get rid of Nan's in target
        mask = (target == target)
        output_masked = torch.masked_select(output, mask).type_as(output)
        target_masked = torch.masked_select(target, mask).type_as(output)
        penalty_masked = torch.masked_select(PENALTY.to(device), mask).type_as(output)
        class_weights=(1-penalty_masked)*(target_masked).to(device)+penalty_masked
        print(target_masked, output_masked)
        loss = F.binary_cross_entropy(output_masked, target_masked,weight=class_weights)
        if LOGS_FILEPATH is not None:
            with open(LOGS_FILEPATH,'a') as f_log:
                f_log.write('loss , '+str(loss.cpu().detach().numpy().item())+'\n')
        if f_loss is not None:
            f_loss.write(str(epoch)+'\t'+str(batch_idx)+'\t'+str(loss.cpu().detach().numpy().item())+'\n')
        loss.backward()
        optimizer.step()
        train_loss+=loss.cpu().detach().numpy().item()
        if LOGS_FILEPATH is not None:
            with open(LOGS_FILEPATH,'a') as f_log:
                f_log.write('backward done '+'\n')
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_generator.dataset),
                       100. * batch_idx / len(train_generator), loss.item()))
            if MODEL_PATH is not None:
                if torch.cuda.device_count() > 1:
                    torch.save(model.module.state_dict(), os.path.join(MODEL_PATH,'checkpoint.pt'))
                else:
                    torch.save(model.state_dict(), os.path.join(MODEL_PATH,'checkpoint.pt'))
         
        print(aucs/num_aucs)         
    train_loss /= len(train_generator.dataset)
    train_loss *= batch_size
    if writer is not None:
        writer.add_scalar('Train/Loss/', train_loss, epoch)
    if torch.cuda.device_count() > 1:
        sigmas = model.module.sigma.cpu().detach().numpy()
    else:
        sigmas = model.sigma.cpu().detach().numpy()
    for idx,sigma in enumerate(sigmas):
        writer.add_scalar('Sigma/'+elems[idx], sigma, epoch)
    losses/=num_losses    
    aucs/=num_aucs 
    writer.add_scalar('Train/AUC/', np.mean(aucs), epoch)
    for i,auc in enumerate(aucs):
        if f_loss_ch is not None and loss==loss:
            f_loss_ch.write(str(epoch)+'\t'+str(batch_idx)+'\t'+str(i)+'\t'+str(auc)+'\n')
        if writer is not None:
            writer.add_scalar('Train/AUC/'+str(i), auc, epoch)
    return

In [16]:
def test_classification(model, test_generator,epoch,device,batch_size,num_targets=12,writer=None,f_loss=None, elements=None, PENALTY = None):
    """ Validation of trained model

        Parameters
        ----------
        model
            torch.nn.Module object to train
        test_generator
            torch.utils.data.DataLoader object, contain iterable set of torch.Tensor data (num_elems, dim,dim,dim) and torch.Tensor labels (num_targets, )
        epoch
            number of validated epoch
        device
            torch.device
        writer
            tensorboardX.SummaryWriter
        f_loss
            .txt file for test loss saving
        elements
            dictionary with {atom name : number} mapping

        Returns
        -------
        test_loss
            Loss for validation set
        """
    with torch.no_grad():
        model.eval()
        test_loss = 0
        correct = 0
        errors=0
        losses=np.zeros(num_targets)
        num_losses=np.zeros(num_targets)
        aucs=np.zeros(num_targets)
        num_aucs=np.zeros(num_targets)
        for batch_idx, (data, target) in enumerate(test_generator):
            data = data.to(device)
            target = target.to(device)
            output = model(data)   
            i=0
            for one_target,one_output in zip(target.cpu().t(),output.cpu().t()):
                with torch.no_grad():
                    mask = (one_target == one_target)
                    output_masked = torch.masked_select(one_output, mask).type_as(one_output)
                    target_masked = torch.masked_select(one_target, mask).type_as(one_output)
                    pred = output_masked.ge(0.5).type_as(one_output)
                    try:
                        auc=roc_auc_score(target_masked.cpu(),pred.cpu())
                        aucs[i]+=auc
                        num_aucs+=1
                    except ValueError:
                        pass

                    i+=1
            mask = (target == target)
            output_masked = torch.masked_select(output, mask).type_as(output)
            target_masked = torch.masked_select(target, mask).type_as(output)
            penalty_masked = torch.masked_select(PENALTY.to(device), mask).type_as(output)
            class_weights=(1-penalty_masked)*(target_masked).to(device)+penalty_masked

            loss = F.binary_cross_entropy(output_masked, target_masked,weight=class_weights)

            test_loss += loss
            pred = output_masked.ge(0.5).type_as(output)
            
            try:
                auc=roc_auc_score(target_masked.cpu(),pred.cpu())
                correct += auc
            except ValueError:
                errors+=1
            
        test_loss /= len(test_generator.dataset)
        test_loss *= batch_size

        print('\nTest set: Average loss: {:.4f}\n'
              .format(test_loss))
    if writer is not None:
        writer.add_scalar('Test/Loss/', test_loss, epoch)
    aucs/=num_aucs
    print(aucs)
    for i,auc in enumerate(aucs):
        if f_loss is not None and loss == loss:
            f_loss.write(str(epoch)+'\t'+str(batch_idx)+'\t'+str(i)+'\t'+str(auc)+'\n')
        if writer is not None:
            writer.add_scalar('Test/AUC/'+str(i), auc, epoch)
    return test_loss


In [17]:
# elements['H'] = 9

In [18]:
# from torch.autograd import Variable
# threshold = 5
# for i, (molecule, curr_target) in enumerate(train_generator):
#         molecule = Variable(molecule.to(device),requires_grad=True)
#         if i > threshold:
#             break
# # model = Net(dim=args_dict['VOXEL_DIM'], num_elems=AMOUNT_OF_ELEM, num_targets=TARGET_NUM, elements=elements, transformation=args_dict['TRANSF'],device=device,sigma_0 = args_dict['SIGMA'],sigma_trainable = args_dict['SIGMA_TRAIN'], x_trainable=True, x_input=torch.randn(1,9,50,50,50))
# # model=model.to(device)
# # model.load_state_dict(torch.load(os.path.join(MODEL_PATH_LOAD,'checkpoint.pt')))
# model.x_input=Parameter(molecule,requires_grad=True)

In [19]:
def plot_visualization_input_as_parameter(model,elements,losses, epoch):
    import matplotlib.pyplot as plt
    inv_elems = {v: k for k, v in elements.items()}
    with torch.no_grad():
        data=model.blur(model.x_input)
#     data=model.x_input
    molecules = data.cpu().detach().sum(dim=0)
    fig = plt.figure(figsize=(10,15),constrained_layout=True)
    gs = fig.add_gridspec(4, 3)
    for i,grad in enumerate(molecules):
        f_ax = fig.add_subplot(gs[i//3,i%3])
        f_ax.imshow(grad.sum(dim=0))
        f_ax.set_title(inv_elems[i],fontsize=25)
#         if inv_elems[i]=='C':
#             print(np.unique(grad))
    f_ax = fig.add_subplot(gs[-1, :])
    f_ax.plot(5*np.arange(0,len(losses),1),losses)
    f_ax.set_title('Loss function',fontsize=25)
    f_ax.set_xlabel('epochs',fontsize=25)
    f_ax.set_ylabel('loss',fontsize=25)
    fig.suptitle('Atom types in molecule',fontsize=25)
    
    plt.show()
#     fig.savefig(os.path.join(LOG_PATH_SAVE,'images','img_'+str(epoch))+'.png',dpi=150,format='png')
    _ = plt.clf()

In [None]:
# plot_visualization_input_as_parameter(model,elements,[], 1)

In [None]:
for epoch in range(args["CONTINUE"], args["EPOCHS_NUM"] + args["CONTINUE"]):
    print('Epoch , '+str(epoch)+'\n')
    try:
        if args["MODE"] == 'r':
            train_regression(model, optimizer, train_generator, epoch,device,writer=writer,f_loss=f_train_loss,f_loss_ch=f_train_loss_ch, elements=elements,batch_size = args["BATCH_SIZE"],MODEL_PATH=MODEL_PATH)
            test_loss = test_regression(model, test_generator,epoch, device,writer=writer,f_loss=f_test_loss, elements=elements,batch_size = args["BATCH_SIZE"])
            early_stopping(test_loss, model)

        elif args["MODE"] == 'c':
            train_classification(model, optimizer, train_generator, epoch,device,writer=writer,f_loss=f_train_loss,f_loss_ch=f_train_loss_ch, elements=elements,batch_size = args["BATCH_SIZE"],MODEL_PATH=MODEL_PATH, PENALTY = PENALTY)
            test_loss = test_classification(model, test_generator,epoch, device,writer=writer,f_loss=f_test_loss, elements=elements,batch_size = args["BATCH_SIZE"], PENALTY = PENALTY)
            early_stopping(test_loss, model)

        if early_stopping.early_stop:
            print(epoch,"Early stopping")
            break
        if epoch%10==0:
            torch.save(model.state_dict(), os.path.join(MODEL_PATH, args["NUM_EXP"]+'_model_'+str(epoch)))
    except KeyError:
        print(epoch,'Key Error problem')

Epoch , 1

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0



tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 



tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0.,
        0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 




Test set: Average loss: 0.1773

[0.03699552 0.03699552 0.05044843 0.03475336 0.04932735 0.03811659
 0.02578475 0.05269058 0.03587444 0.04035874 0.05269058 0.04596413]
Validation loss decreased (inf --> 0.177253).  Saving model ...
Epoch , 2

tensor([0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0



tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 1., 1., 0., 0., 



tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 



tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        1., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 



# Draft

In [None]:
MODEL_PATH

In [None]:
 model.load_state_dict(torch.load(os.path.join('./models_sigma_right/exp_43','checkpoint.pt')))

In [None]:
train_set = dl.Cube_dataset(conf_calc, label_dict, elements, indexing, train_indexes, dim = args['VOXEL_DIM'])
train_generator = td.DataLoader(train_set, batch_size=args['BATCH_SIZE'], shuffle=True)

test_set = dl.Cube_dataset(conf_calc, label_dict, elements, indexing, test_indexes, dim = args['VOXEL_DIM'])
test_generator = td.DataLoader(test_set, batch_size=args['BATCH_SIZE'], shuffle=True)

model = Net(dim=args['VOXEL_DIM'], num_elems=AMOUNT_OF_ELEM, num_targets=TARGET_NUM, elements=elements, transformation=args['TRANSF'],device=device,sigma_0 = args['SIGMA'],sigma_trainable = args['SIGMA_TRAIN'])


if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print ('Run in parallel!')
    with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
        f_log.write('Run in parallel!'+'\n')

# Construct our model by instantiating the class defined above

model=model.to(device)

for (batch, target) in train_generator:
    batch = batch.to(device)
    target = target.to(device)
    with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
        f_log.write('Batch to device!'+'\n')
    print('Batch to device!')
    output = model(batch)
    with open(os.path.join(LOG_PATH,args['NUM_EXP']+'_logs.txt'),'a') as f_log:
        f_log.write('Batch output!'+'\n')
    print('Batch output!')
    break

In [None]:
import os

In [None]:
os.listdir('./models_sigma_right/exp_24')