# Setup

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai import *
from fastai.vision import *

In [3]:
gs_path = Path('./GS')
gs_path.absolute(), gs_path.exists()

(PosixPath('/home/ubuntu/off-sample/GS'), True)

# Prepare Data

In [9]:
# Check that the GS dataframe and the ion images match

In [10]:
gs_df = pd.read_csv('GS.csv')[['dsName', 'sumFormula', 'adduct', 'type']]
gs_df['dsName'] = gs_df.dsName.map(lambda name: name.replace('/', '_'))
gs_df['ion'] = gs_df.sumFormula + gs_df.adduct
gs_df.drop(labels=['sumFormula', 'adduct'], axis=1, inplace=True)
gs_df.sort_values(by=['dsName', 'type', 'ion'], inplace=True)
gs_df.reset_index(drop=True, inplace=True)
gs_df.head()

Unnamed: 0,dsName,type,ion
0,100um_noM2_001_Recal,off,C14H20N6O5S+H
1,100um_noM2_001_Recal,off,C15H19NO10+H
2,100um_noM2_001_Recal,off,C15H26N2O12+H
3,100um_noM2_001_Recal,off,C16H16O11S+Na
4,100um_noM2_001_Recal,off,C16H19N3O5S+H


In [18]:
rows = []
for ds_path in gs_path.iterdir():
    ds_name = ds_path.name
    for t in ['on', 'off']:
        for image_path in (ds_path / t).iterdir():
            ion = image_path.name.split('.')[0]
            rows.append([ds_name, t, ion])
df = (pd.DataFrame(rows, columns=['dsName', 'type', 'ion'])
        .sort_values(by=['dsName', 'type', 'ion'])
        .reset_index(drop=True))

In [19]:
np.all(gs_df == df)

True

In [20]:
# Number of datasets available
len(list(gs_path.iterdir()))

87

In [21]:
row_list = []
for ds_path in gs_path.iterdir():
    for cl in ['on', 'off']:
        for p in (ds_path / cl).iterdir():
            row_list.append([str(p), ds_path.name, cl])

image_df = pd.DataFrame(row_list, columns=['name', 'group', 'label'])
image_df.shape

(23238, 3)

In [22]:
image_df.head()

Unnamed: 0,name,group,label
0,GS/DESI quan_Swales/on/C21H43O6P+H.png,DESI quan_Swales,on
1,GS/DESI quan_Swales/on/C4H7O8P+Na.png,DESI quan_Swales,on
2,GS/DESI quan_Swales/on/C21H39O7P+H.png,DESI quan_Swales,on
3,GS/DESI quan_Swales/on/C10H11NO3+Na.png,DESI quan_Swales,on
4,GS/DESI quan_Swales/on/C24H40O4+Na.png,DESI quan_Swales,on


In [25]:
all_groups = list(image_df.group.unique())
print(len(all_groups))

87


In [29]:
pd.value_counts(image_df.label)

on     13329
off     9909
Name: label, dtype: int64

# Train and Save Model

In [38]:
from sklearn.metrics import f1_score

In [32]:
src = (ImageItemList.from_df(image_df, '', cols=0)
#        .split_by_idxs(train_inds, valid_inds)
       .no_split()  # train - all data, valid - empty
       .label_from_df(cols=2))

In [34]:
tfms = get_transforms(flip_vert=True, max_rotate=15)

In [35]:
data = (src.transform(tfms, size=224, resize_method=ResizeMethod.SQUISH, padding_mode='reflection')
        .databunch()
        .normalize(imagenet_stats))

In [62]:
data.train_ds.classes, data.classes

(['off', 'on'], ['off', 'on'])

In [39]:
arch = models.resnet50
learn = create_cnn(data, arch, metrics=[accuracy, f1_score], ps=0.5)

In [40]:
lr = 3e-3
learn.fit_one_cycle(5, slice(lr), wd=0.1)

epoch,train_loss,valid_loss,accuracy,f1_score
1,0.143511,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2,0.099839,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
3,0.079272,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3
4,0.069701,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4
5,0.063997,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5


In [41]:
learn.unfreeze()
div = 10
learn.fit_one_cycle(5, slice(1e-5, lr/div), wd=0.1)

epoch,train_loss,valid_loss,accuracy,f1_score
1,0.063784,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2,0.064955,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
3,0.049076,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3
4,0.038839,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4
5,0.032359,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5


In [42]:
export_fn = 'resnet-50.pkl'
learn.export(export_fn)

In [48]:
learn.path

PosixPath('.')

# Test Inference

In [95]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

In [44]:
ds_name = '100um_noM2_001_Recal'
test_df = image_df[image_df.group == ds_name]

In [46]:
test_itemlist = (ImageItemList.from_df(test_df, '', cols=0))
test_itemlist

ImageItemList (434 items)
Image (3, 119, 254),Image (3, 119, 254),Image (3, 119, 254),Image (3, 119, 254),Image (3, 119, 254)
Path: .

In [49]:
learn = load_learner('.', export_fn, test=test_itemlist)

In [83]:
pred_probs, _ = learn.get_preds(DatasetType.Test)

In [84]:
pred_probs.shape

torch.Size([434, 2])

In [85]:
learn.data.classes

['off', 'on']

In [86]:
preds = pred_probs[:,0] > 0.5

In [92]:
y = (test_df.label == 'off').astype(int).values
y.shape

(434,)

In [94]:
accuracy_score(y_true=y, y_pred=preds), f1_score(y_true=y, y_pred=preds)

(0.9700460829493087, 0.9304812834224598)

In [98]:
confusion_matrix(y, preds)

array([[334,   2],
       [ 11,  87]])

In [None]:
learn.pred_batch()