In [1]:
import os
import cv2
import time
import timm
import torch
import sklearn.metrics

from PIL import Image

import numpy as np
import pandas as pd
import torch.nn as nn

from torch.optim import SGD, lr_scheduler
from torch.utils.data import DataLoader, Dataset

os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
metadata = pd.read_csv("SnakeCLEF2021_train_metadata_PROD.csv")
min_train_metadata = pd.read_csv("SnakeCLEF2021_min-train_metadata_PROD.csv")

print(len(metadata), len(min_train_metadata))

386006 70208


In [3]:
metadata.head(5)

Unnamed: 0,binomial,country,continent,genus,family,UUID,source,subset,class_id,image_path
0,Pantherophis spiloides,United States of America,North America,Pantherophis,Colubridae,fbc816e9552643a2bce4f655b2f3c4e1,inaturalist,train,523,/Datasets/SnakeCLEF-2021/inaturalist/fbc816e95...
1,Masticophis taeniatus,United States of America,North America,Masticophis,Colubridae,cbc7ad7141a642f2b92ef7fe05c9d608,inaturalist,train,430,/Datasets/SnakeCLEF-2021/inaturalist/cbc7ad714...
2,Crotalus pyrrhus,United States of America,North America,Crotalus,Viperidae,fc4db72953ae4c978ac50acb33adce0c,inaturalist,train,183,/Datasets/SnakeCLEF-2021/inaturalist/fc4db7295...
3,Haldea striatula,United States of America,North America,Haldea,Colubridae,2068c79c956d43dc8a45106e0c808aed,inaturalist,train,305,/Datasets/SnakeCLEF-2021/inaturalist/2068c79c9...
4,Natrix natrix,Russia,Europe,Natrix,Colubridae,3e376aaf4f8d42e991c0c8ddc5972f95,inaturalist,train,471,/Datasets/SnakeCLEF-2021/inaturalist/3e376aaf4...


In [4]:
train_metadata = min_train_metadata
val_metadata = metadata[metadata['subset'] == 'val']

print(len(train_metadata), len(val_metadata))
len(min_train_metadata.binomial.unique())

70208 38601


768

In [5]:
N_CLASSES = 772

class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.df['image_path'].values[idx]
        label = self.df['class_id'].values[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image, label

In [6]:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

model._fc = nn.Linear(model._fc.in_features, N_CLASSES)

Loaded pretrained weights for efficientnet-b0


In [7]:
HEIGHT = 224
WIDTH = 224

from albumentations import Compose, Normalize, Resize, HorizontalFlip, VerticalFlip
from albumentations.pytorch import ToTensorV2
from albumentations import RandomCrop, HorizontalFlip, VerticalFlip, RandomBrightnessContrast, CenterCrop, PadIfNeeded, RandomResizedCrop

def get_transforms(*, data):
    assert data in ('train', 'valid')

    if data == 'train':
        return Compose([
            RandomResizedCrop(WIDTH, HEIGHT, scale=(0.8, 1.0)),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            RandomBrightnessContrast(p=0.2),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

    elif data == 'valid':
        return Compose([
            Resize(WIDTH, HEIGHT),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [8]:
train_dataset = TrainDataset(train_metadata, transform=get_transforms(data='train'))
valid_dataset = TrainDataset(val_metadata, transform=get_transforms(data='valid'))

In [9]:
BATCH_SIZE = 64
EPOCHS = 50
WORKERS = 8

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)

In [10]:
from sklearn.metrics import f1_score, accuracy_score
import tqdm


n_epochs = EPOCHS
lr = 0.01

optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5)
criterion = nn.CrossEntropyLoss()

model.to(device)

for epoch in range(n_epochs):
    start_time = time.time()

    model.train()
    avg_loss = 0.

    optimizer.zero_grad()

    for i, (images, labels) in tqdm.tqdm(enumerate(train_loader)):

        images = images.to(device)
        labels = labels.to(device)

        y_preds = model(images)
        loss = criterion(y_preds, labels)
        avg_loss += loss.item() / len(train_loader)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    model.eval()
    avg_val_loss = 0.
    preds = np.zeros((len(valid_dataset)))

    for i, (images, labels) in enumerate(valid_loader):

        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            y_preds = model(images)

        preds[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = y_preds.argmax(1).to('cpu').numpy()
        loss = criterion(y_preds, labels)
        avg_val_loss += loss.item() / len(valid_loader)
        
    scheduler.step()

    score = f1_score(val_metadata['class_id'], preds, average='macro')
    accuracy = accuracy_score(val_metadata['class_id'], preds)

    elapsed = time.time() - start_time
    print(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f} F1: {score:.6f}  Accuracy: {accuracy:.6f} time: {elapsed:.0f}s')


1097it [03:02,  6.00it/s]


  Epoch 1 - avg_train_loss: 5.0426  avg_val_loss: 3.7397 F1: 0.065554  Accuracy: 0.217274 time: 255s


1097it [03:21,  5.45it/s]


  Epoch 2 - avg_train_loss: 3.6295  avg_val_loss: 3.1958 F1: 0.131081  Accuracy: 0.283361 time: 269s


1097it [03:29,  5.23it/s]


  Epoch 3 - avg_train_loss: 3.1122  avg_val_loss: 3.0451 F1: 0.166816  Accuracy: 0.309111 time: 278s


1097it [03:36,  5.07it/s]


  Epoch 4 - avg_train_loss: 2.7885  avg_val_loss: 2.8895 F1: 0.198979  Accuracy: 0.338566 time: 291s


550it [01:55,  4.98it/s]

  Epoch 5 - avg_train_loss: 2.5378  avg_val_loss: 2.8681 F1: 0.232379  Accuracy: 0.341546 time: 285s


1097it [03:35,  5.09it/s]


  Epoch 6 - avg_train_loss: 2.1108  avg_val_loss: 2.4461 F1: 0.298629  Accuracy: 0.425611 time: 284s


1097it [03:39,  5.01it/s]


  Epoch 7 - avg_train_loss: 2.0170  avg_val_loss: 2.4184 F1: 0.302618  Accuracy: 0.431077 time: 287s


1097it [03:41,  4.96it/s]


  Epoch 8 - avg_train_loss: 1.9617  avg_val_loss: 2.4027 F1: 0.309487  Accuracy: 0.432087 time: 295s


1097it [03:39,  4.99it/s]


  Epoch 9 - avg_train_loss: 1.9289  avg_val_loss: 2.3971 F1: 0.314581  Accuracy: 0.435792 time: 289s


1097it [03:44,  4.88it/s]


  Epoch 10 - avg_train_loss: 1.8914  avg_val_loss: 2.3800 F1: 0.321931  Accuracy: 0.441051 time: 293s


1097it [03:44,  4.89it/s]


  Epoch 11 - avg_train_loss: 1.8325  avg_val_loss: 2.3572 F1: 0.326008  Accuracy: 0.444237 time: 299s


1097it [03:38,  5.03it/s]


  Epoch 12 - avg_train_loss: 1.8279  avg_val_loss: 2.3493 F1: 0.324886  Accuracy: 0.445817 time: 288s


1097it [03:45,  4.87it/s]


  Epoch 13 - avg_train_loss: 1.8217  avg_val_loss: 2.3567 F1: 0.326244  Accuracy: 0.444678 time: 294s


1097it [03:41,  4.94it/s]


  Epoch 14 - avg_train_loss: 1.8170  avg_val_loss: 2.3503 F1: 0.326443  Accuracy: 0.445921 time: 291s


1097it [03:40,  4.99it/s]


  Epoch 15 - avg_train_loss: 1.8121  avg_val_loss: 2.3465 F1: 0.327316  Accuracy: 0.447165 time: 295s


1097it [03:42,  4.93it/s]


  Epoch 16 - avg_train_loss: 1.8058  avg_val_loss: 2.3502 F1: 0.325449  Accuracy: 0.446724 time: 291s


1097it [03:43,  4.91it/s]


  Epoch 17 - avg_train_loss: 1.8077  avg_val_loss: 2.3510 F1: 0.327032  Accuracy: 0.446465 time: 292s


1097it [03:41,  4.96it/s]


  Epoch 18 - avg_train_loss: 1.8109  avg_val_loss: 2.3489 F1: 0.326987  Accuracy: 0.446595 time: 290s


1097it [03:44,  4.89it/s]


  Epoch 19 - avg_train_loss: 1.8035  avg_val_loss: 2.3484 F1: 0.327835  Accuracy: 0.447009 time: 299s


1097it [03:42,  4.94it/s]


  Epoch 20 - avg_train_loss: 1.8024  avg_val_loss: 2.3443 F1: 0.328095  Accuracy: 0.447553 time: 291s


1097it [03:42,  4.94it/s]


  Epoch 21 - avg_train_loss: 1.8020  avg_val_loss: 2.3456 F1: 0.328408  Accuracy: 0.447812 time: 290s


1097it [03:44,  4.89it/s]


  Epoch 22 - avg_train_loss: 1.8093  avg_val_loss: 2.3476 F1: 0.328879  Accuracy: 0.447450 time: 296s


1097it [03:36,  5.06it/s]


  Epoch 23 - avg_train_loss: 1.8013  avg_val_loss: 2.3482 F1: 0.326461  Accuracy: 0.446517 time: 286s


1097it [03:40,  4.98it/s]


  Epoch 24 - avg_train_loss: 1.8021  avg_val_loss: 2.3468 F1: 0.327844  Accuracy: 0.447683 time: 289s


1097it [03:39,  4.99it/s]


  Epoch 25 - avg_train_loss: 1.8026  avg_val_loss: 2.3484 F1: 0.327776  Accuracy: 0.446983 time: 289s


1097it [03:39,  4.99it/s]


  Epoch 26 - avg_train_loss: 1.8059  avg_val_loss: 2.3476 F1: 0.327871  Accuracy: 0.447035 time: 295s


1097it [03:36,  5.07it/s]


  Epoch 27 - avg_train_loss: 1.8098  avg_val_loss: 2.3472 F1: 0.328543  Accuracy: 0.447709 time: 285s


1097it [03:41,  4.95it/s]


  Epoch 28 - avg_train_loss: 1.8032  avg_val_loss: 2.3512 F1: 0.327352  Accuracy: 0.445792 time: 290s


1097it [03:41,  4.95it/s]


  Epoch 29 - avg_train_loss: 1.8064  avg_val_loss: 2.3470 F1: 0.327592  Accuracy: 0.447035 time: 290s


1097it [03:38,  5.02it/s]


  Epoch 30 - avg_train_loss: 1.8023  avg_val_loss: 2.3472 F1: 0.326482  Accuracy: 0.447242 time: 293s


1097it [03:37,  5.05it/s]


  Epoch 31 - avg_train_loss: 1.8019  avg_val_loss: 2.3471 F1: 0.327138  Accuracy: 0.447268 time: 285s


1097it [03:36,  5.07it/s]


  Epoch 32 - avg_train_loss: 1.8041  avg_val_loss: 2.3468 F1: 0.328199  Accuracy: 0.447501 time: 285s


1097it [03:37,  5.05it/s]


  Epoch 33 - avg_train_loss: 1.8066  avg_val_loss: 2.3493 F1: 0.327517  Accuracy: 0.446595 time: 286s


1097it [03:36,  5.07it/s]


  Epoch 34 - avg_train_loss: 1.8048  avg_val_loss: 2.3482 F1: 0.328058  Accuracy: 0.446854 time: 288s


1097it [03:37,  5.05it/s]


  Epoch 35 - avg_train_loss: 1.8060  avg_val_loss: 2.3486 F1: 0.326996  Accuracy: 0.447294 time: 286s


1097it [03:34,  5.10it/s]


  Epoch 36 - avg_train_loss: 1.8068  avg_val_loss: 2.3474 F1: 0.327745  Accuracy: 0.446880 time: 284s


1097it [03:33,  5.14it/s]


  Epoch 37 - avg_train_loss: 1.8012  avg_val_loss: 2.3473 F1: 0.328329  Accuracy: 0.447165 time: 287s


1097it [03:29,  5.23it/s]


  Epoch 38 - avg_train_loss: 1.8051  avg_val_loss: 2.3523 F1: 0.327051  Accuracy: 0.445999 time: 279s


1097it [03:38,  5.03it/s]


  Epoch 39 - avg_train_loss: 1.8071  avg_val_loss: 2.3478 F1: 0.327381  Accuracy: 0.446569 time: 287s


1097it [03:38,  5.02it/s]


  Epoch 40 - avg_train_loss: 1.8015  avg_val_loss: 2.3475 F1: 0.328095  Accuracy: 0.447501 time: 287s


1097it [03:37,  5.04it/s]


  Epoch 41 - avg_train_loss: 1.8068  avg_val_loss: 2.3468 F1: 0.328143  Accuracy: 0.447165 time: 293s


1097it [03:36,  5.06it/s]


  Epoch 42 - avg_train_loss: 1.8002  avg_val_loss: 2.3453 F1: 0.328042  Accuracy: 0.447709 time: 285s


1097it [03:32,  5.17it/s]


  Epoch 43 - avg_train_loss: 1.8055  avg_val_loss: 2.3464 F1: 0.327669  Accuracy: 0.446906 time: 281s


1097it [03:32,  5.15it/s]


  Epoch 44 - avg_train_loss: 1.8077  avg_val_loss: 2.3452 F1: 0.328458  Accuracy: 0.447683 time: 281s


1097it [03:37,  5.05it/s]


  Epoch 45 - avg_train_loss: 1.8027  avg_val_loss: 2.3459 F1: 0.326877  Accuracy: 0.447113 time: 292s


1097it [03:27,  5.28it/s]


  Epoch 46 - avg_train_loss: 1.8028  avg_val_loss: 2.3508 F1: 0.326720  Accuracy: 0.446880 time: 276s


1097it [03:32,  5.15it/s]


  Epoch 47 - avg_train_loss: 1.8058  avg_val_loss: 2.3500 F1: 0.326431  Accuracy: 0.447139 time: 281s


1097it [03:36,  5.06it/s]


  Epoch 48 - avg_train_loss: 1.8010  avg_val_loss: 2.3456 F1: 0.327518  Accuracy: 0.447294 time: 285s


1097it [03:35,  5.10it/s]


  Epoch 49 - avg_train_loss: 1.8064  avg_val_loss: 2.3496 F1: 0.327641  Accuracy: 0.446310 time: 290s


1097it [03:36,  5.07it/s]


  Epoch 50 - avg_train_loss: 1.8090  avg_val_loss: 2.3463 F1: 0.327178  Accuracy: 0.446543 time: 284s


In [11]:
torch.save(model.state_dict(), f'SnakeCLEF2021-EfficientNet-B0_224-50E.pth')