In [1]:
import itertools
import pickle as pickle
import plotly.graph_objects as go
from scipy.ndimage.filters import gaussian_filter1d
from torch_geometric.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler


from analysis.graphanalysis import *

# Loading the data

In [2]:
p = CifProcessor()
p.read_pkl_metainfo()
p.read_pkl(mode='r', folder='data/processed/')

  3%|███                                                                                                               | 15/557 [00:00<00:03, 144.20it/s]

Reading files with generic numbers on receptors.


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 557/557 [00:03<00:00, 141.32it/s]


# Create graph analysis class

In [3]:
gp = GraphProcessor(d=[], p=p)

Reading data from data/couplings/families_coupling.xls!


  obj = obj._drop_axis(labels, axis, level=level, errors=errors)


Initialized Affinity Processor!
Please set a group --------------  ['GPCRdb', 'Inoue', 'Bouvier'].
please set label type -----------  ['Guide to Pharmacology', 'Log(Emax/EC50)', 'pEC50', 'Emax'].

Selected label type 'Log(Emax/EC50)'.


Selected data of group 'GPCRdb'.



In [4]:
gp.set_atom_list()
gp.apply_atom_list_filter()

In [5]:
gp.simplify()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 368/368 [00:02<00:00, 131.25it/s]


# Create training settings

In [6]:
pois = ['1.55', '2.39', '3.46', '6.37', '7.55']

In [7]:
from itertools import compress, product

def combinations(items):
    return ( set(compress(items,mask)) for mask in product(*[[0,1]]*len(items)) )
    # alternative:                      ...in product([0,1], repeat=len(items)) )

In [8]:
list(combinations(range(4)))

[set(),
 {3},
 {2},
 {2, 3},
 {1},
 {1, 3},
 {1, 2},
 {1, 2, 3},
 {0},
 {0, 3},
 {0, 2},
 {0, 2, 3},
 {0, 1},
 {0, 1, 3},
 {0, 1, 2},
 {0, 1, 2, 3}]

In [9]:
hyper_params = {
    'learning_rate': [0.00001],
    'cons_r_res':[
        ['3.53', '7.55'],
        ['3.53', '6.37'],
        ['6.37', '7.55'],
        ['2.55', '3.53', '7.53']
        
    ],
    'radius': [10],
    'max_edge_dist': [7, 9],
    'batch_size': [2]
}

In [10]:
keys, values = zip(*hyper_params.items())
keys = tuple(['index']) + keys
configs = [dict(zip(keys, [idx, *v])) for idx, v in enumerate(itertools.product(*values))]

In [19]:
configs[0]

{'index': 0,
 'learning_rate': 1e-05,
 'cons_r_res': ['3.53', '7.55'],
 'radius': 10,
 'max_edge_dist': 7,
 'batch_size': 2}

# Run Analysis

In [12]:
def run_training(model, n_epochs, lr, train_loader, validation_loader, patience = 50):
    n_outputs=4
    h5n = H5Net(n_outputs=n_outputs, aggr='add')
    h5n.to('cuda')
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(h5n.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=patience, 
                                                           threshold=0.0001, threshold_mode='rel', 
                                                           cooldown=0, min_lr=0, eps=1e-8, verbose=True)

    print("Initialized model!")
    n_epochs = n_epochs

    t_e_losses = []
    t_mse_losses = []
    v_e_losses = []
    v_mse_losses = []

    best_validation = 12
    best_training = 12
    
    for e in range(n_epochs):
        # TRAINING
        training_loss = 0
        t_bar = tqdm(enumerate(train_loader), desc='TRAINING')
        for b,  batch in t_bar:
            batch.z = batch.x
            batch.to('cuda')
            pred = h5n(batch)
            target = batch.y / 12  # very crude normalization! (should substract upper & lower limit etc)
            target = target.reshape(-1, n_outputs)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            training_loss = (training_loss * b + loss) / (b + 1)
            if b % 5 == 0:
                t_bar.set_description(model + ' | E: {} | TRAINING: MSE = '.format(e)
                                      +str(round(float(training_loss), 6)))
        
        t_mse_loss = training_loss.cpu().detach().numpy()
        t_e_loss = round(torch.sqrt(training_loss).cpu().detach().numpy() * 12, 2)
        t_e_losses.append(t_e_loss) # take the square root to get the mean error, then *12
        t_mse_losses.append(t_mse_loss)
        
        if t_e_loss < best_training:
            best_training = t_e_loss            
        
        # VALIDATION
        with torch.no_grad():
            validation_loss = 0
            t_bar = tqdm(enumerate(validation_loader), desc='VALIDATION')
            for b,  batch in t_bar:
                batch.z = batch.x
                batch.to('cuda')
                pred = h5n(batch)
                target = batch.y / 12
                target = target.reshape(-1, n_outputs)
                loss = criterion(pred, target)
                validation_loss = (validation_loss * b + loss) / (b + 1)
                if b % 5 == 0:
                    t_bar.set_description(model + ' | E: {} | VALIDATION: MSE = '.format(e)
                                          +str(round(float(validation_loss), 6)))
        scheduler.step(validation_loss)
        v_mse_loss = validation_loss.cpu().detach().numpy()
        v_e_loss = round(torch.sqrt(validation_loss).cpu().detach().numpy() * 12, 2)
        v_e_losses.append(v_e_loss)
        v_mse_losses.append(v_mse_loss)
        
        if v_e_loss < best_validation:
            print("New best validation perfomance: MSE={} | mean epoch error={}".format(v_mse_loss, v_e_loss))
            best_validation = v_e_loss
            ckpt = h5n.state_dict()
            last_update = e
        
        if e - (4 * patience) > last_update:
            print("STOP (no further improvement recorded after {} epochs)".format(4 * patience))
            break       
            
    return ckpt, t_e_losses, t_mse_losses, best_training, v_e_losses, v_mse_losses, best_validation

In [13]:
from torch_geometric.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler


def run(model, gp, config, performance, n_epochs=250, save=True):
    gp.create_graph(filter_by_chain=True,
                gpcr=True,
                gprotein=False,
                auxilary=False,
                node_criteria='Interaction Site', 
                edge_criteria='radius',
                h5start=13,
                cons_r_res=config['cons_r_res'], 
                radius=config['radius'],
                max_edge_dist=config['max_edge_dist'])
    print("Finished creating Graphs!")
    validation_split = .2
    shuffle_dataset = True

    # Creating data indices for training and validation splits
    dataset_size = len(gp)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # Creating PT data samplers and loaders
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    
    print("Initializing DataLoader ==> Batchsize: {}".format(config['batch_size']))
    train_loader = DataLoader(gp, batch_size=int(config['batch_size']), sampler=train_sampler)
    validation_loader = DataLoader(gp, batch_size=int(config['batch_size']), sampler=valid_sampler)
    
    print("Starting Training ({} Epochs)!".format(n_epochs))
    ckpt, t_e_losses, t_mse_losses, best_training, v_e_losses, v_mse_losses, best_validation = run_training(
        model=model,
        n_epochs=n_epochs, 
        lr = config['learning_rate'],
        train_loader=train_loader, 
        validation_loader=validation_loader)
    performance[config['index']] = {
        'best_training': best_training,
        'training_epoch_losses': t_e_losses,
        'training_mse_losses': t_mse_losses,
        'best_validation': best_validation,
        'validation_epoch_losses': v_e_losses,
        'validation_mse_losses': v_mse_losses
    }
    if save:
        picklefile = open(str('models/performances'+model+'.pkl'),'wb')
        pickle.dump(performance, picklefile)
    return performance, ckpt

In [14]:
"""performance = {}

for c, config in enumerate(configs):
    if c in top5:
        values = list(config.values())
        model = 'model_' + "_".join([str(x) if not isinstance(x, list) else "_".\
                                     join([str(round(float(y), 2)).replace(".", "") for y in x]) for x in values])
        np.random.seed(seed=c)
        torch.manual_seed(c)
        print("\n\n\nRUNNING NEW CONFIGURATION:\n",config)
        performance, ckpt = run(model, gp, config, performance, n_epochs=1000)"""

'performance = {}\n\nfor c, config in enumerate(configs):\n    if c in top5:\n        values = list(config.values())\n        model = \'model_\' + "_".join([str(x) if not isinstance(x, list) else "_".                                     join([str(round(float(y), 2)).replace(".", "") for y in x]) for x in values])\n        np.random.seed(seed=c)\n        torch.manual_seed(c)\n        print("\n\n\nRUNNING NEW CONFIGURATION:\n",config)\n        performance, ckpt = run(model, gp, config, performance, n_epochs=1000)'

In [1]:
# picklefile = open(str('models/performances.pkl'),'wb')

In [2]:
# pickle.dump(performance, picklefile)

In [3]:
# with open("models/performances.pkl", "rb") as input_file:
#     data = pickle.load(input_file)

# Visualization

In [None]:

from plotly.subplots import make_subplots

def loss_plot(performance, config, save=False):
    values = list(config.values())
    model = 'model_' + "_".join([str(x) if not isinstance(x, list) else "_".\
                                 join([str(round(float(y), 2)).replace(".", "") for y in x]) for x in values])
    n_epochs = len(performance['validation_epoch_losses'])
    def make_plot(fig, a, b, col=0, loss='MSE'):
        # Create traces
        fig.add_trace(go.Scatter(
            y=a,
            mode='lines',
            name='Validation loss'))

        fig.add_trace(go.Scatter(
            y=b,
            mode='lines',
            name='Training loss'))

        y_title = loss + ' [log(Emax/EC50)]'
        fig.update_layout(
            title=model,
            xaxis_title='Epoch',
            yaxis_title=y_title
            )
        
        vysmoothed = gaussian_filter1d(a, sigma=2)
        fig.add_trace(go.Scatter(y=vysmoothed))
        
        tysmoothed = gaussian_filter1d(b, sigma=2)
        fig.add_trace(go.Scatter(y=tysmoothed))
        return fig
    fig = go.Figure()
    fig = make_plot(fig, performance['validation_epoch_losses'], performance['training_epoch_losses'], loss='Mean Absolute Error (MAE) \n')
    fig.show()
    if save:
        fig.write_image("plots/" + model + "mae_.png")
    del(fig)
    fig = go.Figure()
    fig = make_plot(fig, performance['validation_mse_losses'], performance['training_mse_losses'], loss='Mean Squared Error (MSE)')
    fig.show()
    if save:
        fig.write_image("plots/" + model + "_mse.png")

In [None]:
for i in list(data.keys()):
    loss_plot(data[i], config=configs[i], save=True)

In [None]:
def get_best_config(performance):
    best_vals = []
    idxs = []
    for i, p in enumerate(performance.values()):
        best_vals.append(round(float(p['best_validation'].cpu().numpy()), 3))
        idxs.append(i)
    return [(y, x) for (x, y) in list(sorted(zip(best_vals, idxs)))]

In [None]:
top5 = list(dict(get_best_config(performance)[:5]).keys())

In [None]:
top5

In [None]:
[configs[x] for x in top5]

In [None]:
def analyse_single_residue_variation(model, sample):
    aas = [x for x in range(20)]
    with torch.no_grad():
        sample.to('cuda')
        sample.z = sample.x
        residues = [int(x) for x in list(sample.x)]
        columns = [[str(x)+'_gs', str(x)+'_gi/o', str(x)+'_gq/11', str(x)+'_g12/13'] for x in residues]
        cols = [item for sublist in columns for item in sublist]
        df = pd.DataFrame(columns=cols, index=aas)
        affinities = list(model(sample)[0].cpu().numpy()*12)
        print(residues)
        for r, res in enumerate(residues):
            for i, aa in enumerate(aas):
                if res == aa:
                    for a, aff in enumerate(affinities):
                        df.at[aa, cols[r*4+a]] = aff
                else:
                    sample_ = sample
                    sample_.z[r] = aa
                    affinities_ = list(model(sample_)[0].cpu().numpy()*12)
                    for a, aff in enumerate(affinities_):
                        df.at[aa, cols[r*4+a]] = aff
    return df

In [None]:
batch_size  = 1

# Creating data indices for training and validation splits
dataset_size = len(gp)

analyse_loader = DataLoader(gp, batch_size=batch_size, shuffle=True)

# in vitro Mutation analysis

In [None]:
sample = next(iter(validation_loader))

In [None]:
df = analyse_single_residue_variation(h5n, sample)

In [None]:
gs_cols = [x for x in list(df.columns) if 's' in x]

In [None]:
df.T.std(axis=1).max()