In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [2]:
from fastai.vision.all import *

In [3]:
path = Path()/'data'

In [4]:
df = pd.read_csv(path/'labels.csv')
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]

trn_df=trn=val_df=val=pd.DataFrame(data=None, columns=df.columns)
for i in df.tools_present.unique():
    df_tmp = df.loc[df['tools_present']==i]
    trn,val=np.split(df_tmp, [int(.8*len(df_tmp))])
    if len(trn)>0: trn_df=pd.concat([trn_df, trn])
    else: trn_df=pd.concat([trn_df, val])
    if len(trn)>0: val_df=pd.concat([val_df, val])
    
trn_df['valid']=False
val_df['valid']=True
df = pd.concat([trn_df,val_df])
df.sort_values(by='clip_name', inplace=True)
df.set_index('clip_name', inplace=True)
print(len(df))

24695


In [5]:
def split_func(f):
    return df.loc[parent_label(f), 'valid']

def get_usm0_tool_lbl(k):
    return re.sub(r"[\[\]]",'',df.loc[parent_label(k),'tools_present']).split(',')[0].strip()

def get_usm1_tool_lbl(k):
    return re.sub(r"[\[\]]",'',df.loc[parent_label(k),'tools_present']).split(',')[1].strip()

def get_usm2_tool_lbl(k):
    return re.sub(r"[\[\]]",'',df.loc[parent_label(k),'tools_present']).split(',')[2].strip()

def get_usm3_tool_lbl(k):
    return re.sub(r"[\[\]]",'',df.loc[parent_label(k),'tools_present']).split(',')[3].strip()

In [9]:
def train(exp, arch, item, batch, accum=False):

    # defining the structure of the block and creating data loaders
    dls = DataBlock(
        blocks=(ImageBlock, CategoryBlock, CategoryBlock, CategoryBlock, CategoryBlock),
        n_inp=1,
        get_items=get_image_files,
        get_y=[get_usm0_tool_lbl,get_usm1_tool_lbl,get_usm2_tool_lbl,get_usm3_tool_lbl],
        splitter=FuncSplitter(split_func),
        item_tfms=item, batch_tfms=batch
    ).dataloaders(path/'train_images_small', seed=42, num_workers=8)
    
    def cfg (i): return dls.c[:i].sum()

    # defining error rate for each robotic hand tools
    def usm1_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return error_rate(preds[:,:cfg(1)], usm1_targs)
    def usm2_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return error_rate(preds[:,cfg(1):cfg(2)], usm2_targs)
    def usm3_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return error_rate(preds[:,cfg(2):cfg(3)], usm3_targs)
    def usm4_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return error_rate(preds[:,cfg(3):cfg(4)], usm4_targs)

    # defining combined error rate 
    def combo_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): 
        return usm1_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm2_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm3_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm4_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)

    # defining error rate for each robotic hand tools for raw preds from the learner 
    def usm1_err_raw(preds,targs): return error_rate(preds[:,:cfg(1)].softmax(dim=1).argmax(dim=1), targs)
    def usm2_err_raw(preds,targs): return error_rate(preds[:,cfg(1):cfg(2)].softmax(dim=1).argmax(dim=1), targs)
    def usm3_err_raw(preds,targs): return error_rate(preds[:,cfg(2):cfg(3)].softmax(dim=1).argmax(dim=1), targs)
    def usm4_err_raw(preds,targs): return error_rate(preds[:,cfg(3):cfg(4)].softmax(dim=1).argmax(dim=1), targs)

    # defining loss function for each robotic hand tools
    def usm1_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return CrossEntropyLossFlat()(preds[:,:cfg(1)], usm1_targs)
    def usm2_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return CrossEntropyLossFlat()(preds[:,cfg(1):cfg(2)], usm2_targs)
    def usm3_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return CrossEntropyLossFlat()(preds[:,cfg(2):cfg(3)], usm3_targs)
    def usm4_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return CrossEntropyLossFlat()(preds[:,cfg(3):cfg(4)], usm4_targs)

    # defining combined loss
    def combo_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): 
        return usm1_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm2_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm3_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm4_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)

    # configuring metrics and loss for learner
    metrics_cfg = [usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err, combo_err]

    # validating the model (normal vs tta)
    def usm_err_raw(preds,targs): return error_rate(preds, targs)
    def combo_err_raw(preds, targs): 
        return usm_err_raw(preds[:,:cfg(1)].softmax(dim=1),targs[0])+usm_err_raw(preds[:,cfg(1):cfg(2)].softmax(dim=1),targs[1])+usm_err_raw(preds[:,cfg(2):cfg(3)].softmax(dim=1),targs[2])+usm_err_raw(preds[:,cfg(3):cfg(4)].softmax(dim=1),targs[3])

           
    # defining the learner and kick starting the training
    learn = vision_learner(dls, arch, loss_func=combo_loss, metrics=metrics_cfg, n_out=cfg(4)).to_fp16()
    learn.fine_tune(12, 0.01)
    
    # validating model performance
    preds_err=combo_err_raw(*learn.get_preds(dl=dls.valid))
    tta_err=combo_err_raw(*learn.tta(dl=dls.valid))
    
    # saving the model
    # learn.path=Path('models')
    # learn.export(f'{exp}_{arch}_crop.pkl')
    
    return preds_err, tta_err
    

In [10]:
res = (180,320)

archs = {
    'vit_small_patch16_224': {
        (Resize(res, method='crop'), 224),
    },
    'convnext_small_in22k': {
        (Resize(res, method='crop'), (180,320)),
    },
    'swinv2_base_window12_192_22k':{
        (Resize(res, method='crop'), 192),
    },
    'swin_small_patch4_window7_224': {
        (Resize(res, method='crop'), 224)
    }
}

In [11]:
for arch, details in archs.items():
    for item, size in details:
        print('----',arch, ' --- ', size)
        l = train('crop', arch, item=item, batch=aug_transforms(size=size, min_scale=1), accum=False)
        print(l)

---- vit_small_patch16_224  ---  224


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,1.956138,1.737473,0.493824,0.259678,0.382307,0.601665,0.155326,0.075973,0.105384,0.180595,0.517278,15:38


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,1.844711,1.692424,0.481075,0.25392,0.366467,0.590961,0.155961,0.073935,0.101001,0.181737,0.512634,19:55
1,1.809279,1.651017,0.47067,0.24394,0.354979,0.581426,0.147603,0.072338,0.098268,0.174715,0.492925,18:50
2,1.837868,1.60408,0.455258,0.237304,0.344261,0.567257,0.142652,0.070147,0.094727,0.169865,0.477391,18:07
3,1.738854,1.526688,0.42696,0.222698,0.323927,0.553105,0.139472,0.065938,0.089409,0.169745,0.464563,18:08
4,1.774711,1.486786,0.423291,0.211962,0.317889,0.533645,0.131761,0.065169,0.088039,0.163411,0.44838,18:08
5,1.688898,1.465806,0.413072,0.210009,0.307165,0.53556,0.129363,0.064134,0.086329,0.164673,0.444499,18:08
6,1.660696,1.419305,0.397,0.203447,0.297862,0.520997,0.121706,0.064093,0.082581,0.162529,0.430909,18:08
7,1.581838,1.400658,0.398333,0.198999,0.290966,0.51236,0.124439,0.062657,0.080215,0.160498,0.427809,18:08
8,1.621,1.416838,0.398843,0.20236,0.297561,0.518074,0.12206,0.063693,0.081398,0.161807,0.428958,18:09
9,1.68622,1.40573,0.398604,0.200065,0.293571,0.513489,0.122902,0.062911,0.080676,0.157658,0.424147,18:08


---- convnext_small_in22k  ---  (180, 320)


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,1.816578,1.995279,0.553028,0.318364,0.452799,0.671087,0.168923,0.091407,0.121352,0.195661,0.577342,27:28


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,1.630949,2.077107,0.564695,0.341067,0.471712,0.699631,0.162569,0.091834,0.120283,0.192681,0.567367,36:49
1,1.6072,2.08505,0.569313,0.348247,0.474531,0.692958,0.166544,0.091513,0.119207,0.192588,0.569853,36:51
2,1.56,2.027772,0.548214,0.336784,0.464894,0.677876,0.16174,0.089883,0.117016,0.19202,0.560659,36:57
3,1.49241,1.934694,0.529814,0.315214,0.440096,0.649569,0.153509,0.086142,0.110662,0.181704,0.532017,36:55
4,1.44023,1.951909,0.52549,0.327591,0.445937,0.65289,0.152908,0.087164,0.11133,0.179225,0.530627,36:56
5,1.47366,1.819003,0.49426,0.300698,0.404919,0.619126,0.145832,0.081719,0.102183,0.172945,0.502679,36:53
6,1.347008,1.794024,0.486048,0.291815,0.400012,0.61615,0.144797,0.078966,0.099037,0.171422,0.494221,36:55
7,1.317958,1.790235,0.478865,0.296216,0.406088,0.609065,0.140821,0.079427,0.101108,0.169865,0.491221,37:08
8,1.339225,1.798666,0.483696,0.301309,0.402034,0.611628,0.142732,0.079734,0.099538,0.169023,0.491027,37:54
9,1.288875,1.766602,0.479649,0.291654,0.393864,0.601436,0.142792,0.077997,0.09774,0.168141,0.486671,37:28


---- swinv2_base_window12_192_22k  ---  192


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,1.979236,1.836246,0.527922,0.294273,0.410086,0.603965,0.16303,0.091627,0.11133,0.181811,0.547798,48:34


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,1.870499,1.783541,0.515723,0.285711,0.401626,0.580482,0.161059,0.086763,0.112546,0.171842,0.53221,1:00:16
1,1.725782,1.738379,0.504462,0.281662,0.388378,0.563878,0.157524,0.08671,0.10845,0.16943,0.522115,59:43
2,1.790488,1.671733,0.486545,0.263525,0.373174,0.548489,0.1434,0.080276,0.098976,0.165148,0.4878,59:22
3,1.745376,1.635582,0.475211,0.256876,0.367193,0.536302,0.14308,0.078485,0.096157,0.160371,0.478092,59:52
4,1.684959,1.610723,0.462633,0.250534,0.367434,0.530123,0.133144,0.074757,0.09603,0.15842,0.462351,59:22
5,1.613104,1.561861,0.449008,0.240053,0.355271,0.51753,0.130024,0.074436,0.098335,0.154525,0.45732,1:00:12
6,1.615857,1.514321,0.43214,0.23363,0.341958,0.506592,0.130372,0.072939,0.095115,0.151404,0.44983,1:00:54
7,1.63129,1.469901,0.418867,0.224932,0.328645,0.497454,0.122408,0.071864,0.087177,0.151612,0.43306,1:00:42
8,1.542051,1.462824,0.417603,0.222225,0.330807,0.492189,0.120998,0.070521,0.093271,0.150917,0.435706,1:01:23
9,1.526773,1.450051,0.41639,0.222911,0.321647,0.489105,0.121486,0.071683,0.085367,0.150382,0.428918,59:24


---- swin_small_patch4_window7_224  ---  224


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,2.114624,1.997372,0.566947,0.326752,0.441614,0.662059,0.180441,0.107381,0.124105,0.191098,0.603025,31:33


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,2.076123,1.966221,0.554171,0.324691,0.432758,0.654603,0.175758,0.105872,0.120083,0.188746,0.590458,39:08
1,2.032273,1.963037,0.557243,0.321938,0.428393,0.655464,0.175096,0.101849,0.118265,0.188806,0.584017,39:09
2,2.072554,1.931912,0.545745,0.317234,0.424842,0.644091,0.173105,0.10221,0.116762,0.186087,0.578164,39:06
3,2.007726,1.917529,0.540714,0.312604,0.421469,0.642746,0.169163,0.099985,0.115392,0.193276,0.577817,39:07
4,1.956358,1.889285,0.530046,0.310127,0.414582,0.634529,0.16452,0.100506,0.113127,0.183675,0.561829,39:08
5,1.912752,1.901695,0.536328,0.313496,0.418054,0.633817,0.168762,0.1008,0.113755,0.193236,0.576554,39:08
6,1.873968,1.891181,0.5353,0.312986,0.413909,0.628985,0.166738,0.097453,0.110348,0.181958,0.556497,39:11
7,1.868657,1.872571,0.52976,0.308672,0.409666,0.624471,0.163344,0.097266,0.110341,0.182773,0.553724,39:07
8,1.816563,1.845184,0.520206,0.302158,0.405922,0.616897,0.15876,0.094086,0.109352,0.182352,0.544551,39:07
9,1.865395,1.845728,0.520932,0.302037,0.40529,0.61747,0.160344,0.095542,0.108651,0.184283,0.54882,39:08
