In [None]:
class ScrapePubMed:
    
    def initialize_variables(self):
        self.email = None
        self.search_term = None
        self.no_space_search_term = None
        self.retmax_search = None
        self.retmax_fetch = None
        self.search_term_count = None
        self.list_of_ids = []
        self.filename = None
        
    def set_email(self):
        self.email = input("Type in your email: ")
        
    def set_search_term(self):
        self.search_term = input("Type in your search term: ")
        self.no_space_search_term = self.search_term.replace(" ","")
        
    def set_retmax_search(self):
        while type(self.retmax_search) is not int or self.retmax_search <= 0 or self.retmax_search >= 10001:
            try:
                self.retmax_search = int(input("Between 1-10000, how many articles to search for in each batch? (<10k): "))
            except ValueError:
                print("Please type in a number from 1-10000: ")
                
    def set_retmax_fetch(self):
        while type(self.retmax_fetch) is not int or self.retmax_fetch <= 0 or self.retmax_fetch >= 10001:
            try:
                self.retmax_fetch = int(input("Between 1-10000, how many articles to fetch in each batch? (<10k): "))
            except ValueError:
                print("Please type in a number from 1-10000: ")
                
    def get_search_term_count(self):
        Entrez.email = self.email

        handle = Entrez.esearch(db='pubmed', retmode='xml', term=self.search_term)
        # api_key = API_KEY
        record = Entrez.read(handle)
        self.search_term_count = int(record["Count"])
        
    def get_list_of_ids(self):
        Entrez.email = self.email
        retstart = 0
        count = self.search_term_count

        while count > 0:
            handle = Entrez.esearch(db="pubmed", retmode = 'xml', term = self.search_term, retmax = self.retmax_search, retstart = retstart)
            # api_key = API_KEY
            record = Entrez.read(handle)
            idlist = record['IdList']
            self.list_of_ids.append(idlist)

            count -= int(len(idlist))
            retstart += int(len(idlist))
            
    def download_articles(self, filepath):
        total_time = time.time()

        Entrez.email = self.email

        next_article = 0
        next_batch = 0

        self.filename = f"{filepath}/{self.no_space_search_term}.csv"

        while self.search_term_count > 0:
            while int(len(self.list_of_ids[next_batch])) > 0:

                handle = Entrez.efetch(db = 'pubmed', id=self.list_of_ids[next_batch], retmode = 'xml', retmax = self.retmax_fetch)
                # api_key = API_KEY
                results = Entrez.read(handle)

                json_format = json.dumps(results)
                python_dictionary = json.loads(json_format)

                for document in python_dictionary['PubmedArticle']:
                    xml = dicttoxml(document)
                    tree = ET.ElementTree(ET.fromstring(xml)) 

                    for entry in tree.findall('MedlineCitation/Article/AuthorList/item'):

                        first_name = []
                        for f_name in entry.findall('ForeName'):
                            first_name.append(f_name.text)

                        last_name = []
                        for l_name in entry.findall('LastName'):
                            last_name.append(l_name.text)

                        affiliation = []
                        for aff in entry.findall('AffiliationInfo/item/Affiliation'):
                            affiliation.append(aff.text)

                        node = tree.find('MedlineCitation/Article/ArticleTitle')
                        if node is not None:
                            title = node.text
                        else:
                            title = None

                        node = tree.find('MedlineCitation/Article/ArticleDate/item/Year')
                        if node is not None:
                            publication = node.text
                        else:
                            publication = None

                        node = tree.find('MedlineCitation/PMID')
                        if node is not None:
                            pmid = node.text
                        else:
                            pmid = None

                        dataframe = pd.DataFrame()
                        dataframe = dataframe.append([first_name, last_name, affiliation]).transpose()
                        dataframe.columns = ['FirstName', 'LastName', 'Affiliation']
                        dataframe['Title'] = title
                        dataframe['Publication'] = publication
                        dataframe['PMID'] = pmid

                        dataframe.to_csv(self.filename, mode='a')

                self.list_of_ids[next_batch] = self.list_of_ids[next_batch][self.retmax_fetch:]
                self.search_term_count -= self.retmax_fetch

                print(f'{self.search_term_count} articles to go')


                if int(len(self.list_of_ids)) <= 0:
                    break

            next_batch += 1 

        end_total_time = time.time()
        finish = round(end_total_time-total_time)

        print('script complete')
        print(f'script took {finish} seconds to complete')

    def __init__(self):
        self.initialize_variables()
        self.set_email()
        self.set_search_term()
        self.set_retmax_search()
        self.set_retmax_fetch()
        self.get_search_term_count()
        self.get_list_of_ids()
        self.download_articles()

In [None]:
class CleanData:
    
    def initialize_variables(self):
        self.df = pd.DataFrame()
        self.filepath = None
        self.series = None
        self.info_df = None
        self.final_df = None
        self.return_str = None
        self.countries = [
        'Afghanistan',
        'Albania',
        'Algeria',
        'Andorra',
        'Angola',
        'Antigua & Deps',
        'Argentina',
        'Armenia',
        'Australia',
        'Austria',
        'Azerbaijan',
        'Bahamas',
        'Bahrain',
        'Bangladesh',
        'Barbados',
        'Belarus',
        'Belgium',
        'Belize',
        'Benin',
        'Bhutan',
        'Bolivia',
        'Bosnia & Herzegovina',
        'Botswana',
        'Brazil',
        'Brunei',
        'Bulgaria',
        'Burkina',
        'Burundi',
        'Cambodia',
        'Cameroon',
        'Canada',
        'Cape Verde',
        'Central African Rep',
        'Chad',
        'Chile',
        'China',
        'Colombia',
        'Comoros',
        'Congo',
        'Congo {Democratic Rep}',
        'Costa Rica',
        'Croatia',
        'Cuba',
        'Cyprus',
        'Czech Republic',
        'Denmark',
        'Djibouti',
        'Dominica',
        'Dominican Republic',
        'East Timor',
        'Ecuador',
        'Egypt',
        'El Salvador',
        'Equatorial Guinea',
        'Eritrea',
        'Estonia',
        'Ethiopia',
        'Fiji',
        'Finland',
        'France',
        'Gabon',
        'Gambia',
        'Georgia',
        'Germany',
        'Ghana',
        'Greece',
        'Grenada',
        'Guatemala',
        'Guinea',
        'Guinea-Bissau',
        'Guyana',
        'Haiti',
        'Honduras',
        'Hungary',
        'Iceland',
        'India',
        'Indonesia',
        'Iran',
        'Iraq',
        'Ireland',
        'Israel',
        'Italy',
        'Ivory Coast',
        'Jamaica',
        'Japan',
        'Jordan',
        'Kazakhstan',
        'Kenya',
        'Kiribati',
        'Korea North',
        'Korea South',
        'Kosovo',
        'Kuwait',
        'Kyrgyzstan',
        'Laos',
        'Latvia',
        'Lebanon',
        'Lesotho',
        'Liberia',
        'Libya',
        'Liechtenstein',
        'Lithuania',
        'Luxembourg',
        'Macedonia',
        'Madagascar',
        'Malawi',
        'Malaysia',
        'Maldives',
        'Mali',
        'Malta',
        'Marshall Islands',
        'Mauritania',
        'Mauritius',
        'Mexico',
        'Micronesia',
        'Moldova',
        'Monaco',
        'Mongolia',
        'Montenegro',
        'Morocco',
        'Mozambique',
        'Myanmar, {Burma}',
        'Namibia',
        'Nauru',
        'Nepal',
        'Netherlands',
        'New Zealand',
        'Nicaragua',
        'Niger',
        'Nigeria',
        'Norway',
        'Oman',
        'Pakistan',
        'Palau',
        'Panama',
        'Papua New Guinea',
        'Paraguay',
        'Peru',
        'Philippines',
        'Poland',
        'Portugal',
        'PR'
        'Qatar',
        'Romania',
        'Russian Federation',
        'Rwanda',
        'St Kitts & Nevis',
        'St Lucia',
        'Saint Vincent & the Grenadines',
        'Samoa',
        'San Marino',
        'Sao Tome & Principe',
        'Saudi Arabia',
        'Senegal',
        'Serbia',
        'Seychelles',
        'Sierra Leone',
        'Singapore',
        'Slovakia',
        'Slovenia',
        'Solomon Islands',
        'Somalia',
        'South Africa',
        'South Sudan',
        'Spain',
        'Sri Lanka',
        'Sudan',
        'Suriname',
        'Swaziland',
        'Sweden',
        'Switzerland',
        'Syria',
        'Taiwan',
        'Tajikistan',
        'Tanzania',
        'Thailand',
        'Togo',
        'Tonga',
        'Trinidad & Tobago',
        'Tunisia',
        'Turkey',
        'Turkmenistan',
        'Tuvalu',
        'Uganda',
        'Ukraine',
        'United Arab Emirates',
        'UK',
        'United Kingdom',
        'United States',
        'USA',
        'Uruguay',
        'Uzbekistan',
        'Vanuatu',
        'Vatican City',
        'Venezuela',
        'Vietnam',
        'Yemen',
        'Zambia',
        'Zimbabwe'
        ]
        self.pattern = '|'.join(self.countries)
        
    def csv_to_dataframe(self):
        self.filepath = input("Where is your file?: ")
        self.df = pd.read_csv(self.filepath)
        
    def drop_unnecessary_column_row(self):
        self.df = self.df.drop('Unnamed: 0', axis=1) 
        self.df = self.df[~self.df['FirstName'].isin(['FirstName'])]
        
    def clean_affiliation_column(self):
        self.df = self.df.dropna(subset=['Affiliation'])
           
    def make_email_column(self):
        # make new column of emails
        self.df['Email'] = self.df['Affiliation'].str.findall(r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+')
        # remove all rows where email is an empty list
        self.df = self.df[self.df['Email'].map(lambda x: len(x)) > 0]
        # for any row that has multiple emails, make a new row for each email
        self.df = self.df.explode('Email')
        # remove any periods that are at the end of emails
        self.df['Email'] = self.df['Email'].str.rstrip('.')   
        
    def drop_email_from_affiliation(self):
        # remove the email from column 'Affiliation'
        self.df['Affiliation'] = self.df['Affiliation'].str.replace(r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+', '')
        
    # search through a string and if a word in the string is in a word list(str) then return the word from the string
    def pattern_searcher(self, search_str:str, search_list:str):
        search_obj = re.search(search_list, search_str)
        if search_obj:
            self.return_str = search_str[search_obj.start(): search_obj.end()]
        else:
            self.return_str = 'NaN'
            
        return self.return_str
            
    def use_pattern_searcher(self):
        self.df['Country'] = self.df['Affiliation'].apply(lambda x: self.pattern_searcher(search_str=x, search_list=self.pattern))
        self.df['Country'] = self.df['Country'].replace('NaN', np.NaN)

    def combine_affiliation_per_author(self):
        self.df = self.df.groupby(['PMID', 'FirstName','LastName', 'Title', 'Publication', 'Email', 'Country'])['Affiliation'].apply(', '.join).reset_index()
        
    def separate_affiliation_entries(self):
        # split string in 'Affiliation' column by each comma into new columns
        self.df = pd.concat([self.df, self.df['Affiliation'].str.split(',', expand=True)], axis=1)
        
    def combine_by_attribute(self, dataframe, series, subject):
        # -10 to not count columns that are not integers     
        len_of_columns = int(len(self.df.columns)) - 10
        counter = 1
        while counter < len_of_columns:
            series = series.append(self.df[counter].loc[self.df[counter].str.contains('|'.join(subject), case=False, na=False)])
            counter += 1
            
        return series

    def create_info_column(self, info_words, column_name):    
        info = info_words
        info_series = self.df[0].loc[self.df[0].str.contains('|'.join(info), case=False, na=False)]
        self.info_df = self.combine_by_attribute(self.df, info_series, info)
        self.info_df = pd.DataFrame(self.info_df)
        self.info_df = self.info_df.rename(columns={0: f'{column_name}'})
        
        return self.info_df
        
    def comp_dept(self):
        companies = ['University|College|Institute|School|Academy|Hospital|Clinic|Medicine']
        company_dataframe = self.create_info_column(companies, self.df, 'Company')
  
        departments = ['Department', 'Dept.', 'Division']
        department_dataframe = self.create_info_column(departments, self.df, 'Department')

        self.df = self.df.join(company_dataframe, how='outer', rsuffix='University')
        self.df = self.df.join(department_dataframe, how='outer', rsuffix='Department')
        
    def replace_text(self):
        # replace 'Electronic address:' in all text with nothing
        self.df['Department'] = self.df['Department'].str.replace('Electronic address:', '')
        self.df['Company'] = self.df['Company'].str.replace('Electronic address:', '')
        # replace any countries in all text with nothing
        self.df['Department'] = self.df['Department'].str.replace('|'.join(self.countries), '')
        # replace text 'PR: ;' in all text with nothing
        self.df['Department'] = self.df['Department'].str.replace('PR ;', '')
        
    def remove_clutter(self):
        # drop unnecessary columns
        column_int_count = 0
        for i in range(len(self.df.columns)-10):
            self.df.drop([column_int_count], axis=1, inplace=True)
            column_int_count += 1
        # drop rows where columns 'FirstName'/'LastName' is empty
        self.df = self.df.dropna(subset=['FirstName', 'LastName'])
        # remove duplicate indexes
        self.df = self.df[~self.df.index.duplicated(keep='first')]
        # replace empty lists with NaN
        self.df = self.df.mask(self.df.applymap(str).eq('[]'))
        
    def drop_email_duplicates(self):
        self.df = self.df.drop_duplicates(subset='Email', keep="first")
        
    def make_final_dataframe(self):
        self.final_df = pd.DataFrame(columns=['Salutation', 'First Name', 'Last Name', 'Phone', 'Email', 'Position', 'Company', 'Department', 'Domain', 'Comment', 'Tags', 'Source', 'Country', 'State', 'City', 'Address', 'Zipcode'])
        self.final_df['First Name'] = self.df['FirstName']
        self.final_df['Last Name'] = self.df['LastName']
        self.final_df['Field1'] = self.df['Title']
        self.final_df['Field2'] = self.df['Publication']
        self.final_df['Field3'] = self.df['PMID']
        self.final_df['Email'] = self.df['Email']
        self.final_df['Country'] = self.df['Country']
        self.final_df['Department'] = self.df['Department']
        self.final_df['Company'] = self.df['Company']
        self.final_df['Comment'] = self.df['Affiliation']
        
    def invert_index(self):
        self.final_df.reindex(index=self.final_df.index[::-1])
        
    def __init__(self):
        self.initialize_variables()
        self.csv_to_dataframe()
        self.drop_unnecessary_column_row()
        self.clean_affiliation_column()
        self.make_email_column()
        self.drop_email_from_affiliation()
        self.use_pattern_searcher()
        self.combine_affiliation_per_author()
        self.separate_affiliation_entries()
        self.comp_dept()
        self.replace_text()
        self.remove_clutter()
        self.drop_email_duplicates()
        self.make_final_dataframe()
        self.invert_index()