In [5]:
import pickle
import re
import numpy as np
import sys
import os
import glob
import networkx as nx
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 

import torch
import torch.nn.functional as F

In [6]:
with open('/xdisk/twheeler/jgaiser/deepvs/deepvs/data/protein_config.yaml', 'r') as config_file:
    protein_config = yaml.safe_load(config_file) 

PROTEIN_ATOM_LABELS = protein_config['atom_labels']
PROTEIN_EDGE_LABELS = protein_config['edge_labels']
INTERACTION_LABELS = protein_config['interaction_labels']

dummy_index = PROTEIN_ATOM_LABELS.index('DUMMY')
voxel_edge_index = PROTEIN_EDGE_LABELS.index('voxel')
model_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/models/pocket_embed/'

pocket_graph_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/pockets/'
mol_graph_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/molecules/'
positive_voxel_graph_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/graph_data/voxel_positives/'
pocket_embed_dir = '/xdisk/twheeler/jgaiser/deepvs/deepvs/data/pocket_embeds/'

pocket_file_structure = pocket_graph_dir + "%s_pocket_graph.pkl"
positive_voxel_graph_file_structure = positive_voxel_graph_dir + "%s_voxel_positive_graphs.pkl"
mol_file_structure = mol_graph_dir + "%s_mol_graph.pkl"
pocket_embed_file_structure = pocket_embed_dir + "%s_pocket_embed.pkl"

holdout_complexes = ["3gdt", "3g1v", "3w07", "3g1d", "1loq", "3wjw", "2zz1", "2zz2", "1km3", "1x1z", 
                     "6cbg", "5j7q", "6cbf", "4wrb", "6b1k", "5hvs", "5hvt", "3rf5", "3rf4", "1mfi", 
                     "5efh", "6csq", "5efj", "6csr", "6css", "6csp", "5een", "5ef7", "5eek", "5eei",
                     "3ozt", "3u81", "4p58", "5k03", "3ozr", "3ozs", "3oe5", "3oe4", "3hvi", "3hvj",
                     "3g2y", "3g2z", "3g30", "3g31", "3g34", "3g32", "4de2", "3g35", "4de0", "4de1",
                     "2exm", "4i3z", "1e1v", "5jq5", "1jsv", "1e1x", "4bcp", "4eor", "1b38", "1pxp", "2xnb", "4bco", "4bcm", "1pxn", "4bcn", "1h1s", "4bck", "2fvd", "1pxo", "2xmy",
                     "4xoe", "5fs5", "1uwf", "4att", "4av4", "4av5", "4avh", "4avj", "4avi", "4auj", "4x50", "4lov", "4x5r", "4buq", "4x5p", "4css", "4xoc", "4cst", "4xo8", "4x5q",
                     "1gpk", "3zv7", "1gpn", "5bwc", "5nau", "5nap", "1h23", "1h22", "1e66", "4m0e", "4m0f", "2ha3", "2whp", "2ha6", "2ha2", "1n5r", "4arb", "4ara", "5ehq", "1q84",
                     "2z1w", "3rr4", "1s38", "1q65", "4q4q", "4q4p", "4q4r", "4kwo", "1r5y", "4leq", "4lbu", "1f3e", "4pum", "4q4s", "3gc5", "2qzr", "4q4o", "3gc4", "5jxq", "3ge7"]

In [30]:
PDB_IDS = []

for item in glob.glob(mol_graph_dir + "*"):
    pdb_id = item.split('/')[-1].split('_')[0]
    
    if pdb_id in holdout_complexes:
        continue
        
    PDB_IDS.append(pdb_id)

PDB_IDS = np.array(sorted(PDB_IDS))

In [9]:
class_weights = torch.tensor([max(class_count)/x for x in class_count])
print(class_weights)

tensor([ 47.0805,   1.4693,   1.9435,   1.0000, 113.7626,  45.2585,   9.2728,
          9.8441,  28.5378], dtype=torch.float64)


In [10]:
print(PROTEIN_EDGE_LABELS)

['DOUBLE', 'RING', 'SINGLE', 'TRIPLE', 'covalent', 'interaction', 'spatial', 'voxel']


In [11]:
GRAPH_NODE_COUNT = 12

self_edge_indices = torch.arange(GRAPH_NODE_COUNT)
self_edge_indices = torch.vstack((self_edge_indices, self_edge_indices))

self_edge_attr = torch.vstack([torch.zeros(9)]*GRAPH_NODE_COUNT)

In [12]:
def fetch_positive_voxel_corpus(pdb_ids, corpus_size, batch_size):
    sample_ids = np.random.choice(pdb_ids, corpus_size)
    corpus = []
    
    for s_id in sample_ids:
        if s_id in holdout_complexes:
            continue
            
        graph_samples = pickle.load(open(positive_voxel_graph_file_structure % s_id, 'rb'))
        
        for item in graph_samples:
            item.y = torch.unsqueeze(item.y, dim=0)
            item.edge_index = torch.hstack( (item.edge_index, self_edge_indices) )
            item.edge_attr = torch.vstack( (item.edge_attr, self_edge_attr) )
            
        corpus.extend(graph_samples)
     
    return DataLoader(corpus, batch_size=batch_size, shuffle=True) 

loader = fetch_positive_voxel_corpus(PDB_IDS, 1000, 16)

In [13]:
hidden = 512 
INTERACTION_TYPES = protein_config['interaction_labels']
NODE_DIMS = 38 
EDGE_DIMS = 9
DUMMY_INDEX = protein_config['atom_labels'].index('DUMMY')
MAX_EDGE_WEIGHT = 15.286330223083496

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(NODE_DIMS, hidden)
        
        self.conv1 = GCN2Conv(hidden, 0.2, add_self_loops=False)
        self.conv2 = GCN2Conv(hidden, 0.2, add_self_loops=False)
        self.conv3 = GCN2Conv(hidden, 0.2, add_self_loops=False)

        self.linear2 = torch.nn.Linear(hidden, len(INTERACTION_TYPES))

    def forward(self, data):
        x, edge_index, edge_weights = data.x[:,:-3], data.edge_index, data.edge_attr[:,-1] / MAX_EDGE_WEIGHT

        x = self.linear1(x)

        h = self.conv1(x, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.conv2(h, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.conv3(h, x, edge_index, edge_weights)
        h = F.relu(h)
        
        o = self.linear2(h)
        return o

model = GCN()

# criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(interaction_weights))
criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
DUMMY_INDEX = np.where(np.array(PROTEIN_ATOM_LABELS)=='DUMMY')[0].item()

for batch in loader:
    dummy_mask = torch.where(batch.x[:, DUMMY_INDEX]==1)
    out = model(batch)[dummy_mask]
    print(criterion(out, batch.y))
    break


tensor(0.7601, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


In [42]:
hidden = 512 

class PoxelGCN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = GCN2Conv(hidden, 0.1)
        self.conv2 = GCN2Conv(hidden, 0.1)

        self.pool1 = SAGPooling(hidden)

        self.conv3 = GCN2Conv(hidden, 0.2)
        self.conv4 = GCN2Conv(hidden, 0.2)

        self.pool2 = SAGPooling(hidden)

        # self.conv5 = GCN2Conv(hidden, 0.3)

        # self.pool3 = SAGPooling(1024)

        # self.conv6 = GCN2Conv(hidden, 0.3)

        # self.gate1 = torch.nn.Linear(1024, 1)

        # self.pooling_nn = torch.nn.Linear(1024, 1024)
        
        # self.global_pool = GlobalAttention(self.gate1, self.pooling_nn)
        gate_nn = MLP([512, 1], act='relu')
        nn = MLP([512, 512], act='relu')
        self.global_pool = AttentionalAggregation(gate_nn, nn)
        # self.global_pool = EquilibriumAggregation(1024,  1024, num_layers=[1024,1024])
        # self.global_pool = MeanAggregation()
        # self.global_pool = torch_geometric.nn.global_mean_pool

    def forward(self, data):
        x, edge_index, edge_weights, batch = data.x, data.edge_index, data.edge_attr, data.batch
        h = self.conv1(x, x, edge_index, edge_weights)
        h = F.relu(h)

        h = self.conv2(h, x, edge_index, edge_weights)
        h = F.relu(h)

        h, edge_index, edge_weights, batch, _, _ = self.pool1(h, edge_index, edge_weights, batch)

        h = self.conv3(h, x, edge_index, edge_weights)
        h = F.relu(h)

        h, edge_index, edge_weights, batch, _, _ = self.pool2(h, edge_index, edge_weights, batch)

        h = self.conv4(h, x, edge_index, edge_weights)
        h = F.relu(h)


        # h = self.conv5(h, x, edge_index, edge_weights)
        # h = F.relu(h)

        # h, edge_index, edge_weights, batch, _, _ = self.pool3(h, edge_index, edge_weights, batch)

        # h = self.conv6(h, x, edge_index, edge_weights)
        
        # h = self.global_pool(h, batch)
        h = self.global_pool(h, index=batch)
        return h

poxel_model = PoxelGCN()

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters())
# criterion = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))
criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights.to(device))
batch_size=64

batch_idx = 0

for epoch in range(5000):
    for corpus_partition_index in range(25):
        vox_corpus = fetch_positive_voxel_corpus(PDB_IDS, int(len(PDB_IDS)/25), batch_size)
        avg_loss = []
        print("EPOCH %s" % epoch)
        
        for batch in vox_corpus:
            optimizer.zero_grad()
            dummy_mask = torch.where(batch.x[:, DUMMY_INDEX]==1)
            batch = batch.to(device)
            dummy_indices = torch.where(batch.x[:,DUMMY_INDEX] == 1)

            out = model(batch)[dummy_mask]
            
            loss = criterion(out, batch.y)
            avg_loss.append(loss.item())
            loss.backward()
            optimizer.step()
            batch_idx += 1

        print("Average loss:", sum(avg_loss) / len(avg_loss))
        avg_loss = []

EPOCH 0


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm)

In [None]:
# torch.save(model.state_dict(), model_dir+"pocket_embed_12-1.m")
model.state_dict().cpu()

In [23]:
def fetch_training_batch(pdb_ids, batch_size):
    sample_ids = np.random.choice(pdb_ids, batch_size)
    pocket_graphs = [] 
    mol_graphs = []
    decoy_graphs = []
    
    for s_id in sample_ids:
        pocket_graphs.append(pickle.load(open(pocket_embed_file_structure % s_id, 'rb')))
        mol_graphs.append(pickle.load(open(mol_file_structure % s_id, 'rb')))
        decoy_graphs.append(pickle.load(open(mol_file_structure % np.random.choice(PDB_IDS), 'rb')))
    
    mol_graphs += decoy_graphs
    
    pocket_loader = DataLoader(pocket_graphs, batch_size=batch_size, shuffle=False)
    mol_loader = DataLoader(mol_graphs, batch_size=batch_size*2, shuffle=False)
    
    return next(iter(pocket_loader)), next(iter(mol_loader))


In [35]:
interaction_count = [0 for x in INTERACTION_LABELS]

for s_id in PDB_IDS:
    g = (pickle.load(open(mol_file_structure % s_id, 'rb')))
    interacting_voxel_indices = torch.where(torch.sum(g.y, dim=1) > 0)[0]
    for idx in torch.where(g.y[interacting_voxel_indices] > 0)[1]:
        interaction_count[idx.item()] += 1
        
print(INTERACTION_LABELS)
print(interaction_count)

['halogenbond', 'hbond_a', 'hbond_d', 'hydroph_interaction', 'pication_c', 'pication_r', 'pistack', 'saltbridge_n', 'saltbridge_p']
[1296, 40180, 30520, 63479, 540, 7661, 36100, 6291, 2137]


In [36]:
interaction_weights = [1/(x/max(interaction_count)) for x in interaction_count]
print(interaction_weights)

[48.98070987654321, 1.5798656047784967, 2.0799148099606817, 1.0, 117.55370370370372, 8.285993995561936, 1.758421052631579, 10.090446669845813, 29.704726251754796]


In [38]:
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool

class AtomicAttentiveFP(AttentiveFP):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.voxel_classifier = nn.Linear(kwargs['hidden_channels'], 9)
        
    def forward(self, x, edge_index, edge_attr, batch):
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.voxel_classifier(x), self.lin2(out)

In [40]:
## MOLECULE GRAPH POOLING 
#@title Molecule Pooling

mol_hidden = 512 

class MoleculePool(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = AtomicAttentiveFP(in_channels=52, 
                                 hidden_channels=mol_hidden, 
                                 out_channels=mol_hidden,
                                 edge_dim=10,
                                 num_layers=5,
                                 num_timesteps=5,
                                 dropout=0.0)
    
        
    
    def forward(self, data):
        x, edge_index, edge_weights, batch = data.x, data.edge_index, data.edge_attr, data.batch
        h = self.conv1(x, edge_index, edge_weights, batch)
        return h
    
model = MoleculePool()
pocket_batch, mol_batch = fetch_training_batch(PDB_IDS, 32)

criterion = nn.BCEWithLogitsLoss()
atom_embeds, mol_embed = model(mol_batch)
interacting_voxel_indices = torch.where(torch.sum(mol_batch.y, dim=1) > 0)[0]
criterion(atom_embeds[interacting_voxel_indices], mol_batch.y[interacting_voxel_indices])

tensor(0.6980, dtype=torch.float64,
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [None]:
# BATCH_SIZE = 32
# sigmoid = nn.Sigmoid()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(interaction_weights)).to(device)
# model = MoleculePool().to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# def mol_training_loop(epochs, batch_size):
#     for epoch in range(epochs):
#         print("EPOCH %s" % epoch)
        
#         for batch_idx in range(int(len(PDB_IDS) / BATCH_SIZE)):
#             loss_history = []
#             pocket_batch, mol_batch = fetch_training_batch(PDB_IDS, 32)
#             mol_batch = mol_batch.to(device)
            
#             optimizer.zero_grad() 

#             atom_embeds, mol_embed = model(mol_batch)

#             interacting_voxel_indices = torch.where(torch.sum(mol_batch.y, dim=1) > 0)[0]
#             loss = criterion(atom_embeds[interacting_voxel_indices], 
#                              mol_batch.y[interacting_voxel_indices].to(device))
                               
            
#             loss_history.append(loss.item())
#             loss.backward()
#             optimizer.step()
                    
#             if batch_idx % 10 == 0:
#                 print(sum(loss_history) / len(loss_history))
#                 loss_history = []

# # mol_training_loop(100, 32)

In [None]:
BATCH_SIZE = 16
sigmoid = nn.Sigmoid()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(interaction_weights)).to(device)
model = PocketGCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def mol_training_loop(epochs, batch_size):
    for epoch in range(epochs):
        print("EPOCH %s" % epoch)
        
        for batch_idx in range(int(len(PDB_IDS) / batch_size)):
            loss_history = []
            pocket_batch, mol_batch = fetch_training_batch(PDB_IDS, BATCH_SIZE)
            pocket_batch = pocket_batch.to(device)
            
            optimizer.zero_grad() 

            voxel_embeds, _ = model(pocket_batch)

            interacting_voxel_indices = torch.where(torch.sum(pocket_batch.y, dim=1) > 0)[0]
            
            loss = criterion(voxel_embeds[interacting_voxel_indices], 
                             pocket_batch.y[interacting_voxel_indices].to(device))
            
            loss_history.append(loss.item())
            loss.backward()
            optimizer.step()
                    
            if batch_idx % 10 == 0:
                print(sum(loss_history) / len(loss_history))
                loss_history = []

mol_training_loop(100, BATCH_SIZE)

In [41]:
# ACTIVE CLASSIFIER
#@title Classifier

class ActiveClassifier(torch.nn.Module):
    def __init__(self, pocket_model, poxel_model, molecule_model):
        super(ActiveClassifier, self).__init__()
        self.pox_pooler = PoxelGCN()
        self.mol_pooler = MoleculePool()

#         self.linear1 = nn.Linear(2048, 1024)
        self.linear1 = nn.Linear(1024, 512)
        self.linear2 = nn.Linear(512, 512)
        self.linear3 = nn.Linear(512, 1)

#         self.linear1 = nn.Linear(512, 256)
#         self.linear2 = nn.Linear(256, 64)
#         self.linear3 = nn.Linear(64, 1)
#         self.linear4 = nn.Linear(256, 1)

        self.relu = nn.ReLU()

    def forward(self, pocket_batch, active_batch, decoy_batch):
#         voxel_node_indices = torch.where(pocket_batch.x[:, dummy_index] == 1.0)[0]
#         non_voxel_node_indices = torch.where(pocket_batch.x[:, dummy_index] == 0)[0]
#         voxel_edge_indices = torch.where(pocket_batch.edge_attr[:, voxel_edge_index]==1.0)[0]
        
        pocket_embeds = self.pocket_model(pocket_batch)
        
#         pocket_embeds[non_voxel_node_indices] = torch.zeros(pocket_embeds.size(1))
#         trimmed_edge_index = torch.vstack((pocket_batch.edge_index[0][voxel_edge_indices],
#                                            pocket_batch.edge_index[1][voxel_edge_indices]))
        
#         trimmed_edge_attr = pocket_batch.edge_attr[voxel_edge_indices]
        
        pocket_batch.x = pocket_embeds
#         pocket_batch.edge_index = trimmed_edge_index
#         pocket_batch.edge_attr = trimmed_edge_attr
        
        poxel_embeds = self.pox_pooler(pocket_batch)
        active_embeds = self.mol_pooler(active_batch)
        decoy_embeds = self.mol_pooler(decoy_batch)
        
        poxel_actives = torch.hstack((poxel_embeds, active_embeds))
        poxel_decoys = torch.hstack((torch.cat([poxel_embeds]*len(decoy_embeds), dim=0), 
                                     decoy_embeds.repeat_interleave(poxel_embeds.size(0), dim=0)))
        
        all_embeds = torch.vstack((poxel_actives, poxel_decoys))

        x = self.linear1(all_embeds) 
        x = self.relu(x)
        
        x = self.linear2(x) 
        x = self.relu(x)
        
        x = self.linear3(x) 
#         x = self.relu(x)
        
#         o = self.linear4(x) 
        return x
    
    
# batch_size=32
# ac = ActiveClassifier(PocketGCN, PoxelPool, MoleculePool)

# corpus = pickle.load(open(corpus_dir + "pdbbind_corpus_3_75.pkl", 'rb'))
# mol_loader = DataLoader(corpus[1], batch_size=batch_size, shuffle=False)

# pocket_loader = DataLoader(corpus[2], batch_size=batch_size, shuffle=False)
# decoy_batch = fetch_decoy_batch(200)

# for active_batch, pocket_batch in zip(mol_loader, pocket_loader):
#     out = ac(pocket_batch, active_batch, decoy_batch)
#     print(out)
#     break

In [None]:
# TRAINING LOOP
sigmoid = nn.Sigmoid()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([100])).to(device)
# criterion = nn.BCEWithLogitsLoss().to(device)
ac = ActiveClassifier(PocketGCN, PoxelPool, MoleculePool).to(device)
optimizer = torch.optim.Adam(ac.parameters(), lr=1e-3)

def training_loop(epochs, batch_size):
    for epoch in range(epochs):
        print("EPOCH %s" % epoch)
        
        for corpus_idx in range(1,75):
            loss_history = []
            corpus = pickle.load(open(corpus_dir + "pdbbind_corpus_%s_75.pkl" % corpus_idx, 'rb'))
            
            mol_loader = DataLoader(corpus[1], batch_size=batch_size, shuffle=False)
            pocket_loader = DataLoader(corpus[2], batch_size=batch_size, shuffle=False)
            
            for active_batch, pocket_batch in zip(mol_loader, pocket_loader):
                optimizer.zero_grad() 
                active_batch = active_batch.to(device)
                pocket_batch = pocket_batch.to(device)
                decoy_batch = fetch_decoy_batch(batch_size).to(device)
                
                out = ac(pocket_batch, active_batch, decoy_batch)
                
                y = torch.zeros(out.size(0))
                y[:batch_size] = 1
                y = torch.unsqueeze(y, dim=1).to(device)
                
                loss = criterion(out, y)
                loss_history.append(loss.item())
                loss.backward()
                optimizer.step()
                
            print(sum(loss_history) / len(loss_history))
            print('active:', [float("%.3f" % x) for x in torch.sigmoid(out[:7].squeeze()).tolist()])
            print('decoy:', [float("%.3f" % x) for x in torch.sigmoid(out[-7:].squeeze()).tolist()])
            print('')
    
#       for batch_idx in range(int(POXEL_COLLECTION_SIZE / batch_size)):
#         y = torch.hstack([torch.ones(batch_size), torch.zeros(batch_size*np_ratio)]).unsqueeze(dim=1).to(device) 
      
#         pox_batch, active_batch, decoy_batch, active_indices = retrieve_training_batch(batch_size, np_ratio, poxel_collection, col_idx)
    
#         # pox_batch, active_batch, decoy_batch, active_indices = retrieve_DUMMY_batch(batch_size, np_ratio, poxel_collection, COL_IDX, [Ameans, Astds], [Bmeans, Bstds])
#         out = ac(pox_batch.to(device), active_batch.to(device), decoy_batch.to(device), np_ratio)
#         loss = criterion(out, y)
#         print(batch_idx, torch.mean(sigmoid(out[0:10])).item(), torch.mean(sigmoid(out[10:])).item(), loss.item())

        # print(loss)

training_loop(100, 32)
