In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [1]:
!pip install -Uqq fastai

In [3]:
from fastai.vision.all import *

In [4]:
path = Path()/'data'

In [5]:
df = pd.read_csv(path/'labels.csv')
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
print(len(df))

trn_df=trn=val_df=val=pd.DataFrame(data=None, columns=df.columns)
for i in df.tools_present.unique():
    df_tmp = df.loc[df['tools_present']==i]
    trn,val=np.split(df_tmp, [int(.8*len(df_tmp))])
    if len(trn)>0: trn_df=pd.concat([trn_df, trn])
    else: trn_df=pd.concat([trn_df, val])
    if len(trn)>0: val_df=pd.concat([val_df, val])
    
trn_df['valid']=False
val_df['valid']=True
df = pd.concat([trn_df,val_df])
df.sort_values(by='clip_name', inplace=True)
df.set_index('clip_name', inplace=True)
print(len(df))


24695
24695


In [6]:
def split_func(f):
    return df.loc[parent_label(f), 'valid']

def get_usm0_tool_lbl(k):
    return re.sub(r"[\[\]]",'',df.loc[parent_label(k),'tools_present']).split(',')[0].strip()

def get_usm1_tool_lbl(k):
    return re.sub(r"[\[\]]",'',df.loc[parent_label(k),'tools_present']).split(',')[1].strip()

def get_usm2_tool_lbl(k):
    return re.sub(r"[\[\]]",'',df.loc[parent_label(k),'tools_present']).split(',')[2].strip()

def get_usm3_tool_lbl(k):
    return re.sub(r"[\[\]]",'',df.loc[parent_label(k),'tools_present']).split(',')[3].strip()

In [13]:
def train(method, arch, item, batch, accum=False):
    # defining the structure of the block and creating data loaders
    dls = DataBlock(
        blocks=(ImageBlock, CategoryBlock, CategoryBlock, CategoryBlock, CategoryBlock),
        n_inp=1,
        get_items=get_image_files,
        get_y=[get_usm0_tool_lbl,get_usm1_tool_lbl,get_usm2_tool_lbl,get_usm3_tool_lbl],
        splitter=FuncSplitter(split_func),
        item_tfms=item, batch_tfms=batch
    ).dataloaders(path/'train_images_small', num_workers=8)
    
    # deriving the shape of output vectors from vocab
    ns = tensor([len(v)for v in dls.vocab])
    def cfg (i): return ns[:i].sum().item()
    
    # defining error rate for each robotic hand tools
    def usm1_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return error_rate(preds[:,:cfg(1)], usm1_targs)
    def usm2_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return error_rate(preds[:,cfg(1):cfg(2)], usm2_targs)
    def usm3_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return error_rate(preds[:,cfg(2):cfg(3)], usm3_targs)
    def usm4_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return error_rate(preds[:,cfg(3):cfg(4)], usm4_targs)

    # defining combined error rate 
    def combo_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): 
        return usm1_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm2_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm3_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm4_err(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)
    
    # defining error rate for each robotic hand tools for raw preds from the learner 
    def usm1_err_raw(preds,targs): return error_rate(preds[:,:cfg(1)].softmax(dim=1).argmax(dim=1), targs)
    def usm2_err_raw(preds,targs): return error_rate(preds[:,cfg(1):cfg(2)].softmax(dim=1).argmax(dim=1), targs)
    def usm3_err_raw(preds,targs): return error_rate(preds[:,cfg(2):cfg(3)].softmax(dim=1).argmax(dim=1), targs)
    def usm4_err_raw(preds,targs): return error_rate(preds[:,cfg(3):cfg(4)].softmax(dim=1).argmax(dim=1), targs)

    # defining loss function for each robotic hand tools
    def usm1_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return CrossEntropyLossFlat()(preds[:,:cfg(1)], usm1_targs)
    def usm2_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return CrossEntropyLossFlat()(preds[:,cfg(1):cfg(2)], usm2_targs)
    def usm3_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return CrossEntropyLossFlat()(preds[:,cfg(2):cfg(3)], usm3_targs)
    def usm4_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): return CrossEntropyLossFlat()(preds[:,cfg(3):cfg(4)], usm4_targs)

    # defining combined loss
    def combo_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs): 
        return usm1_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm2_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm3_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)+usm4_loss(preds,usm1_targs,usm2_targs,usm3_targs,usm4_targs)

    # configuring metrics and loss for learner
    metrics_cfg = [usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err, combo_err]
    
    # defining the learner and kick starting the training
    learn = vision_learner(dls, arch, loss_func=combo_loss, metrics=metrics_cfg, n_out=cfg(4)).to_fp16()
    learn.fine_tune(12, 0.01)
    
    # validating the model
#     probs,_,idxs = learn.get_preds(dl=dls.valid, with_decoded=True)
    
#     return usm1_err_raw(idxs,dls.valid
    

In [35]:
exps = {'vit_small_patch16_224':(Resize((180,320), method='squish'), aug_transforms(size=224, min_scale=1)),
        'convnext_small_in22k':(Resize((180,320), method='squish'), aug_transforms(size=(180,320), min_scale=1)),
        'swinv2_base_window12_192_22k':(Resize((180,320), method='squish'), aug_transforms(size=192, min_scale=1)),
        'swin_small_patch4_window7_224': (Resize((180,320), method='squish'), aug_transforms(size=224, min_scale=1))
       }

In [36]:
for arch, conf in exps.items():
    print(arch)
    train('squish', arch, item=conf[0], batch=conf[1], accum=False)

vit_small_patch16_224


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,0.985189,1.334285,0.389099,0.228115,0.311131,0.405941,0.129002,0.068229,0.089616,0.127993,0.41484,15:10


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,0.228196,0.683772,0.219756,0.099113,0.144197,0.220707,0.051473,0.026424,0.035144,0.061461,0.174502,18:39
1,0.178918,0.745121,0.237934,0.099721,0.150431,0.257035,0.045152,0.025242,0.032832,0.055755,0.158981,18:40
2,0.138427,0.843635,0.284926,0.160327,0.151084,0.247297,0.057132,0.02603,0.037696,0.0486,0.169457,18:47
3,0.100924,0.899314,0.313958,0.171281,0.137353,0.276722,0.058428,0.023979,0.032524,0.058421,0.173352,18:44
4,0.093587,1.179546,0.417256,0.165967,0.223051,0.373271,0.057512,0.028121,0.033794,0.05774,0.177167,18:41
5,0.067972,0.853327,0.308973,0.117271,0.148895,0.278186,0.047992,0.018126,0.025416,0.049348,0.140881,18:37
6,0.04343,1.006356,0.362817,0.167784,0.163785,0.311968,0.051506,0.022402,0.030473,0.042159,0.14654,18:35
7,0.030442,0.835622,0.290183,0.077206,0.159413,0.308819,0.034569,0.012354,0.023805,0.039673,0.110401,18:35
8,0.031523,0.972724,0.325474,0.130072,0.19693,0.320248,0.04113,0.016924,0.027173,0.041143,0.12637,18:34
9,0.023727,1.036291,0.365961,0.152585,0.195538,0.322206,0.045132,0.019229,0.025329,0.041437,0.131127,18:47


convnext_small_in22k


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,0.862261,1.772771,0.483104,0.339018,0.40448,0.546168,0.131768,0.077276,0.100299,0.137935,0.447278,28:11


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,0.191742,1.23254,0.360211,0.213876,0.282692,0.37576,0.079901,0.043301,0.056737,0.087832,0.267772,37:39
1,0.187652,1.457065,0.409693,0.271836,0.31521,0.460326,0.066626,0.03803,0.043929,0.081151,0.229736,37:41
2,0.112578,1.452381,0.430085,0.245994,0.31407,0.462233,0.07461,0.0368,0.048967,0.088113,0.24849,37:49
3,0.073479,1.40748,0.403523,0.245317,0.282192,0.476447,0.054265,0.024447,0.033994,0.063145,0.175851,37:47
4,0.085458,1.387594,0.405217,0.258076,0.269533,0.45477,0.056136,0.029979,0.035892,0.058067,0.180074,37:51
5,0.044766,1.501917,0.41669,0.299424,0.278278,0.507521,0.052695,0.025516,0.033233,0.059677,0.171121,37:41
6,0.041736,1.262205,0.372701,0.258867,0.242742,0.387896,0.051038,0.027039,0.03501,0.051406,0.164493,37:40
7,0.034461,1.300774,0.362024,0.272014,0.231924,0.434814,0.05478,0.02895,0.039466,0.062009,0.185205,37:40
8,0.027413,1.205919,0.334733,0.263851,0.217599,0.389735,0.04737,0.028656,0.037856,0.052501,0.166384,37:41
9,0.020503,1.336686,0.378046,0.276068,0.267862,0.414708,0.050076,0.026418,0.035892,0.051272,0.163658,37:45


swinv2_base_window12_192_22k


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth" to /home/bilal/.cache/torch/hub/checkpoints/swinv2_base_patch4_window12_192_22k.pth


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,0.843905,1.209214,0.323877,0.215703,0.289238,0.380398,0.090137,0.050905,0.063933,0.101789,0.306764,48:11


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,0.155755,0.643668,0.192409,0.108843,0.138494,0.20392,0.041277,0.023364,0.028522,0.045954,0.139118,1:00:19
1,0.111803,0.742297,0.215481,0.093621,0.148827,0.284368,0.040248,0.022623,0.026551,0.05039,0.139812,1:00:15
2,0.08816,0.876818,0.260183,0.124045,0.192446,0.300144,0.035197,0.020585,0.03066,0.045994,0.132436,1:00:10
3,0.100049,0.891822,0.245939,0.125638,0.170721,0.349525,0.039673,0.020852,0.030186,0.056323,0.147035,1:00:09
4,0.07758,0.928355,0.276279,0.151974,0.167811,0.332289,0.042286,0.019316,0.023224,0.048793,0.133619,1:00:10
5,0.06513,0.920523,0.308609,0.108493,0.204573,0.298846,0.034756,0.016864,0.027634,0.037936,0.11719,1:00:09
6,0.054776,0.964783,0.316775,0.149371,0.165615,0.33302,0.03638,0.01979,0.022723,0.039286,0.118178,1:00:10
7,0.052881,0.854893,0.244873,0.109648,0.164627,0.335743,0.030193,0.014679,0.021086,0.036172,0.10213,1:00:08
8,0.03281,0.886465,0.269043,0.147637,0.170211,0.299572,0.041985,0.020438,0.024039,0.043776,0.130238,1:00:08
9,0.023796,0.917658,0.277974,0.125694,0.181606,0.332384,0.033901,0.018541,0.021981,0.03636,0.110782,1:00:03


swin_small_patch4_window7_224


Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth" to /home/bilal/.cache/torch/hub/checkpoints/swin_small_patch4_window7_224.pth


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,1.035699,1.505884,0.408932,0.268584,0.333735,0.494632,0.128875,0.075231,0.097079,0.135363,0.436548,32:05


epoch,train_loss,valid_loss,usm1_loss,usm2_loss,usm3_loss,usm4_loss,usm1_err,usm2_err,usm3_err,usm4_err,combo_err,time
0,0.224984,0.747932,0.211407,0.115915,0.182303,0.238307,0.041818,0.026698,0.035351,0.052862,0.156729,39:51
1,0.150327,0.794892,0.225713,0.137558,0.156242,0.275378,0.042005,0.024661,0.029471,0.053143,0.14928,39:57
2,0.136012,0.799476,0.277132,0.097862,0.160357,0.264124,0.0425,0.019235,0.030413,0.041798,0.133946,39:49
3,0.088594,0.796553,0.267167,0.081041,0.163461,0.284883,0.036994,0.017057,0.020712,0.043889,0.118653,39:58
4,0.084798,0.735136,0.232959,0.077948,0.136212,0.288017,0.039079,0.014893,0.024661,0.044103,0.122735,40:03
5,0.057567,0.677925,0.23046,0.078098,0.14684,0.222527,0.03658,0.012681,0.024093,0.037883,0.111237,40:03
6,0.03861,1.070599,0.388204,0.225094,0.18233,0.274971,0.042032,0.021066,0.021627,0.043609,0.128334,39:56
7,0.033546,0.847059,0.296497,0.103106,0.168666,0.278788,0.03489,0.015006,0.021607,0.040448,0.111951,40:00
8,0.036644,0.88135,0.283273,0.106365,0.179685,0.312026,0.03966,0.016723,0.023558,0.038324,0.118265,40:04
9,0.033147,0.864628,0.271619,0.083858,0.181586,0.327564,0.031669,0.012875,0.022449,0.039426,0.106419,39:55
