In [1]:
import os
import ast
import torch
import torch.nn as nn

from model import bengalimodel
from dataset import BengaliDatasetTrain
from train_fn import train,evaluate

In [2]:
device = 'cuda'
IMG_HEIGHT=137
IMG_WIDTH=236
EPOCHS=25
TRAIN_BATCH_SIZE=64
TEST_BATCH_SIZE=16
MODEL_MEAN=(0.485, 0.456, 0.406)
MODEL_STD=(0.229, 0.224, 0.225)

In [3]:
TRAINING_FOLDS = [1,2,3,4]
VALIDATION_FOLDS = [0]

model = bengalimodel(backbone = 'resnet18')
model = model.to(device)

train_dataset = BengaliDatasetTrain(
    folds=TRAINING_FOLDS,
    img_height = IMG_HEIGHT,
    img_width = IMG_WIDTH,
    mean = MODEL_MEAN,
    std = MODEL_STD
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size= TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

valid_dataset = BengaliDatasetTrain(
    folds=VALIDATION_FOLDS,
    img_height = IMG_HEIGHT,
    img_width = IMG_WIDTH,
    mean = MODEL_MEAN,
    std = MODEL_STD
)

valid_loader = torch.utils.data.DataLoader(
    dataset=valid_dataset,
    batch_size= TEST_BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode = 'min',
patience = 5,
factor = 0.3,
verbose = True
)



Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


In [4]:
best_score = -1
es = 0
for epoch in range(EPOCHS):
    train_loss, train_score = train(train_dataset,train_loader, model, optimizer)
    val_loss, val_score = evaluate(valid_dataset, valid_loader, model)
    
    scheduler.step(val_loss)

    
    if val_score > best_score:
            best_score = val_score
            es = 0
            torch.save(model.state_dict(), f"resnet18_fold{VALIDATION_FOLDS[0]}.pth")
    else:
        es += 1

    epoch_len = len(str(EPOCHS))
    print_msg = (f'[{epoch:>{epoch_len}}/{EPOCHS:>{epoch_len}}] ' +
                 f'train_loss: {train_loss:.5f} ' +
                 f'train_score: {train_score:.5f} ' +
                 f'valid_loss: {val_loss:.5f} ' +
                 f'valid_score: {val_score:.5f}'
                )

    print(print_msg)
    if es > 4:
        print("Early stopping")
        break

2511it [04:55,  8.50it/s]                          


recall: grapheme 0.8044096895662305, vowel 0.9098206178912704, consonant 0.9082203933649018, total 0.8567150975971582


2511it [01:41, 24.80it/s]                          


recall: grapheme 0.8980004275863047, vowel 0.9608997183843879, consonant 0.9516139165815236, total 0.9271286225346302
[ 0/25] train_loss: 0.19238 train_score: 0.85672 valid_loss: 0.09246 valid_score: 0.92713


2511it [04:55,  8.50it/s]                          


recall: grapheme 0.9157546789445652, vowel 0.9679842588161826, consonant 0.9593254908081688, total 0.9397047768783704


2511it [01:41, 24.65it/s]                          


recall: grapheme 0.924287963935349, vowel 0.9679710097457139, consonant 0.9619859284668627, total 0.9446332165208187
[ 1/25] train_loss: 0.06399 train_score: 0.93970 valid_loss: 0.06498 valid_score: 0.94463


2511it [04:57,  8.43it/s]                          


recall: grapheme 0.943944356359867, vowel 0.9764264499977948, consonant 0.9702267781552353, total 0.958635485218191


2511it [01:41, 24.71it/s]                          


recall: grapheme 0.9203614582923701, vowel 0.9729269552097911, consonant 0.9692765007116669, total 0.9457315931265496
[ 2/25] train_loss: 0.04323 train_score: 0.95864 valid_loss: 0.06137 valid_score: 0.94573


2511it [04:55,  8.49it/s]                          


recall: grapheme 0.9588351461850609, vowel 0.9821564119998143, consonant 0.9770458266577798, total 0.969218132756929


2511it [01:41, 24.72it/s]                          


recall: grapheme 0.9290041629118646, vowel 0.9750620204369329, consonant 0.9677327568122255, total 0.9502007757682219
[ 3/25] train_loss: 0.03112 train_score: 0.96922 valid_loss: 0.06437 valid_score: 0.95020


2511it [04:55,  8.48it/s]                          


recall: grapheme 0.9694693751308966, vowel 0.985185859898828, consonant 0.9808977608186973, total 0.9762555927448295


2511it [01:41, 24.72it/s]                          


recall: grapheme 0.9253252596465378, vowel 0.9750077514155554, consonant 0.9761173199991763, total 0.9504438976769518
[ 4/25] train_loss: 0.02359 train_score: 0.97626 valid_loss: 0.06992 valid_score: 0.95044


2511it [04:55,  8.49it/s]                          


recall: grapheme 0.9758268829742911, vowel 0.9880151748904104, consonant 0.9841708794959032, total 0.980959955083724


2511it [01:41, 24.72it/s]                          


recall: grapheme 0.9298108337490444, vowel 0.9800028691020739, consonant 0.9642755342927432, total 0.9509750177232265
[ 5/25] train_loss: 0.01831 train_score: 0.98096 valid_loss: 0.06404 valid_score: 0.95098


2511it [04:55,  8.49it/s]                          


recall: grapheme 0.9798182743119662, vowel 0.9891486535113543, consonant 0.9866097533842504, total 0.9838487388798842


2511it [01:41, 24.73it/s]                          


recall: grapheme 0.9337442764733941, vowel 0.9766439698321743, consonant 0.9606533896023726, total 0.9511964780953337
[ 6/25] train_loss: 0.01514 train_score: 0.98385 valid_loss: 0.06562 valid_score: 0.95120


2511it [04:55,  8.49it/s]                          


recall: grapheme 0.98276163450075, vowel 0.9911301418498547, consonant 0.9885585277316576, total 0.9863029846457532


2511it [01:41, 24.67it/s]                          


recall: grapheme 0.9333447460258265, vowel 0.9735505971139737, consonant 0.9601852605981945, total 0.9501063374409553
[ 7/25] train_loss: 0.01267 train_score: 0.98630 valid_loss: 0.07549 valid_score: 0.95011


2511it [04:55,  8.49it/s]                          


recall: grapheme 0.9844772998848345, vowel 0.9912363188782707, consonant 0.9901549645818962, total 0.987586470807459


2511it [01:41, 24.72it/s]                          


recall: grapheme 0.934741140316401, vowel 0.9756055503810918, consonant 0.9511487028196347, total 0.9490591334583821
Epoch     9: reducing learning rate of group 0 to 3.0000e-05.
[ 8/25] train_loss: 0.01141 train_score: 0.98759 valid_loss: 0.07308 valid_score: 0.94906


2511it [04:55,  8.49it/s]                          


recall: grapheme 0.9975007760987866, vowel 0.9975059531640817, consonant 0.9966315562435579, total 0.9972847654013033


2511it [01:41, 24.72it/s]                          


recall: grapheme 0.952656554360996, vowel 0.9839887934300681, consonant 0.9737405194577811, total 0.9657606054024603
[ 9/25] train_loss: 0.00269 train_score: 0.99728 valid_loss: 0.05665 valid_score: 0.96576


2511it [04:55,  8.49it/s]                          


recall: grapheme 0.9995131814979168, vowel 0.9987372287191693, consonant 0.9990232483709375, total 0.999196710021485


2511it [01:41, 24.66it/s]                          


recall: grapheme 0.9520769561805745, vowel 0.9844013775832948, consonant 0.9764981592456667, total 0.9662633622975276
[10/25] train_loss: 0.00100 train_score: 0.99920 valid_loss: 0.06020 valid_score: 0.96626


2511it [04:56,  8.48it/s]                          


recall: grapheme 0.9993057200700542, vowel 0.9992312318861516, consonant 0.9992755998263386, total 0.9992795679631497


2511it [01:41, 24.71it/s]                          

recall: grapheme 0.9511510104218162, vowel 0.9836612859618458, consonant 0.9771731117647623, total 0.9657841046425601
[11/25] train_loss: 0.00096 train_score: 0.99928 valid_loss: 0.06366 valid_score: 0.96578



2511it [04:55,  8.48it/s]                          


recall: grapheme 0.9994816353660709, vowel 0.9992948312778221, consonant 0.9991234044705595, total 0.9993453766201308


2511it [01:41, 24.69it/s]                          

recall: grapheme 0.951670612460734, vowel 0.9834929317328982, consonant 0.9734345532887121, total 0.9650671774857695
[12/25] train_loss: 0.00075 train_score: 0.99935 valid_loss: 0.06534 valid_score: 0.96507



2511it [04:55,  8.48it/s]                          


recall: grapheme 0.9994932240924331, vowel 0.999281628067565, consonant 0.9992674092278964, total 0.999383871370082


2511it [01:41, 24.67it/s]                          

recall: grapheme 0.9524345574671009, vowel 0.9835331901397684, consonant 0.9744407364236813, total 0.9657107603744128
[13/25] train_loss: 0.00075 train_score: 0.99938 valid_loss: 0.06761 valid_score: 0.96571



2511it [04:56,  8.48it/s]                          


recall: grapheme 0.9992998614563978, vowel 0.999354493307786, consonant 0.9994973905942144, total 0.9993629017036991


2511it [01:41, 24.69it/s]                          

recall: grapheme 0.9490589998771436, vowel 0.9822289124699154, consonant 0.9672586875057357, total 0.9619013999324846
[14/25] train_loss: 0.00073 train_score: 0.99936 valid_loss: 0.07072 valid_score: 0.96190



2511it [04:55,  8.49it/s]                          


recall: grapheme 0.9995318293260321, vowel 0.9993797740157163, consonant 0.9991470879353974, total 0.9993976301507945


2511it [01:41, 24.73it/s]                          

recall: grapheme 0.9511367052931412, vowel 0.9804174625160738, consonant 0.9765127422296646, total 0.9648009038330051
Epoch    16: reducing learning rate of group 0 to 9.0000e-06.
[15/25] train_loss: 0.00072 train_score: 0.99940 valid_loss: 0.07302 valid_score: 0.96480
Early stopping



