In [1]:
import os
import cv2
import time
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
list_rs = []
sub_dir = os.listdir('../data/pill_recog')

for classname in sub_dir:
    if classname == '107':
        continue
    for filename in os.listdir('../data/pill_recog/' + classname):
        filepath = os.path.join(classname, filename)
        list_rs.append([filepath, classname])
df = pd.DataFrame(list_rs, columns=['filepath', 'label'])
df.to_csv('../data/pill_recog.csv')

In [2]:
df1 = pd.read_csv('../data/crop/train_crop.csv')
print(df1.head())
df2 = pd.read_csv('../data/crop/val_crop.csv')
print(df2.head())

                                            filepath  label
0  ../data/crop/train_crop/74/VAIPE_P_1045_1_pill...     74
1  ../data/crop/train_crop/74/VAIPE_P_63_18_pill1...     74
2  ../data/crop/train_crop/74/VAIPE_P_304_21_pill...     74
3  ../data/crop/train_crop/74/VAIPE_P_768_3_pill7...     74
4  ../data/crop/train_crop/74/VAIPE_P_889_1_pill6...     74
                                          filepath  label
0  ../data/crop/val_crop/74/VAIPE_P_105_1_2469.jpg     74
1  ../data/crop/val_crop/74/VAIPE_P_146_0_3030.jpg     74
2  ../data/crop/val_crop/74/VAIPE_P_134_1_2906.jpg     74
3  ../data/crop/val_crop/74/VAIPE_P_134_0_2905.jpg     74
4  ../data/crop/val_crop/74/VAIPE_P_105_1_2470.jpg     74


In [19]:
frames = [df1, df2]
df = pd.concat(frames).reset_index(drop=True)
len(df)
df.head()

Unnamed: 0,filepath,label
0,../data/crop/train_crop/74/VAIPE_P_1045_1_pill...,74
1,../data/crop/train_crop/74/VAIPE_P_63_18_pill1...,74
2,../data/crop/train_crop/74/VAIPE_P_304_21_pill...,74
3,../data/crop/train_crop/74/VAIPE_P_768_3_pill7...,74
4,../data/crop/train_crop/74/VAIPE_P_889_1_pill6...,74


In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

In [5]:
from torch.utils.data import Dataset, DataLoader


class PillDataset(Dataset):
    def __init__(self, df, data_root, 
                 transforms=None, 
                 output_label=True, 
                ):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        
        self.output_label = output_label
        if output_label:
            self.labels = self.df['label'].values
            
        """
        images = []
        labels = []
        
        for filename in os.listdir(os.path.join(data_root, sub_dir)):
            filepath = os.path.join(data_root, sub_dir, filename)
            images.append(filepath)
            labels.append(sub_dir)
            
        self.images = images
        self.labels = labels
        """ 
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        if self.output_label:
            target = self.labels[index]
            # target = dict_label[target]
        imgname = self.df.loc[index]['filepath']
        img  = get_img(f"{self.data_root}/{self.df.loc[index]['filepath']}")

        if self.transforms:
            img = self.transforms(image=img)['image']
                            
        # do label smoothing
        if self.output_label:
            return img, target
        else:
            return img

In [6]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from albumentations.pytorch import ToTensorV2

In [7]:
def get_train_transforms():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            # VerticalFlip(p=0.5),
            # RandomRotate90(p=0.5),
            # ShiftScaleRotate(shift_limit=0.0, scale_limit=0.3, rotate_limit=10, border_mode=0, p=0.7),
            # HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            # RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.1),
            ToTensorV2(p=1.0),
        ], p=1.0)
  
        
def get_valid_transforms():
    return Compose([
            Resize(CFG['img_size'], CFG['img_size']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)

In [18]:
# model: vit_large_patch16_384
CFG = {
    'fold_num': 10,
    'seed': 719,
    'model_arch': 'tf_efficientnet_b7_ns',
    'img_size': 224,
    'epochs': 10,
    'train_bs': 8,
    'valid_bs': 16,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 4,
    'accum_iter': 2,
    'verbose_step': 1,
    'device': 'cuda:0'
}

In [9]:
import timm
import torch.nn as nn

In [10]:
class PillClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        # n_features = self.model.head.out_features
        # n_features = self.model.head.fc.out_features
        # self.out_layer = nn.Linear(n_features, n_class)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)

    def forward(self, x):
        x = self.model(x)
        # x = self.out_layer(x)
        return x

In [11]:
def prepare_dataloader(df, trn_idx, val_idx, data_root='../data/'):
    
    # from catalyst.data.sampler import BalanceClassSampler
    
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)
        
    train_ds = PillDataset(train_, data_root, transforms=get_train_transforms(),
                              output_label=True)
    valid_ds = PillDataset(valid_, data_root, transforms=get_valid_transforms(),output_label=True)
    
    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True, 
        num_workers=CFG['num_workers'],
    )
    val_loader = torch.utils.data.DataLoader(
        valid_ds,
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader

def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, logger, scheduler=None, schd_batch_update=False):
    model.train()

    t = time.time()
    running_loss = None
    
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        with autocast():
            image_preds = model(imgs)

            loss = loss_fn(image_preds, image_labels)
            
            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .99 + loss.item() * .01

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() 
                
                if scheduler is not None and schd_batch_update:
                    scheduler.step()

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss:.4f}'
                # logger.info('Epoch {} loss: {}'.format(epoch, running_loss))
                pbar.set_description(description)
                
    if scheduler is not None and not schd_batch_update:
        scheduler.step()
        
def valid_one_epoch(epoch, model, loss_fn, val_loader, device, logger, scheduler=None, schd_loss_update=False):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        
        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            # logger.info('Epoch {} loss: {}'.format(epoch, loss_sum/sample_num))
            pbar.set_description(description)
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    print(f'validation multi-class accuracy = {(image_preds_all==image_targets_all).mean():.4f}')
    logger.info('Validation multi-class accuracy: {}'.format((image_preds_all == image_targets_all).mean()))
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()

In [20]:
import pandas as pd
train = df
train.head()

Unnamed: 0,filepath,label
0,../data/crop/train_crop/74/VAIPE_P_1045_1_pill...,74
1,../data/crop/train_crop/74/VAIPE_P_63_18_pill1...,74
2,../data/crop/train_crop/74/VAIPE_P_304_21_pill...,74
3,../data/crop/train_crop/74/VAIPE_P_768_3_pill7...,74
4,../data/crop/train_crop/74/VAIPE_P_889_1_pill6...,74


In [22]:
%%time
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
from logult import setup_log
logger = setup_log(save_dir='saved/final')
save_path = 'weights/' + CFG['model_arch'] + '_final' 
Path(save_path).mkdir(parents=True, exist_ok=True)



if __name__ == '__main__':
     # for training only, need nightly build pytorch
    seed_everything(CFG['seed'])
    
    stratifiedKFold = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed'])
    folds = stratifiedKFold.split(np.arange(train.shape[0]), train.label.values)
    
    for fold, (trn_idx, val_idx) in enumerate(folds):
        # we'll train fold 0 first
        # if fold > 0:
        #     break

        logger.info(f'Training with {fold} started')
        logger.info('Found dataset with {} train sample, {} val sample'.format(len(trn_idx), len(val_idx)))
        # print(len(trn_idx), len(val_idx))

        train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, 
                                                      data_root='./')
        
        device = torch.device(CFG['device'])
        
        model = PillClassifier(CFG['model_arch'], train.label.nunique(), pretrained=True).to(device)
        scaler = GradScaler()
        optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)

        loss_tr = nn.CrossEntropyLoss().to(device)
        loss_fn = nn.CrossEntropyLoss().to(device)

        for epoch in range(CFG['epochs']):
            train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, logger,
                            scheduler=scheduler, schd_batch_update=False)

            with torch.no_grad():
                valid_one_epoch(epoch, model, loss_fn, val_loader, device, logger,
                                scheduler=None, schd_loss_update=False)
        
            torch.save(model.state_dict(), f"weights/{CFG['model_arch']}_final/fold_{fold}_{epoch}.pth")

        del model, optimizer, train_loader, val_loader, scheduler, scaler
        torch.cuda.empty_cache()
        # if fold == 2:
        #     break

[2022-08-31 10:09:20,093 - urllib3.connectionpool - DEBUG] - Starting new HTTPS connection (1): raw.githubusercontent.com:443
[2022-08-31 10:09:20,252 - urllib3.connectionpool - DEBUG] - https://raw.githubusercontent.com:443 "GET /cuongngm/logult/why/config/logger_config.json HTTP/1.1" 200 372
[2022-08-31 10:09:20,270 - root - INFO] - Training with 0 started
[2022-08-31 10:09:20,271 - root - INFO] - Found dataset with 25056 train sample, 2785 val sample




[2022-08-31 10:09:21,457 - timm.models.helpers - INFO] - Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth)


epoch 0 loss: 0.8422: 100%|██████████| 3132/3132 [11:57<00:00,  4.37it/s]
epoch 0 loss: 0.4714: 100%|██████████| 175/175 [00:21<00:00,  8.08it/s]

validation multi-class accuracy = 0.8740
[2022-08-31 10:21:42,624 - root - INFO] - Validation multi-class accuracy: 0.873967684021544



epoch 1 loss: 0.5770: 100%|██████████| 3132/3132 [11:40<00:00,  4.47it/s]
epoch 1 loss: 0.3379: 100%|██████████| 175/175 [00:19<00:00,  8.99it/s]

validation multi-class accuracy = 0.9063
[2022-08-31 10:33:43,844 - root - INFO] - Validation multi-class accuracy: 0.9062836624775583



epoch 2 loss: 0.4869: 100%|██████████| 3132/3132 [12:07<00:00,  4.30it/s]
epoch 2 loss: 0.2668: 100%|██████████| 175/175 [00:19<00:00,  9.10it/s]

validation multi-class accuracy = 0.9189
[2022-08-31 10:46:12,052 - root - INFO] - Validation multi-class accuracy: 0.918850987432675



epoch 3 loss: 0.4544: 100%|██████████| 3132/3132 [11:43<00:00,  4.45it/s]
epoch 3 loss: 0.2257: 100%|██████████| 175/175 [00:19<00:00,  9.03it/s]

validation multi-class accuracy = 0.9300
[2022-08-31 10:58:15,804 - root - INFO] - Validation multi-class accuracy: 0.9299820466786356



epoch 4 loss: 0.4121: 100%|██████████| 3132/3132 [11:39<00:00,  4.48it/s]
epoch 4 loss: 0.2348: 100%|██████████| 175/175 [00:19<00:00,  8.99it/s]

validation multi-class accuracy = 0.9246
[2022-08-31 11:10:16,175 - root - INFO] - Validation multi-class accuracy: 0.9245960502692998



epoch 5 loss: 0.2816: 100%|██████████| 3132/3132 [12:03<00:00,  4.33it/s]
epoch 5 loss: 0.2026: 100%|██████████| 175/175 [00:19<00:00,  9.02it/s]

validation multi-class accuracy = 0.9364
[2022-08-31 11:22:40,172 - root - INFO] - Validation multi-class accuracy: 0.9364452423698384



epoch 6 loss: 0.2926: 100%|██████████| 3132/3132 [11:39<00:00,  4.48it/s]
epoch 6 loss: 0.1886: 100%|██████████| 175/175 [00:19<00:00,  9.02it/s]

validation multi-class accuracy = 0.9422
[2022-08-31 11:34:40,374 - root - INFO] - Validation multi-class accuracy: 0.9421903052064632



epoch 7 loss: 0.2222: 100%|██████████| 3132/3132 [11:57<00:00,  4.37it/s]
epoch 7 loss: 0.1809: 100%|██████████| 175/175 [00:19<00:00,  8.98it/s]

validation multi-class accuracy = 0.9454
[2022-08-31 11:46:58,312 - root - INFO] - Validation multi-class accuracy: 0.9454219030520646



epoch 8 loss: 0.2375: 100%|██████████| 3132/3132 [12:00<00:00,  4.35it/s]
epoch 8 loss: 0.1658: 100%|██████████| 175/175 [00:19<00:00,  9.00it/s]

validation multi-class accuracy = 0.9494
[2022-08-31 11:59:19,162 - root - INFO] - Validation multi-class accuracy: 0.9493716337522442



epoch 9 loss: 0.2573: 100%|██████████| 3132/3132 [11:45<00:00,  4.44it/s]
epoch 9 loss: 0.1535: 100%|██████████| 175/175 [00:19<00:00,  8.99it/s]

validation multi-class accuracy = 0.9508
[2022-08-31 12:11:25,467 - root - INFO] - Validation multi-class accuracy: 0.9508078994614003





[2022-08-31 12:11:26,123 - root - INFO] - Training with 1 started
[2022-08-31 12:11:26,124 - root - INFO] - Found dataset with 25057 train sample, 2784 val sample
[2022-08-31 12:11:27,373 - timm.models.helpers - INFO] - Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth)


epoch 0 loss: 0.8877: 100%|██████████| 3133/3133 [11:57<00:00,  4.36it/s]
epoch 0 loss: 0.4821: 100%|██████████| 174/174 [00:19<00:00,  9.01it/s]

validation multi-class accuracy = 0.8635
[2022-08-31 12:23:47,360 - root - INFO] - Validation multi-class accuracy: 0.8635057471264368



epoch 1 loss: 0.6066: 100%|██████████| 3133/3133 [11:43<00:00,  4.45it/s]
epoch 1 loss: 0.3555: 100%|██████████| 174/174 [00:19<00:00,  9.01it/s]

validation multi-class accuracy = 0.9030
[2022-08-31 12:35:51,743 - root - INFO] - Validation multi-class accuracy: 0.9030172413793104



epoch 2 loss: 0.5284: 100%|██████████| 3133/3133 [11:48<00:00,  4.42it/s]
epoch 2 loss: 0.2944: 100%|██████████| 174/174 [00:19<00:00,  9.02it/s]

validation multi-class accuracy = 0.9159
[2022-08-31 12:48:00,915 - root - INFO] - Validation multi-class accuracy: 0.915948275862069



epoch 3 loss: 0.5080: 100%|██████████| 3133/3133 [11:45<00:00,  4.44it/s]
epoch 3 loss: 0.2744: 100%|██████████| 174/174 [00:19<00:00,  9.00it/s]

validation multi-class accuracy = 0.9221
[2022-08-31 13:00:07,286 - root - INFO] - Validation multi-class accuracy: 0.9220545977011494



epoch 4 loss: 0.4009: 100%|██████████| 3133/3133 [11:58<00:00,  4.36it/s]
epoch 4 loss: 0.2595: 100%|██████████| 174/174 [00:19<00:00,  8.99it/s]

validation multi-class accuracy = 0.9253
[2022-08-31 13:12:26,894 - root - INFO] - Validation multi-class accuracy: 0.9252873563218391



epoch 5 loss: 0.3492: 100%|██████████| 3133/3133 [11:47<00:00,  4.43it/s]
epoch 5 loss: 0.2084: 100%|██████████| 174/174 [00:19<00:00,  9.03it/s]

validation multi-class accuracy = 0.9418
[2022-08-31 13:24:35,294 - root - INFO] - Validation multi-class accuracy: 0.9418103448275862



epoch 6 loss: 0.2582: 100%|██████████| 3133/3133 [11:50<00:00,  4.41it/s]
epoch 6 loss: 0.2174: 100%|██████████| 174/174 [00:20<00:00,  8.63it/s]

validation multi-class accuracy = 0.9411
[2022-08-31 13:36:47,397 - root - INFO] - Validation multi-class accuracy: 0.9410919540229885



epoch 7 loss: 0.2569: 100%|██████████| 3133/3133 [12:01<00:00,  4.34it/s]
epoch 7 loss: 0.1987: 100%|██████████| 174/174 [00:20<00:00,  8.69it/s]

validation multi-class accuracy = 0.9418
[2022-08-31 13:49:10,372 - root - INFO] - Validation multi-class accuracy: 0.9418103448275862



epoch 8 loss: 0.2871: 100%|██████████| 3133/3133 [11:35<00:00,  4.50it/s]
epoch 8 loss: 0.2105: 100%|██████████| 174/174 [00:19<00:00,  8.95it/s]

validation multi-class accuracy = 0.9418
[2022-08-31 14:01:06,797 - root - INFO] - Validation multi-class accuracy: 0.9418103448275862



epoch 9 loss: 0.2606: 100%|██████████| 3133/3133 [11:49<00:00,  4.41it/s]
epoch 9 loss: 0.1972: 100%|██████████| 174/174 [00:19<00:00,  8.96it/s]

validation multi-class accuracy = 0.9440
[2022-08-31 14:13:17,521 - root - INFO] - Validation multi-class accuracy: 0.9439655172413793





[2022-08-31 14:13:18,335 - root - INFO] - Training with 2 started
[2022-08-31 14:13:18,336 - root - INFO] - Found dataset with 25057 train sample, 2784 val sample
[2022-08-31 14:13:19,682 - timm.models.helpers - INFO] - Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth)


epoch 0 loss: 4.6852:   1%|          | 27/3133 [00:06<12:22,  4.19it/s]


KeyboardInterrupt: 