In [1]:
import os
import sys

proj_root = os.getcwd() + "/.."
os.chdir(proj_root)
sys.path.append(proj_root)

In [2]:
import pandas as pd
import torch

In [3]:
from src.models import train, config, generate_predictions

config = config.Config()
metadata = pd.read_csv(config.dataset_metadata)
predictions = pd.read_csv(config.model_predictions)

g = torch.Generator(device=train.device).manual_seed(0)

In [4]:
def per_well_accuracy(sampler, average=False):
    sampled = predictions[metadata.index.isin(sampler)].copy()
    compounds = ['Berberine Chloride','Brefeldin A','DFSO','Fluphenazine','Latrunculin B','Nocodazole','Rapamycin','Rotenone','Tetrandrine']
    sampled['preds'] = sampled[compounds].to_dict(orient='records')
    sampled = sampled.drop(columns=compounds)
    grouped = sampled.groupby(metadata['well_id'])
    if average:
        pred = grouped.preds.agg(lambda x: pd.DataFrame(list(x)).mean().idxmax())
    else: #majority
        pred = grouped.predicted_compound.agg(lambda x: x.value_counts().index[0])
    true = grouped.true_compound.agg(lambda x: list(x)[0])
    correct = sum(pred==true)
    total = len(true)
    return correct/total

In [5]:
for av in False, True:
    for name, sampler in zip(["Training", "Validation", "Test"], train.get_stratified_sampler(config, g)):
        print(f"{name} dataset (per-well, {'average' if av else 'majority'}):")
        print(f"\t{per_well_accuracy(sampler, av):.3%}")
    print()

Training dataset (per-well, majority):
	98.810%
Validation dataset (per-well, majority):
	95.833%
Test dataset (per-well, majority):
	98.958%

Training dataset (per-well, average):
	99.405%
Validation dataset (per-well, average):
	95.833%
Test dataset (per-well, average):
	97.917%

