In [None]:
import pickle
import pandas as pd
import re
from tqdm import tqdm
import numpy as np
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from pprint import pprint as pp
import yaml
import plotly.graph_objects as go
import os
import pickle

BASE_PATH = '/Users/morgunov/batista/Summer/pipeline/'
REGEX_PATTERN = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|@@|\?|>|!|\*|\$|\%[0-9]{2}|[0-9])"
regex = re.compile(REGEX_PATTERN)

PRETRAINING_PATH = BASE_PATH + '1. Pretraining/'
GENERATION_PATH = BASE_PATH + '2. Generation/'
SAMPLING_PATH = BASE_PATH + '3. Sampling/'
DIFFDOCK_PATH = BASE_PATH + '4. DiffDock/'
SCORING_PATH = BASE_PATH + '5. Scoring/'
AL_PATH = BASE_PATH + '6. ActiveLearning/'

# Part 1. Convert all files to .csv.gz format

In [None]:
def remove_invalid(smiles):
    valid_set = set()
    for smile in tqdm(smiles, total=len(smiles)):
        mol = Chem.MolFromSmiles(smile)
        if mol is not None:
            valid_set.add(smile)
    return valid_set

def convert_bindingdb(pretrain_path, file_name):
    with open(f"{pretrain_path}datasets/original_files/{file_name}", 'r') as f:
        lines = f.readlines()
    smiles_list, metrics = [], {}
    print('Starting to fill smiles_list from BindingDB')
    for i, rline in tqdm(enumerate(lines), total=len(lines)):
        if i == 0: continue
        smile = rline.rstrip().split('\t')[1]
        smiles_list.append(smile)
    smiles_set = set(smiles_list)
    metrics.update(dict(entries=len(smiles_list), unique=len(smiles_set)))
    print('Starting to remove invalid smiles from BindingDB')
    valid_set = remove_invalid(smiles_set)
    pd.DataFrame({"smiles": list(valid_set)}).to_csv(f"{pretrain_path}datasets/converted/bindingdb.csv.gz", compression='gzip')
    metrics.update(dict(valid=len(valid_set)))
    # Dump metrics as a yaml file
    with open(f"{pretrain_path}dataset_metrics/bindingdb_metrics.yaml", 'w') as f:
        yaml.dump(metrics, f)

def convert_smiles_to_df(pretrain_path, file_name):
    with open(f"{pretrain_path}datasets/original_files/{file_name}.smiles", 'r') as f:
        lines = f.readlines()
    smiles = []
    for i, rline in tqdm(enumerate(lines), total=len(lines)):
        smile = rline.rstrip()
        smiles.append(smile)
    pd.DataFrame({"smiles": smiles}).to_csv(f"{pretrain_path}datasets/converted/{file_name}.csv.gz", compression='gzip')
    return smiles

def convert_guacamol(pretrain_path, file_name):
    print('test partition')
    smiles_test = convert_smiles_to_df(pretrain_path, f"{file_name}_test")
    print('valid partition')
    smiles_valid = convert_smiles_to_df(pretrain_path, f"{file_name}_valid")
    print('train partition')
    smiles_train = convert_smiles_to_df(pretrain_path, f"{file_name}_train")
    print('all partition')
    smiles_all = convert_smiles_to_df(pretrain_path, f"{file_name}_all")
    metrics = {"test": len(smiles_test), "valid": len(smiles_valid), "train": len(smiles_train), "all": len(smiles_all)}
    print('checking that there are no invalid smiles')
    valid_set = remove_invalid(smiles_all)
    assert len(valid_set) == len(smiles_all), f"{len(valid_set)} valid smiles out of {len(smiles_all)}"
    
    with open(f"{pretrain_path}dataset_metrics/guacamol_metrics.yaml", 'w') as f:
        yaml.dump(metrics, f)

def convert_chembl(pretrain_path, file_name):
    with open(f"{pretrain_path}datasets/original_files/{file_name}", "r") as f:
        lines = f.readlines()
    smiles_list, metrics = [], {}
    print('Starting to fill smiles_list from ChemBL')
    for i, rline in enumerate(lines):
        if i == 0: continue
        smile = rline.rstrip().split('\t')[1]
        smiles_list.append(smile)
    smiles_set = set(smiles_list)
    metrics.update(dict(entries=len(smiles_list), unique=len(smiles_set)))
    print('Starting to remove invalid smiles from ChemBL')
    valid_set = remove_invalid(smiles_set)
    pd.DataFrame({"smiles": list(valid_set)}).to_csv(f"{pretrain_path}datasets/converted/chembl.csv.gz", compression='gzip')
    metrics.update(dict(valid=len(valid_set)))
    # Dump metrics as a yaml file
    with open(f"{pretrain_path}dataset_metrics/chembl_metrics.yaml", 'w') as f:
        yaml.dump(metrics, f)

def convert_moses(pretrain_path, file_name):
    pd.read_csv(f"{pretrain_path}datasets/original_files/{file_name}_train.csv.gz").rename(columns={"SMILES": "smiles"}).to_csv(f"{pretrain_path}datasets/converted/{file_name}_train.csv.gz")
    pd.read_csv(f"{pretrain_path}datasets/original_files/{file_name}_test.csv.gz").rename(columns={"SMILES": "smiles"}).to_csv(f"{pretrain_path}datasets/converted/{file_name}_test.csv.gz")

In [None]:
def convert_all(pretrain_path):
    print("Converting BindingDB")
    convert_bindingdb(pretrain_path, "BindingDB_All.tsv")
    print("Converting GuacaMole")
    convert_guacamol(pretrain_path, "guacamol_v1")
    print("Converting ChemBL")
    convert_chembl(pretrain_path, "chembl_33_chemreps.txt")
    print(f"Converting MOSES")
    convert_moses(pretrain_path, "moses")

In [None]:
convert_all(PRETRAINING_PATH)

In [None]:
def concat_all(pretrain_path):
    chembl = pd.read_csv(f"{pretrain_path}datasets/converted/chembl.csv.gz")
    bindingdb = pd.read_csv(f"{pretrain_path}datasets/converted/bindingdb.csv.gz")
    moses_train = pd.read_csv(f"{pretrain_path}datasets/converted/moses_train.csv.gz")
    moses_test = pd.read_csv(f"{pretrain_path}datasets/converted/moses_test.csv.gz")
    guacamol = pd.read_csv(f"{pretrain_path}datasets/converted/guacamol_v1_all.csv.gz")

    combined = pd.concat([chembl, bindingdb, moses_train, moses_test, guacamol])
    print(f"Combined df has {len(combined)} rows")
    combined.drop_duplicates(subset='smiles', inplace=True)
    print(f"Combined df has {len(combined)} rows after dropping duplicates")
    combined.dropna(inplace=True)
    print(f"Combined df has {len(combined)} rows after dropping NaNs")
    combined.to_csv(f"{pretrain_path}datasets/converted/combined.csv.gz")
    return combined

combined = concat_all(PRETRAINING_PATH)
# Combined df has 6906124 rows
# Combined df has 5622772 rows after dropping duplicates
# Combined df has 5622771 rows after dropping NaNs

In [None]:
import itertools

def calculate_cross_repetitions(pretrain_path):
    chembl = pd.read_csv(f"{pretrain_path}datasets/converted/chembl.csv.gz")
    bindingdb = pd.read_csv(f"{pretrain_path}datasets/converted/bindingdb.csv.gz")
    moses_train = pd.read_csv(f"{pretrain_path}datasets/converted/moses_train.csv.gz")
    moses_test = pd.read_csv(f"{pretrain_path}datasets/converted/moses_test.csv.gz")
    moses_combined = pd.concat([moses_train, moses_test])
    guacamol = pd.read_csv(f"{pretrain_path}datasets/converted/guacamol_v1_all.csv.gz")

    datasets = [('moses', set(moses_combined['smiles'])), ('bindingDB', set(bindingdb['smiles'])), ('ChemBL', set(chembl['smiles'])), ('GuacaMol', set(guacamol['smiles']))]
    metrics = {}
    for name, ds in datasets:
        metrics[f"Number of molecules in {name}"] = len(ds)
        
    pairs = list(itertools.combinations(datasets, 2))
    for (name1, ds1), (name2, ds2) in pairs:
        metrics[f"Number of molecules in both {name1} and {name2}"] = len(ds1 & ds2)

    combined = pd.concat([chembl, bindingdb, moses_combined, guacamol])
    combined.drop_duplicates(subset=['smiles'], inplace=True)
    combined.dropna()
    metrics["Number of molecules in all datasets"] = len(combined)
    # dump metrics to a yaml file
    with open(f"{pretrain_path}dataset_metrics/cross_repetitions.yaml", 'w') as f:
        yaml.dump(metrics, f)
    return metrics
calculate_cross_repetitions(PRETRAINING_PATH)

# Part 2. Preprocess the datasets

In [None]:
class Graph:
    def __init__(self):
        self.title_size = 20
        self.axis_title_size = 14
        self.tick_font_size = 12
        self.text_color="#333333"
        self.background = "white"
        self.grid_color = "#e2e2e2"
        self.line_color = "#000000"
        self.font_family = 'Helvetica'
        self.show_xgrid = False
        self.show_ygrid = False
        self.width = 600
        self.height = 400
        self.title = ''
        self.xaxis_title = ''
        self.yaxis_title = ''
    
    def update_parameters(self, params):
        for key, val in params.items():
            setattr(self, key, val)
        

    def style_figure(self, figure):
        figure.update_layout({
            'margin': {'t': 50, 'b': 50, 'l': 50, 'r': 50},
            'plot_bgcolor': self.background,
            'paper_bgcolor': self.background,
            'title': {
                'text': self.title,
                'font': {
                    'size': self.title_size,
                    'color': self.text_color,
                    'family': self.font_family
                },
            },
            'height': self.height,  # Set fixed size ratio 3:4
            'width': self.width, 
            'font': {
                'family': self.font_family,
                'size': self.tick_font_size,
                'color': self.text_color
            },
            'legend': {
                'font': {
                    'family': self.font_family,
                    'size': self.tick_font_size,
                    'color': self.text_color
                },
            },
        })

        # Setting the title size and color and grid for both x and y axes
        figure.update_xaxes(
            title=self.xaxis_title,
            title_font={'size': self.axis_title_size, 'color': self.text_color, 'family': self.font_family},
            tickfont={'size': self.tick_font_size, 'color': self.text_color, 'family': self.font_family},
            showgrid=self.show_xgrid,
            gridwidth=1,
            gridcolor=self.grid_color,
            linecolor=self.line_color,  # make x axis line visible
            linewidth=2
        )

        figure.update_yaxes(
            title=self.yaxis_title,
            title_standoff=0,
            title_font={'size': self.axis_title_size, 'color': self.text_color, 'family': self.font_family},
            tickfont={'size': self.tick_font_size, 'color': self.text_color, 'family': self.font_family},
            showgrid=self.show_ygrid,
            gridwidth=1,
            gridcolor=self.grid_color,
            linecolor=self.line_color,  # make y axis line visible
            linewidth=2
        )
        return figure

In [None]:
class Dataset:

    def __init__(self, base_path, file_name, remove_isotopes=False, remove_non_bio_friendly=False, token_len_cutoff=None, token_freq_cutoff=None, save_files=False):
        self.base_path = base_path
        self.file_name = file_name
        self.vocab = set()
        self.max_block_size = 0
        self.block_sizes = []
        self.token_to_freq = {}
        self.remove_isotopes = remove_isotopes
        self.remove_non_bio_friendly = remove_non_bio_friendly
        self.save_files = save_files
        self.plot_suffix = ''
        self.token_len_cutoff = token_len_cutoff
        self.token_freq_cutoff = token_freq_cutoff
        if self.remove_isotopes:
            with open(f"{self.base_path}exceptions/isotope_exceptions.yaml", 'r') as f:
                self.isotope_exceptions = set(yaml.safe_load(f))
            self.plot_suffix += '_no_isotopes'
        if self.remove_non_bio_friendly:
            with open(f"{self.base_path}exceptions/non_bio_friendly_exceptions.yaml", 'r') as f:
                self.non_bio_friendly_exceptions = set(yaml.safe_load(f))
            self.plot_suffix += '_no_nonbio_friendly'
        with open(f"{self.base_path}exceptions/guacamole_exceptions.yaml", 'r') as f:
            self.guacamole_exceptions = set(yaml.safe_load(f))

    def load_and_analyze(self):
        self.df = pd.read_csv(f"{self.base_path}datasets/converted/{self.file_name}.csv.gz")
        
        processed_smiles = set()
        smile_crumbles = []

        pbar_desc = 'Analyzing ' + self.file_name
        if self.remove_isotopes:
            pbar_desc += ' (no isotopes)'
            self.smiles_with_isotopes = set()
        if self.remove_non_bio_friendly:
            pbar_desc += ' (no non bio friendly)'
            self.smiles_with_non_bio_friendly = set()
        self.smiles_guac_filter = set()

        pbar = tqdm(self.df['smiles'].values, total=len(self.df['smiles'].values), desc=pbar_desc)
        for smile in pbar:
            if not isinstance(smile, str):
                print(f"Found {smile=} which has {type(smile)=}")
                continue
            tokens = regex.findall(smile.strip())

            # if any([token.replace('[', '').replace(']', '').replace('-', '').replace('+', '') in self.guacamole_exceptions for token in tokens]): 
                # self.smiles_guac_filter.add(smile)
                # continue
            if self.remove_isotopes and any([token in self.isotope_exceptions for token in tokens]):
                self.smiles_with_isotopes.add(smile)
                continue
            if self.remove_non_bio_friendly and any([token in self.non_bio_friendly_exceptions for token in tokens]):
                self.smiles_with_non_bio_friendly.add(smile)
                continue
            if self.token_len_cutoff is not None and len(tokens) > self.token_len_cutoff:
                continue

            processed_smiles.add(smile)
            for token in tokens:
                self.vocab.add(token)
                if token not in self.token_to_freq:
                    self.token_to_freq[token] = 0
                self.token_to_freq[token] += 1

            self.max_block_size = max(self.max_block_size, len(tokens))
            self.block_sizes.append(len(tokens))
        
        if self.save_files: pd.DataFrame({"smiles": list(processed_smiles)}).to_csv(f"{self.base_path}datasets/converted/{self.file_name}_processed.csv.gz", compression='gzip')
        self.processed_smiles = processed_smiles
        if self.token_freq_cutoff is not None:
            self._remove_rare_tokens()

    def _remove_rare_tokens(self):
        new_smiles = set()
        self.max_block_size = 0
        self.vocab = set()

        pbar = tqdm(self.processed_smiles, total=len(self.processed_smiles), desc=f"Removing rare tokens")
        for smile in pbar:
            tokens = regex.findall(smile.strip())
            if any([self.token_to_freq[token] < self.token_freq_cutoff for token in tokens]): continue
            new_smiles.add(smile)
            for token in tokens:
                self.vocab.add(token)
            self.max_block_size = max(self.max_block_size, len(tokens))
        self.no_rare = new_smiles
        pd.DataFrame({"smiles": list(new_smiles)}).to_csv(f"{self.base_path}datasets/converted/{self.file_name}_processed_freq{self.token_freq_cutoff}_block{self.token_len_cutoff}.csv.gz", compression='gzip')

    def plot_distribution(self, block_sizes):
        graph = Graph()
        fig = go.Figure()
        fig.add_trace(go.Histogram(x=block_sizes, name=self.file_name, histnorm='probability density'))
        fig.update_layout(barmode='overlay', xaxis=dict(dtick=20))
        fig.update_traces(opacity=0.75)
        percentiles = [25, 50, 75, 95, 99, 99.9, 99.99]
        for i, percentile in enumerate(percentiles):
            val = np.percentile(block_sizes, percentile)
            y_coord = 0.5 if i < len(percentiles) / 2 else 0
            y_anno = -40 * (1 + i % 5) #-30 if i % 2 == 0 else -50 #-10 if i < len(percentiles) / 2 else -30
            fig.add_annotation(x=val, y=0, text=f'{percentile} percentile<br>{val:.0f} blocks',
                showarrow=True, arrowhead=1, ax=-10, ay=y_anno)
        stats_text = f"Total smiles: {len(block_sizes)}<br>Vocab Size: {len(self.vocab)}<br>Max Block Size: {self.max_block_size}"
        guac_fraction = len(self.smiles_guac_filter) / len(self.df["smiles"])
        stats_text += f"<br># excluded (guacamole): {len(self.smiles_guac_filter)} or {guac_fraction:.2%}"
        
        if self.remove_isotopes:
            isotope_fraction = len(self.smiles_with_isotopes) / len(self.df["smiles"])
            stats_text += f"<br># excluded (isotope-containing): {len(self.smiles_with_isotopes)} or {isotope_fraction:.2%}"
        if self.remove_non_bio_friendly:
            non_bio_friendly_fraction = len(self.smiles_with_non_bio_friendly) / len(self.df["smiles"])
            stats_text += f"<br># excluded (non bio friendly): {len(self.smiles_with_non_bio_friendly)} or {non_bio_friendly_fraction:.2%}"

        fig.add_annotation(x=0.75, y=1, text=stats_text,
                            showarrow=False, xref='paper', yref='paper', xanchor='left', yanchor='top', font=dict(size=12))
        graph.update_parameters({'title': f'Block size distribution for {self.file_name} partition. Isotopes removed: {self.remove_isotopes}',
                                'xaxis_title': 'Block size', 'yaxis_title': 'Probability density',
                                'width': 1280, 'height': 720, 'show_xgrid': True})
        graph.style_figure(fig)
        # fig.write_html(self.base_path + f'plots/block_size_distribution_{self.file_name}{self.plot_suffix}.html', include_plotlyjs='cdn')
        fig.write_html(self.base_path + f'plots/html/block_size_distribution/{self.file_name}.html', include_plotlyjs='cdn')
        fig.write_image(self.base_path + f'plots/jpg/block_size_distribution/{self.file_name}.jpg', scale=3.0)
        fig.write_image(self.base_path + f'plots/svg/block_size_distribution/{self.file_name}.svg')

    def plot_token_frequencies(self, top_n=20):
        sorted_token_freq = dict(sorted(self.token_to_freq.items(), key=lambda item: item[1], reverse=True))
        sorted_token_freq = dict(list(sorted_token_freq.items())[:top_n])
        
        graph = Graph()
        fig = go.Figure(data=[
            go.Bar(
                x=list(sorted_token_freq.keys()), 
                y=list(sorted_token_freq.values()), 
                name=self.file_name)
            ])
        graph.update_parameters({'title': f'Top {top_n} character frequencies for {self.file_name} partition. Isotopes removed: {self.remove_isotopes}',
                                'xaxis_title': 'Token', 'yaxis_title': 'Frequency',
                                'width': 1280, 'height': 720})

        graph.style_figure(fig)
        # fig.write_html(self.base_path + f'plots/token_frequencies_{self.file_name}{self.plot_suffix}.html', include_plotlyjs='cdn')
        fig.write_html(self.base_path + f'plots/html/token_frequencies/{self.file_name}.html', include_plotlyjs='cdn')
        fig.write_image(self.base_path + f'plots/jpg/token_frequencies/{self.file_name}.jpg', scale=3.0)
        fig.write_image(self.base_path + f'plots/svg/token_frequencies/{self.file_name}.svg')
    
    def export_vocab(self):
        with open(self.base_path + f'vocab_{self.file_name}.txt', 'w') as f:
            for token in self.vocab:
                f.write(token + '\n')


In [None]:
configs = [
    # dict(file_name='combined', remove_isotopes=True, remove_non_bio_friendly=True, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
    # dict(file_name='moses_train', remove_isotopes=False, remove_non_bio_friendly=False, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
    # dict(file_name='bindingdb', remove_isotopes=False, remove_non_bio_friendly=False, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
    # dict(file_name='chembl', remove_isotopes=False, remove_non_bio_friendly=False, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
    # dict(file_name='guacamol_v1_all', remove_isotopes=False, remove_non_bio_friendly=False, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
    # dict(file_name='combined', remove_isotopes=True, remove_non_bio_friendly=True, token_len_cutoff=133, token_freq_cutoff=100, save_files=False),
    # dict(file_name='combined', remove_isotopes=True, remove_non_bio_friendly=True, token_len_cutoff=133, token_freq_cutoff=1000, save_files=False),
    # dict(file_name='combined_processed_freq100_block133', remove_isotopes=True, remove_non_bio_friendly=True, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
    # dict(file_name='combined_processed_freq1000_block133', remove_isotopes=True, remove_non_bio_friendly=True, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
    # dict(file_name='combined', remove_isotopes=True, remove_non_bio_friendly=True, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
    dict(file_name='combined', remove_isotopes=False, remove_non_bio_friendly=False, token_len_cutoff=133, token_freq_cutoff=1000, save_files=False),
    # dict(file_name='combined', remove_isotopes=False, remove_non_bio_friendly=False, token_len_cutoff=None, token_freq_cutoff=1000, save_files=False),
    # dict(file_name='combined_processed_freq', remove_isotopes=False, remove_non_bio_friendly=False, token_len_cutoff=None, token_freq_cutoff=None, save_files=False),
]
for config in configs:
    print(config)
    ds = Dataset(PRETRAINING_PATH, **config)
    ds.load_and_analyze()
    ds.plot_distribution(ds.block_sizes)
    ds.plot_token_frequencies(top_n=None)

In [None]:
ds.token_to_freq['[Na+]'], ds.token_to_freq['[K+]']

In [None]:
def sandbox(pretrain_path):
    with open(f"{pretrain_path}exceptions/non_bio_friendly_exceptions.yaml", "r") as f:
        nonbio = set(yaml.load(f, Loader=yaml.FullLoader))

    with open(f"{pretrain_path}exceptions/isotope_exceptions.yaml", "r") as f:
        isotope = set(yaml.load(f, Loader=yaml.FullLoader))

    with open(f"{pretrain_path}dataset_descriptors/combined_processed_freq1000_block133.yaml", "r") as f:
        final = yaml.load(f, Loader=yaml.FullLoader)

    with open(f"{pretrain_path}dataset_descriptors/combined_processed_freq1000_block133_nonan.yaml", "r") as f:
        final_nonan = yaml.load(f, Loader=yaml.FullLoader)

    print(f"{len(nonbio)=}, {len(isotope)=}, {len(nonbio & isotope)=}")
    print((nonbio | isotope) & ds.vocab)
    atoms = {"Ag","Al","Am","Ar","At","Au","D","E","Fe","G","K","L","M","Ra","Re","Rf","Rg","Rh","Ru","T","U","V","W","Xe","Y","Zr","a","d","f","g","h","k","m","si","t","te","u","v","y",}
    print(ds.vocab - set(final['stoi'].keys()))
    print(ds.vocab - set(final_nonan['stoi'].keys()))

sandbox(PRETRAINING_PATH)

In [None]:
# Sort the token_to_freq dictionary by frequency in reversed order
sorted_token_freq = dict(sorted(combined_db.token_to_freq.items(), key=lambda item: item[1], reverse=True))
sorted_token_freq

In [None]:
def create_train_val_partition(pretrain_path, file_name, validation_fraction=0.05):
    all_mols = pd.read_csv(f"{pretrain_path}datasets/converted/{file_name}.csv.gz", compression='gzip')['smiles'].to_numpy()
    np.random.shuffle(all_mols)
    validation_index = int((1-validation_fraction)*len(all_mols))
    print(f"{validation_index=}")
    train_mols = all_mols[:validation_index]
    val_mols = all_mols[validation_index:]
    pd.DataFrame({'smiles': train_mols}).to_csv(f"{pretrain_path}datasets/splits/{file_name}_train.csv.gz", compression='gzip')
    pd.DataFrame({'smiles': val_mols}).to_csv(f"{pretrain_path}datasets/splits/{file_name}_val.csv.gz", compression='gzip')

In [None]:
# create_train_val_partition(PRETRAINING_PATH, 'combined_processed_freq100_block133')
create_train_val_partition(PRETRAINING_PATH, 'combined_processed_freq1000_block133_nonan')
# create_train_val_partition(PRETRAINING_PATH, 'combined_processed_freq1000_block133')

In [None]:
)

In [None]:
def export_combined_metrics(pretrain_path, file_name):
    train_df = pd.read_csv(f"{pretrain_path}datasets/splits/{file_name}_train.csv.gz")
    val_df = pd.read_csv(f"{pretrain_path}datasets/splits/{file_name}_val.csv.gz")
    with open(f"{pretrain_path}dataset_metrics/{file_name}_metrics.yaml", 'w') as f:
        yaml.dump({"total": len(train_df)+len(val_df), "training partition": len(train_df), "validation partition": len(val_df) }, f)
    

# export_combined_metrics(PRETRAINING_PATH, 'combined_processed_freq100_block133')
export_combined_metrics(PRETRAINING_PATH, 'combined_processed_freq1000_block133_nonan')