In [1]:
import os
import pandas as pd
import json
import csv
import random
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from sklearn.decomposition import PCA

%matplotlib notebook

In [2]:
class Data():
    
    def __init__(self):
        # set file paths for future use
        self.drugbank = '/home3/jwang/druglikeness_ML/drugbank/prop/'
        self.zinc = '/home3/jwang/druglikeness_ML/zinc/prop/'
        self.likeness = '/home3/jwang/druglikeness_ML/zinc/total_50.list'
        
        self.load_smiles()
        
        
    def create_drug_dic(self):
        # returns list of drugs in the total_50.list file
        self.drug_dic = {}
        with open(self.likeness) as file:
            reader = csv.reader(file, delimiter='\t')
            for row in reader:
                location = self.drug_dic
                for char in row[0].split(' ')[0]:
                    if char not in location.keys():
                        location[char] = {}
                    location = location[char]
        self.create_drug_location_index()
                    
                
    def load_file(self, file):
        # takes file name without path as input
        # returns a dictionary of drugs and parameters if they are in the total_50 file
        # dictionary has 3 list: drug_id, mw, pka
        file_name = self.zinc + file
        with open(file_name) as file:
            reader = csv.reader(file, delimiter='\t')
            drug_list, mw, logp = [], [], []
            for row in reader:
                drug_id = row[1].split(' ')[0]
                if self.valid_drug(drug_id):
                    drug_list.append(drug_id)
                    mw.append(float(row[1].split(' ')[1]))
                    logp.append(float(row[1].split(' ')[2]))
        return (drug_list, np.array(mw), np.array(logp))
    
    
    def valid_drug(self, drug_id):
        # Checks if a drug id is in the total_50 file
        # Input is the drug id as a string
        # Outputs True if the drug is in the file, otherwise False
        location = self.drug_dic
        for char in list(drug_id):
            if char in location.keys():
                location = location[char]
            else:
                return False
        return True
    
    
    def select_data(self):
        file_list = os.listdir(self.zinc)
        self.data = {}
        for file in file_list:
            drug_list, mw, logp = self.load_file(file)
            x = np.concatenate((np.expand_dims(mw,1), np.expand_dims(logp,1)), axis=1)
            component = PCA(n_components=1).fit_transform(x)
            norm_component = (component-(np.sum(component)/component.shape[0]))/np.std(component)
            p = self.p(norm_component)
            weight = 2000/np.sum(p)
            p *= weight
            selection = p > np.random.rand(p.shape[0],1)
            selected_drugs = [drug_list[x] for x in range(p.shape[0]) if selection[x,0]]
            self.save_drugs(selected_drugs)
            self.save_distributions(np.array([np.squeeze(mw), np.squeeze(logp), np.squeeze(component), np.squeeze(selection)]), file)
            print('Number selected: {}'.format(np.sum(selection)))
            
            
    def save_drugs(self, drug_list):
        with open('selected_drugs.txt', 'a') as file:
            for drug in drug_list:
                file.write('{}\n'.format(drug))
                
                
    def save_distributions(self, data, file):
        data = np.transpose(data)
        if not 'distributions' in os.listdir():
            os.mkdir(os.getcwd()+'/distributions/')
        file_name = os.getcwd() + '/distributions/' + file + '.csv'
        df = pd.DataFrame(data=data, 
                          index=np.arange(data.shape[0]),
                          columns=['mw', 'logp', 'component', 'selection'])
        df.to_csv(file_name)
    
    
    def p(self, z):
        assert type(z) == np.ndarray
        return (1/np.sqrt(2*np.pi))*np.exp(-(z**2)/2)
    
    
    def create_drug_location_index(self):
        # Adds location of the drug to the drug_dic
        # Value of final level is the tuple: (file_name, index within file)
        for file in list(os.listdir(self.zinc)):
            drug_list, _, _ = self.load_file(file)
            drug_index = 1
            for drug_id in drug_list:
                location = self.drug_dic
                char_index = 0
                for char in list(drug_id):
                    if char in list(location.keys()):
                        if char_index+1 == len(list(drug_id)):
                            location[char] = (file, drug_index)
                        else:
                            location = location[char]
                    else:
                        break
                    char_index += 1
                drug_index += 1
                
                
    def get_smiles(self):
        self.smiles = {'real': [], 'fake':[]}
        real_file_name = self.drugbank + 'total.can'
        fake_file_name = os.getcwd() + '/initial_selected_drugs.txt'
        with open(real_file_name) as file:
            reader = csv.reader(file, delimiter='\t')
            count = 0
            for row in reader:
                self.smiles['real'].append(row[0])
                count += 1
        print('Number of real smiles: {}'.format(count))
        real_drug_list = []
        with open(fake_file_name) as file:
            reader = csv.reader(file)
            for row in reader:
                drug_id = row[0]
                tup = self.get_drug_location(drug_id)
                real_drug_list.append(tup)
        rand_array = np.random.rand(len(real_drug_list))
        selected_drugs = [tup for tup, rand in zip(real_drug_list, rand_array) if rand < 10000/len(real_drug_list)]
        selected_drugs_dic = {}
        for file, index in selected_drugs:
            if file in list(selected_drugs_dic.keys()):
                selected_drugs_dic[file].append(index)
            else:
                selected_drugs_dic[file] = []
        print('Getting smiles for real compounds...')
        file_number = 0
        selected_drugs_out = []
        for file_name in selected_drugs_dic:
            print('Loading file {} of {}'.format(file_number, len(list(selected_drugs_dic))))
            with open(self.zinc+'/'+file_name, 'r') as file:
                reader = csv.reader(file, delimiter='\t')
                row_number = 0
                for row in reader:
                    if row_number in selected_drugs_dic[file_name]:
                        self.smiles['fake'].append(row[0])
                        selected_drugs_out.append(row[1].split(' ')[0])
                    row_number += 1
            file_number += 1
        with open(os.getcwd()+'/selected_drugs.txt', 'a') as file:
            for drug in selected_drugs_out:
                file.write('{}\n'.format(drug))
        print('Number of fake smiles: {}'.format(len(self.smiles['fake'])))
        self.save_smiles()
        
        
    def save_smiles(self):
        with open('smiles.json', 'w') as file:
            json.dump(self.smiles, file)
            
    
    def load_smiles(self):
        with open('smiles.json', 'r') as file:
            self.smiles = json.load(file)
            
    
    def fix_smiles(self):
        for source in ['fake', 'real']:
            for smile, index in zip(self.smiles[source], range(len(self.smiles[source]))):
                self.smiles[source][index] = self.smiles[source][index].split(' ')[-1]
                
                
    def get_drug_location(self, drug_id):
        location = self.drug_dic
        char_index = 0
        for char in list(drug_id):
            if char_index+1 == len(list(drug_id)):
                return location[char]
            else:
                location = location[char]
            char_index += 1
            
    def save_drug_dic(self):
        with open('drug_dic.json', 'w') as file:
            json.dump(self.drug_dic, file)
            
    def load_drug_dic(self):
        with open('drug_dic.json', 'r') as file:
            self.drug_dic = json.load(file)
            
    

In [3]:
data = Data()

In [4]:
data.create_drug_dic()

In [5]:
data.get_drug_location('ZINC08769986')

KeyError: '6'