### CNN training

In [None]:
import os
import numpy as np
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import precision_recall_fscore_support as prf

In [None]:
from PIL import Image, ExifTags, ImageOps, ImageDraw
from src import bbox2tlbr, sqrbbox, compute_IoU

In [None]:
import torch
from torchvision import transforms

### 1. Dataset & dataloader

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
_imgRoot = '../projects/ma24/data/test/images/'
_classes = ['aegypti', 'albopictus', 'anopheles', 'culex', 'culiseta', 'japonicus-koreicus', '??']
_imgSize = 512 # pretrained at 384
_imgNorm = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# control image max-size 
Image.MAX_IMAGE_PIXELS = 201326592
Image.warnings.simplefilter('error', Image.DecompressionBombWarning)

In [None]:
class maDataset(Dataset):
    
    def __init__(self, csvDataFile):
        
        self.df = pd.read_csv(csvDataFile)
        self.transform = transforms.Compose([
                transforms.Resize((_imgSize, _imgSize)),
                transforms.ToTensor(),
                transforms.Normalize(_imgNorm[0], _imgNorm[1])
            ])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
    
        # open image file
        row = self.df.iloc[idx]
        pilImg = Image.open('%s/%s' %(_imgRoot, row.img_fName))
        # crop image
        bbox = sqrbbox([(row.bbx_xtl, row.bbx_ytl), (row.bbx_xbr, row.bbx_ybr)], pilImg.size)
        pilImg = pilImg.crop(bbox)
        # transform to torch tensor image
        torchImg = self.transform(pilImg)

        return {'img_fName': row.img_fName, 'image' : torchImg, 'label': [_classes.index(row.class_label)]}
 

In [None]:
# instantiate our custom dataset
csvDataFile = '../projects/ma24/data/test/phase2_test.csv'
_maDataset = maDataset(csvDataFile)
_maDataset.__len__()

In [None]:
# instantiate the dataloader
batch_size, num_workers = 4, 8
_dataLoader = DataLoader(_maDataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)

### 2. Backbone

In [None]:
_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
import timm

In [None]:
# load the model
model = timm.create_model(
        'tf_efficientnetv2_s',
        pretrained = True,
        num_classes = len(_classes),
        global_pool = 'avg'
    )

In [None]:
model = model.to(_device)

### 3. Backpropagation

#### loss function

In [None]:
loss_function = torch.nn.CrossEntropyLoss()

In [None]:
loss_function

#### optimizer

In [None]:
from torch import optim

In [None]:
# optimizer
lRate = 0.0001
optimizer = optim.Adam(model.parameters(), lr = lRate)

In [None]:
optimizer

#### scheduler

In [None]:
gamma = 0.995
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = gamma)

In [None]:
scheduler

### 4. Training

In [None]:
import time

In [None]:
%%time
# train
try:
    
    model.train()
    torch.set_grad_enabled(True)
    train_loss, train_match = .0, .0

    max_epochs = 2
    for epoch in range(max_epochs):

        start_time = time.time()
        for batch in _dataLoader:

            # +++ forward pass
            optimizer.zero_grad()
            inputs = batch['image'].to(_device)
            output = model(inputs)
            
            # +++ loss
            labels = torch.cat(tuple(batch['label']), dim = 0).to(_device)
            batch_loss = loss_function(output, labels)

            # +++ backpropagation
            batch_loss.backward()
            optimizer.step()

            # +++ evaluation
            _, preds = torch.max(torch.nn.functional.softmax(output, dim = 1), dim = 1)
            train_loss += batch_loss.data * inputs.shape[0]
            train_match += torch.sum(preds.data == labels.data)

        print('+++ epoch {:3d}, {:6.4f}s Train- Loss: {:.4f} Acc: {:.4f}'.format(epoch, (time.time() -start_time), train_loss.item() /_maDataset.__len__(), train_match.item() /_maDataset.__len__()), end = '')
        
        #
except BaseException as err:    
    print(f"+++ batch_inference() {type(err).__name__}, {err}")


### 5. Evaluate

In [None]:
df_pred = pd.DataFrame(preds)
df_pred.head()

In [None]:
df_ = pd.merge(_maDataset.df, df_pred, how = 'inner', on = 'img_fName')
df_.head()

In [None]:
df_.groupby('class_label').pred_label.value_counts()

#### Classification

In [None]:
_, axs = plt.subplots(1, 3, figsize = (15, 7), sharey = True)
for i, norm in enumerate([None, 'true', 'pred']):
    ConfusionMatrixDisplay.from_predictions(
        df_.class_label,
        df_.pred_label,
        normalize = norm,
        ax = axs[i],
        display_labels = ['??', 'aeg', 'alb', 'ano', 'clx', 'cul', 'j/k'],
        cmap = 'GnBu',
        colorbar = None
    )
plt.tight_layout()

In [None]:
avrgs = ['macro', 'micro', 'weighted']
pd.DataFrame([prf(df_.class_label, df_.pred_label, average = mode, zero_division = 0)[:3] for mode in avrgs], columns = ['precision', 'recall', 'f-score'], index = avrgs)

#### check predictions

In [None]:
i = -1

In [None]:
i += 1
row = df_.iloc[i]
pilImg = Image.open('%s/%s' %(_imgRoot, row.img_fName))
imgdrw = ImageDraw.Draw(pilImg)
imgdrw.rectangle([(row.bbx_xtl, row.bbx_ytl), (row.bbx_xbr, row.bbx_ybr)], outline = 'blue', width = 2)
plt.imshow(pilImg)
plt.axis('off');
print('+++%3d %s - %s / %s' %(i, row.img_fName, row.class_label, row.pred_label))

In [None]:
chk = df_[df_.pred_label == '??']
len(chk)

In [None]:
j = -1

In [None]:
j += 1
row = chk.iloc[j]
pilImg = Image.open('%s/%s' %(_imgRoot, row.img_fName))
imgdrw = ImageDraw.Draw(pilImg)
imgdrw.rectangle([(row.bbx_xtl, row.bbx_ytl), (row.bbx_xbr, row.bbx_ybr)], outline = 'blue', width = 8)
plt.imshow(pilImg)
plt.axis('off');
print('+++%3d - %s / %s' %(i, row.class_label, row.pred_label))