In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

from BayesMultiScaleCNN import BayesMultiScaleCNN
from BBBC021 import BBBC021
from utils import confidence_score

In [None]:
def validate(data_loader, model):
    model.eval()
    
    correct = 0
    pbar = tqdm(total=len(data_loader), desc="Validation", leave=False)
    with torch.no_grad():
        for input, target, _ in data_loader:
            input = input.to(device)
            target = target.to(device)

            output = model(input)
            _, labels = output.max(1)
            
            correct += labels.eq(target).sum()
            pbar.update(1)
        pbar.close()
            
    return correct * 100 / len(data_loader.dataset)

def train(data_loader, model, criterion, optimizer):
    model.train()
    
    beta = torch.tensor(1.0/len(data_loader)).to(device)
    correct = 0
    nlls = []
    kls = []
    pbar = tqdm(total=len(data_loader), desc="Loss: 0", leave=False)
    for input, target, _ in train_loader:
        input = input.to(device)
        target = target.to(device)

        output = model(input)
        nll = criterion(output, target)
        kl = model.kl
        elbo = nll + beta*kl
        
        _, labels = output.max(1)
        correct += labels.eq(target).sum().item()
        nlls.append(nll.item())
        kls.append(kl.item())

        optimizer.zero_grad()
        elbo.backward()
        optimizer.step()
        
        pbar.set_description(f"Loss: {elbo.item():.03f}")
        pbar.update(1)
    pbar.close()
        
    return nlls, kls, correct * 100 / len(data_loader.dataset)

In [None]:
if not os.path.exists("./LeaveOneCompoundOut"):
    os.mkdir("LeaveOneCompoundOut")

df = pd.DataFrame(
    columns=[
        "softmax_1", "softmax_2", "softmax_3",
        "softmax_4", "softmax_5", "softmax_6",
        "softmax_7", "softmax_8", "softmax_9",
        "softmax_10", "softmax_11", "softmax_12",
        "moa_pred", "confidence", "site",
        "well", "replicate", "plate",
        "compound", "concentration", "moa"
    ]
)
df.to_csv("LeaveOneCompoundOut/BBBC021_LeaveOneCompoundOut.csv")
    
n_epochs = 80
lr = 1e-2
batch_size = 32
milestones = [60]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

moa = BBBC021().dataset.MOA.copy()
moa.remove('null')
compounds = list(np.unique(BBBC021(moa=moa).dataset.compounds))

for compound in compounds:

    save_path = f'LeaveOneCompoundOut/{compound}.pt'
    if compound == 'mevinolin/lovastatin':
        save_path = 'LeaveOneCompoundOut/mevinolin-lovastatin.pt'

    train_compounds = compounds.copy()
    train_compounds.remove(compound)
    train_dataset = BBBC021(
        moa=moa,
        compound=train_compounds
    )
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        pin_memory=True
    )
    val_dataset = BBBC021(
        moa=moa,
        compound=compound
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        pin_memory=True
    )

    model = BayesMultiScaleCNN(
        n_outputs=len(moa)
    )
    model.to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=lr
    )
    _, counts = np.unique(train_dataset.dataset.moa, return_counts=True)
    weights = torch.tensor(len(train_dataset)/counts)
    weights = weights.type(torch.FloatTensor)
    weights = weights.to(device)
    criterion = torch.nn.CrossEntropyLoss(
        reduction="sum",
        weight=weights
    )

    nlls = []
    kls = []
    train_accs = []
    val_accs = []
    best_acc = 0
    pbar = tqdm(total=n_epochs, desc="Loss: 0 | Accuracy: 0 % [0 %]", leave=False)
    for epoch in range(n_epochs):

        epoch_nll, epoch_kl, train_acc = train(
            train_loader, model, criterion, optimizer
        )
        val_acc = validate(val_loader, model)

        if val_acc >= best_acc:
            best_acc = val_acc
            torch.save({
                'state_dict': model.state_dict(),
                'val_acc': val_acc,
                'train_acc': train_acc,
                'opt_state_dict' : optimizer.state_dict(),
                },
                save_path
            )

        nlls.append(epoch_nll)
        kls.append(epoch_kl)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        desc = f"Loss: {np.mean(epoch_nll)+np.mean(epoch_kl):.03f} | "
        desc += f"Accuracy: {train_acc:.02f} % [{val_acc:.02f} %]"
        pbar.set_description(desc)
        pbar.update(1)
    pbar.close()

    checkpoint = torch.load(save_path)
    checkpoint["nll"] = nlls
    checkpoint["kl"] = kls
    checkpoint["train_acc"] = train_accs
    checkpoint["val_acc"] = val_accs
    torch.save(
        checkpoint,
        save_path
    )

    model.load_state_dict(checkpoint['state_dict'])

    n_samples = 100
    with torch.no_grad():
        for input, _, metadata in val_loader:
            input = input.to(device)
            output_samples = np.empty((n_samples, len(input), len(moa)))
            for j in range(n_samples):
                output_samples[j, ...] = F.softmax(model(input), dim=1).cpu().numpy()

            outputs = output_samples.mean(axis=0)
            confidence = confidence_score(output_samples)
            labels = np.argmax(outputs, axis=1)

            df = pd.DataFrame(
                columns=[
                    "softmax_1", "softmax_2", "softmax_3",
                    "softmax_4", "softmax_5", "softmax_6",
                    "softmax_7", "softmax_8", "softmax_9",
                    "softmax_10", "softmax_11", "softmax_12",
                    "moa_pred", "confidence", "site",
                    "well", "replicate", "plate",
                    "compound", "concentration", "moa"
                ],
                index=range(len(input))
            )
            df.iloc[:len(input), :len(moa)] = outputs
            df.loc[pd.RangeIndex(0, len(input)), "moa_pred"] = labels
            df.loc[pd.RangeIndex(0, len(input)), "confidence"] = confidence
            df.loc[pd.RangeIndex(0, len(input)), "site"] = metadata[0][0].numpy()
            df.loc[pd.RangeIndex(0, len(input)), "well"] = metadata[0][1]
            df.loc[pd.RangeIndex(0, len(input)), "replicate"] = metadata[0][2].numpy()
            df.loc[pd.RangeIndex(0, len(input)), "plate"] = metadata[0][3]
            df.loc[pd.RangeIndex(0, len(input)), "compound"] = metadata[1][0]
            df.loc[pd.RangeIndex(0, len(input)), "concentration"] = metadata[1][1].numpy()
            df.loc[pd.RangeIndex(0, len(input)), "moa"] = metadata[1][2]

            df.to_csv(
                "LeaveOneCompoundOut/BBBC021_LeaveOneCompoundOut.csv",
                mode='a',
                index=False
            )