# Multitask

Recognizing variety -> recognizing disease

After changing head notebook, here is an easier way to do it.

In [1]:
from fastai.vision.all import *
from fastcore.parallel import *
path = Path()
trn_path = path/'train_images'

In [2]:
df = pd.read_csv('train.csv')
df.head()

Unnamed: 0,image_id,label,variety,age
0,100330.jpg,bacterial_leaf_blight,ADT45,45
1,100365.jpg,bacterial_leaf_blight,ADT45,45
2,100382.jpg,bacterial_leaf_blight,ADT45,45
3,100632.jpg,bacterial_leaf_blight,ADT45,45
4,101918.jpg,bacterial_leaf_blight,ADT45,45


Turning DataLoaders into DataBlocks. We need to return 2 outputs, variety and disease.

In [3]:
arch = resnet18

In [4]:
img2variety = { r.image_id:r.variety for _,r in df.iterrows() }

In [5]:
def get_variety(p):
    "Turn path into image name"
    return img2variety[p.name]

## Customizing head

In [6]:
orig_lf = CrossEntropyLossFlat()

def disease_err(inp,disease,variety): return error_rate(inp[:,:10],disease)
def variety_err(inp,disease,variety): return error_rate(inp[:,10:],variety)
def disease_loss(inp,disease,variety): return orig_lf(inp[:,:10],disease)
def variety_loss(inp,disease,variety): return orig_lf(inp[:,10:],variety)
def loss(pred,disease,variety): return orig_lf(pred[:,:10],disease)+orig_lf(pred[:,10:],variety)
err_metrics = (disease_err,variety_err)
all_metrics = err_metrics+(disease_loss,variety_loss)

In [17]:
def get_dls(img_sizes):
    """Return the dataloaders"""
    for size in img_sizes:
        print('*' * 40)
        print(f'Size: {size}'.center(40, '*'))
        print('*' * 40)
        dblock = DataBlock(
            blocks=(ImageBlock, CategoryBlock, CategoryBlock),
            n_inp=1,                                             
            get_items=get_image_files,
            get_y=[parent_label, get_variety],
            splitter=RandomSplitter(seed=42),
            item_tfms=Resize(224, method='squish'),                
            batch_tfms=aug_transforms(size=size),
        )
        dls = dblock.dataloaders(trn_path, bs=512)
        yield dls

In [19]:
def prog_sizing(img_sizes, epochs=1):
    """Progressively use bigger images for training."""
    dls_gen = get_dls(img_sizes)
    dls = next(dls_gen)
    learn = vision_learner(dls, arch, loss_func=loss, metrics=all_metrics, n_out=20).to_fp16()
    # learn.lr_find()
    learn.fine_tune(epochs, 0.01)
    for dls in dls_gen:
        learn.dls = dls
        learn.fine_tune(epochs, 0.01)

In [9]:
# learn.fine_tune(20, 0.01) ## error: .033 

NameError: name 'learn' is not defined

After 20 epochs, it is better than training normally.

What would happen if we add progressive resizing?

In [16]:
prog_sizing([32, 64, 128, 224], epochs=10)  # min err: 0.024988
# It isn't traing very well when the size was big. probably needs smaller learning rate.

****
****************Size: 32****************
****


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,5.830206,3.450318,0.725613,0.336377,2.040698,1.409619,00:15


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,3.706298,2.667388,0.557424,0.255166,1.694673,0.972715,00:13
1,3.177379,2.016477,0.457953,0.166266,1.405505,0.610972,00:13
2,2.718117,1.553778,0.368573,0.138876,1.117376,0.436402,00:15
3,2.307854,1.238139,0.301297,0.106199,0.907001,0.331138,00:14
4,1.960078,1.004146,0.247958,0.08938,0.732678,0.271468,00:14
5,1.67807,0.851616,0.195579,0.075444,0.609585,0.242031,00:14
6,1.437219,0.747382,0.177799,0.064392,0.545799,0.201582,00:13
7,1.241058,0.680201,0.158097,0.059106,0.493098,0.187103,00:13
8,1.076236,0.654065,0.150408,0.058626,0.47277,0.181295,00:13
9,0.965133,0.648171,0.148006,0.059106,0.468403,0.179767,00:13


****
****************Size: 64****************
****


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,2.543406,3.083127,0.666026,0.305142,1.904092,1.179034,00:16


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,1.626815,1.227012,0.274387,0.117251,0.867184,0.359828,00:16
1,1.307079,0.748122,0.164825,0.066314,0.53416,0.213961,00:16
2,1.029927,0.49933,0.107641,0.041326,0.360382,0.138948,00:16
3,0.835368,0.448897,0.09851,0.027391,0.326292,0.122605,00:17
4,0.677665,0.378867,0.08073,0.027871,0.265655,0.113212,00:17
5,0.543291,0.337726,0.069197,0.024507,0.243161,0.094565,00:16
6,0.436781,0.286326,0.055262,0.023546,0.197329,0.088997,00:17
7,0.352045,0.266076,0.05382,0.01778,0.183954,0.082122,00:17
8,0.286786,0.249216,0.049495,0.015858,0.173591,0.075624,00:17
9,0.237671,0.247031,0.049015,0.016338,0.172489,0.074542,00:17


****
***************Size: 128****************
****


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,1.110096,0.695265,0.165786,0.054301,0.517615,0.177651,00:24


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.615155,0.361478,0.086977,0.029313,0.271103,0.090374,00:28
1,0.454983,0.290732,0.063431,0.023066,0.214284,0.076449,00:29
2,0.360793,0.305207,0.063431,0.025469,0.217639,0.087568,00:28
3,0.308478,0.338534,0.054301,0.034118,0.216667,0.121867,00:28
4,0.265945,0.247525,0.050457,0.019702,0.18214,0.065385,00:28
5,0.223946,0.196427,0.045651,0.011533,0.150864,0.045563,00:28
6,0.183068,0.173904,0.034118,0.012013,0.127926,0.045977,00:28
7,0.148095,0.148388,0.031235,0.011052,0.114295,0.034093,00:28
8,0.121464,0.148832,0.031716,0.00865,0.112136,0.036697,00:28
9,0.10034,0.146956,0.030274,0.00865,0.109733,0.037222,00:28


****
***************Size: 224****************
****


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.352148,0.277035,0.06247,0.01778,0.207681,0.069354,00:48


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.227471,0.295345,0.054781,0.025949,0.198256,0.097089,01:07
1,0.194329,0.296361,0.052379,0.023546,0.215109,0.081252,01:07
2,0.190352,0.336605,0.061028,0.030274,0.225118,0.111486,01:08
3,0.18733,0.290767,0.062951,0.021624,0.216703,0.074064,01:08
4,0.172985,0.27207,0.049495,0.023066,0.188455,0.083615,01:08
5,0.151575,0.188059,0.040846,0.012975,0.143488,0.044571,01:07
6,0.124858,0.17685,0.032196,0.013936,0.124932,0.051918,01:07
7,0.102299,0.154965,0.028832,0.011052,0.109161,0.045803,01:08
8,0.082791,0.146318,0.024988,0.010572,0.102105,0.044213,01:08
9,0.068522,0.145713,0.025949,0.011533,0.101416,0.044297,01:07


In [15]:
prog_sizing([64, 128, 224], 10)

****
****************Size: 64****************
****


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,5.628244,3.294783,0.569918,0.341182,1.920324,1.374459,00:15


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,3.200235,2.234747,0.421913,0.201346,1.395548,0.839198,00:16
1,2.541065,1.44962,0.306103,0.124459,0.953555,0.496066,00:16
2,1.994637,0.936545,0.215281,0.07112,0.668757,0.267788,00:16
3,1.565975,0.653385,0.135992,0.052859,0.457729,0.195656,00:16
4,1.24137,0.503635,0.114368,0.040846,0.363817,0.139817,00:16
5,0.978848,0.413678,0.086977,0.030754,0.287513,0.126165,00:16
6,0.780515,0.39906,0.082172,0.029793,0.26759,0.13147,00:16
7,0.623618,0.337887,0.067756,0.023546,0.224834,0.113053,00:16
8,0.505693,0.32537,0.066314,0.023066,0.216888,0.108482,00:16
9,0.42066,0.321115,0.064392,0.023066,0.214511,0.106604,00:16


****
***************Size: 128****************
****


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,1.371843,2.533492,0.522826,0.265738,1.578217,0.955275,00:25


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.807679,0.659039,0.136473,0.063912,0.451985,0.207054,00:28
1,0.600972,0.381619,0.079769,0.031235,0.271312,0.110307,00:28
2,0.470544,0.391498,0.078328,0.03604,0.254183,0.137314,00:28
3,0.400748,0.36771,0.066795,0.030754,0.26107,0.10664,00:28
4,0.337892,0.307263,0.059587,0.024507,0.209087,0.098176,00:28
5,0.281231,0.229418,0.04469,0.015858,0.166072,0.063346,00:28
6,0.227221,0.211469,0.040365,0.014416,0.139267,0.072202,00:28
7,0.182899,0.184123,0.031716,0.012013,0.129371,0.054752,00:28
8,0.148317,0.1748,0.031235,0.013455,0.12579,0.04901,00:28
9,0.122947,0.172369,0.032677,0.012975,0.123714,0.048655,00:28


****
***************Size: 224****************
****


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.384562,0.304417,0.064873,0.024027,0.223663,0.080753,00:47


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.259126,0.279759,0.052379,0.025949,0.178342,0.101417,01:07
1,0.212142,0.237233,0.047093,0.013936,0.179813,0.05742,01:07
2,0.197016,0.385804,0.073042,0.023546,0.290273,0.095531,01:07
3,0.201611,0.392736,0.063912,0.025949,0.287332,0.105403,01:07
4,0.18852,0.253916,0.047573,0.016338,0.198615,0.055301,01:08
5,0.164273,0.193276,0.039404,0.010572,0.150139,0.043137,01:07
6,0.135389,0.18115,0.038443,0.011052,0.148339,0.03281,01:07
7,0.110961,0.158122,0.029313,0.012013,0.11696,0.041163,01:07
8,0.090154,0.156832,0.030754,0.008169,0.119786,0.037045,01:08
9,0.074944,0.159052,0.029313,0.008169,0.121561,0.03749,01:09


Training with [32, 64, 128, 224] as image sizes reduced the error rate down to 0.024988. Let's try one more time. 

When it was training, I noticed the following:
- Loss got worse for couple epochs when image size was 224. I assume our learning rate was too big.
- When transitioning from size 32 to 64, loss starts out with a big value. Maybe transition slowly?

In [20]:
# DO it again for a 
prog_sizing([32, 64, 128, 224], epochs=10)

****************************************
****************Size: 32****************
****************************************


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,5.829073,3.527646,0.742912,0.366651,2.156008,1.371637,00:13


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,3.731667,2.662712,0.556463,0.250841,1.697008,0.965704,00:14
1,3.190116,2.084851,0.489188,0.173955,1.463543,0.621308,00:14
2,2.734121,1.623267,0.380106,0.154733,1.134065,0.489203,00:14
3,2.319028,1.2632,0.297934,0.115329,0.899772,0.363428,00:14
4,1.97888,1.03738,0.246516,0.084094,0.755496,0.281883,00:13
5,1.68986,0.853025,0.20519,0.077847,0.622187,0.230839,00:13
6,1.444708,0.726997,0.174435,0.061989,0.538176,0.188821,00:14
7,1.24323,0.666091,0.159539,0.052859,0.498401,0.167691,00:13
8,1.08872,0.635889,0.144642,0.051418,0.473535,0.162354,00:13
9,0.96807,0.634013,0.146564,0.050457,0.472768,0.161244,00:13


****************************************
****************Size: 64****************
****************************************


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,2.615643,3.332381,0.677078,0.320519,1.996183,1.336199,00:15


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,1.615825,1.230323,0.296492,0.114368,0.874521,0.355802,00:17
1,1.293265,0.713892,0.174916,0.052859,0.528471,0.185421,00:17
2,1.023287,0.578141,0.13407,0.045171,0.421118,0.157023,00:16
3,0.829999,0.456702,0.096588,0.03556,0.319981,0.136721,00:16
4,0.674892,0.39678,0.079769,0.031716,0.295453,0.101327,00:17
5,0.541866,0.327606,0.066314,0.024027,0.246544,0.081063,00:17
6,0.434183,0.293301,0.061028,0.01778,0.221633,0.071668,00:17
7,0.350318,0.247975,0.050457,0.015858,0.187726,0.060248,00:16
8,0.280412,0.240303,0.049976,0.014897,0.180518,0.059785,00:17
9,0.232882,0.239566,0.048534,0.015377,0.179589,0.059977,00:17


****************************************
***************Size: 128****************
****************************************


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,1.150411,0.710937,0.170591,0.073042,0.504847,0.20609,00:24


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.648411,0.423235,0.096588,0.037482,0.296692,0.126543,00:29
1,0.483133,0.277622,0.065353,0.022585,0.203184,0.074438,00:29
2,0.372637,0.385703,0.085536,0.031716,0.287125,0.098578,00:28
3,0.317843,0.31704,0.064392,0.019702,0.242844,0.074197,00:28
4,0.273166,0.269459,0.062951,0.01778,0.209304,0.060155,00:28
5,0.230405,0.209609,0.042287,0.013936,0.156859,0.052749,00:28
6,0.189451,0.154419,0.034118,0.012013,0.118036,0.036383,00:28
7,0.153491,0.155311,0.031716,0.011533,0.119968,0.035343,00:28
8,0.125852,0.154879,0.031235,0.010572,0.12005,0.034828,00:28
9,0.104963,0.152816,0.030274,0.010572,0.118489,0.034326,00:28


****************************************
***************Size: 224****************
****************************************


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.349256,0.292107,0.063431,0.019222,0.222339,0.069768,00:48


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time
0,0.230294,0.280018,0.055262,0.027871,0.195073,0.084945,01:07
1,0.190613,0.330878,0.066795,0.025469,0.229826,0.101052,01:07
2,0.183166,0.341349,0.077847,0.017299,0.27389,0.067459,01:07
3,0.186574,0.523317,0.078328,0.046612,0.325349,0.197968,01:07
4,0.181953,0.30641,0.056223,0.023066,0.219652,0.086757,01:07
5,0.157106,0.225756,0.04469,0.013936,0.172092,0.053664,01:07
6,0.131176,0.191468,0.03556,0.012975,0.135268,0.0562,01:07
7,0.107546,0.18089,0.032677,0.013455,0.124883,0.056007,01:07
8,0.088351,0.166397,0.030754,0.012975,0.116095,0.050302,01:07
9,0.072841,0.166583,0.030754,0.012975,0.115693,0.05089,01:07


In [21]:
prog_sizing([32, 48, 64, 128, 224], epochs=10)

****************************************
****************Size: 32****************
****************************************


epoch,train_loss,valid_loss,disease_err,variety_err,disease_loss,variety_loss,time


KeyboardInterrupt: 