## Multi-label prediction with All drugs in the 12 MeSH classes

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.vision import *
from fastai.datasets import Config
import pandas as pd
import numpy as np
from fastai.vision.data import Image
from functools import partial
import torch
from fastai.metrics import *
from fastai import utils
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, matthews_corrcoef, auc
from sklearn.metrics import balanced_accuracy_score, accuracy_score, average_precision_score
from sklearn.preprocessing import OneHotEncoder, LabelBinarizer, MultiLabelBinarizer, label_binarize

In [3]:
torch.cuda.set_device(1)   ### use first GPU if you have many
torch.cuda.is_available()

True

## Getting the data

In [4]:
PATH = '/home/jgmeyer2/drugclass/multiclass_data/'

## Multiclassification

In [5]:
df = pd.read_csv(f'{PATH}all_chem_df.csv')
#df.drop('', axis=1)
df.head()

Unnamed: 0,image_name,tags,smiles,Col3
0,pics/0,dermatologic,CC(=O)NC1C(O)OC(CO)C(O)C1O,['dermatologic']
1,pics/1,antiinfective,CCC[C@@]1(CCc2ccccc2)CC(O)=C([C@H](CC)c2cccc(N...,['antiinfective']
2,pics/2,antiinfective,CCCCC(C)C(=O)OC1C(C)C(CC)OC2(CC3CC(C/C=C(\C)CC...,['antiinfective']
3,pics/3,antineoplastic,COc1cc2c(c(OC)c1OC)-c1c(cc3c(c1OC)OCO3)C[C@H](...,['antineoplastic']
4,pics/4,antiinfective respiratorysystem,CC(=O)N[C@@H](CS)C(=O)[O-],"['antiinfective', 'respiratorysystem']"


In [6]:
# print unique labels
print(len(set([x.split(' ')[0] for x in df['tags']])))
set([x.split(' ')[0] for x in df['tags']])


12


{'antiinfective',
 'antiinflammatory',
 'antineoplastic',
 'cardio',
 'cns',
 'dermatologic',
 'gastrointestinal',
 'hematologic',
 'lipidregulating',
 'reproductivecontrol',
 'respiratorysystem',
 'urological'}

To put this in a `DataBunch` while using the [data block API](https://docs.fast.ai/data_block.html), we then need to using `ImageList` (and not `ImageDataBunch`). This will make sure the model created has the proper loss function to deal with the multiple classes.

In [7]:
tfms = get_transforms(do_flip=True, flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0)

In [8]:
def get_val_idx_fromfile(validx_csv):
    validx_df =pd.read_csv(validx_csv, header=None)
    return validx_df[0].tolist()

In [13]:
vidx = get_val_idx_fromfile(PATH+'multilabel_iter5fold_4.csv')

In [14]:
len(vidx)

1670

In [15]:
np.random.seed(42)
src = (ImageItemList.from_csv(PATH, 'all_chem_df.csv', folder='', suffix='.png')
       .split_by_idx(vidx)
       .label_from_df(label_delim=' '))

In [16]:
data = (src.transform(tfms, size=256)
        .databunch().normalize(imagenet_stats))

In [18]:
data.batch_size = 40
data.batch_size

40

In [19]:
arch = models.resnet50

In [20]:
acc_50 = partial(accuracy_thresh, thresh=0.5, sigmoid=False)
fs = partial(fbeta, thresh=0.5, sigmoid=False)
learn = create_cnn(data, arch, metrics=[acc_50, fs], ps=0.4)

In [None]:
np.random.seed(42)
acc_50 = partial(accuracy_thresh, thresh=0.5, sigmoid=False)
fs = partial(fbeta, thresh=0.5, sigmoid=False)
act = []
fbe = []
roc = []
avp = []
for rep in range(0, 5):
    print(rep)
    vidx = get_val_idx_fromfile(PATH+'multilabel_iter5fold_'+str(rep)+'.csv')
    data = (ImageItemList.from_csv(PATH, 'all_chem_df.csv', folder='', suffix='.png')
            .split_by_idx(vidx).label_from_df(label_delim=' ')
            .transform(tfms, size=256)
            .databunch().normalize(imagenet_stats))
    data.batch_size = 40
    data.num_workers = 1
    arch = models.resnet50
    learn = create_cnn(data, arch, metrics=[acc_50, fs], ps=0.4)
    learn.unfreeze()
    learn.fit_one_cycle(127, slice(1e-4, 1e-2))
    # change to use get_preds because it works better than TTA
    y_preds, y = learn.get_preds(ds_type=DatasetType.Valid)
    act.append(accuracy_thresh(y_preds, y, thresh= 0.5, sigmoid=False).item())
    fbe.append(fbeta(y_preds, y, thresh = 0.5, sigmoid=False).item())
    roc.append(roc_auc_score(y, y_preds, average="weighted"))
    avp.append(average_precision_score(y, y_preds, average="weighted"))

In [None]:
print('acc thresholded mean '+str(np.mean(act))+'+/-'+str(np.std(act)))
print('F beta mean '+str(np.mean(fbe))+'+/-'+str(np.std(fbe)))
print('ROC AUC mean '+str(np.mean(roc))+'+/-'+str(np.std(roc)))
print('Ave Prec Score mean '+str(np.mean(avp))+'+/-'+str(np.std(avp)))

In [166]:
def get_y_pred(y_prob, thresh):
    return (y_prob>thresh)

In [169]:
preds = get_y_pred(y_preds, thresh)

In [155]:
## convert binary predictions to list of classes in text format
thresh = 0.5
labelled_preds = [' '.join([learn.data.classes[i] for i,p in enumerate(pred) if p > thresh]) for pred in preds]

In [182]:
### write predictions for network analysis
with open('multiclass_validix5pred.csv', 'w') as f:
    writer = csv.writer(f, delimiter=',')
    writer.writerows(map(lambda x: [x], labelled_preds))

In [183]:
learn.save('multiclass_fold5final.model')

## fin