In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import cv2

from sklearn.model_selection import train_test_split

import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
import albumentations

from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from torchmetrics import F1
from pytorch_toolbelt import losses as L
import timm

from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingWarmRestarts, ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint 
import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler

from pytorch_toolbelt import losses as L

gpu = 1 if torch.cuda.is_available() else 0
print(f'Using {gpu} GPUS')


import warnings
warnings.filterwarnings('ignore')

Using 1 GPUS


In [2]:
df_train = pd.read_csv('List_train.csv')
df_val = pd.read_csv('List_val.csv')
df_test = pd.read_csv('List_test.csv')

In [3]:
# From https://juansensio.com/blog/062_multihead_attention
class Dataset(torch.utils.data.Dataset):
    def __init__(self, mode, df):
        self.mode = mode
        self.df = df 
        self.mean_img = (0.485, 0.456, 0.406 )
        self.std_img = (0.229, 0.224, 0.225)
        self.classes = ['Pinus','Erica.m', 'Cistus sp', 'Lavandula', 'Citrus sp', 'Helianthus annuus',
                        'Eucalyptus sp.', 'Rosmarinus officinalis', 'Brassica', 'Cardus', 'Tilia', 'Taraxacum']
    def __crop_padding(self,img):
        ## convert to gray
        img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        ## set threshold for 0
        _,thresh = cv2.threshold(img_gray,10,255,cv2.THRESH_BINARY)
        ## find contours
        contours, hierarchy = cv2.findContours(thresh,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
        cnt = contours[0]
        x,y,w,h = cv2.boundingRect(cnt)
        crop = img[y:y+h,x:x+w,:]
        return crop
    def __getitem__(self, index):
        name_img = self.df['name'].iloc[index]
        label    = self.df['labels'].iloc[index]
        ## READ IMAGE
        image = plt.imread(name_img)
        image = self.__crop_padding(image)
        target = torch.tensor(self.classes.index(label))
        # print(f'Image shape: {image.shape} \t Target:{target}')
        if self.mode=='train':
            train_augm = albumentations.Compose(
              [
               albumentations.Resize(height=320,width=320),
               albumentations.Normalize(self.mean_img, self.std_img, max_pixel_value=255.0, always_apply=True),
              #  albumentations.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15),
              #  albumentations.Flip(p=0.5)
              ]
            )
            transformed = train_augm(image=image)
            image=transformed['image']
        else:
            valid_augm = albumentations.Compose(
              [
               albumentations.Resize(height=320,width=320),
               albumentations.Normalize(self.mean_img, self.std_img, max_pixel_value=255.0, always_apply=True)
              ]
            )
            transformed = valid_augm(image=image)
            image=transformed['image']
        image = torch.from_numpy(image.transpose()).float()
        target_oh = torch.nn.functional.one_hot(target, num_classes=12).float()
        data = {"image":image,
                "target_oh":target_oh,
                'target':target,
                'class_name':label } 
        # print(f'Image shape: {image.shape} \t Target:{target}')
        return data
    def __len__(self):
        return len(self.df)

class HoneyDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 4, Dataset = Dataset):
        super().__init__()
        self.batch_size = batch_size
        self.Dataset = Dataset
        # self.train_ds =  self.Dataset(mode='train',df= df_train)
        # self.val_ds   =  self.Dataset(mode='val', df= df_val)
        self.test_ds   =  self.Dataset(mode='test', df= df_test)
    # def train_dataloader(self):
    #     return DataLoader(self.train_ds,
    #                       batch_size=4,
    #                       # shuffle=True,
    #                       num_workers=0,
    #                       pin_memory=True,
    #                       drop_last=True,
    #                       # sampler=sampler
    #                       )
    # def val_dataloader(self):
    #     return DataLoader(self.val_ds, batch_size=4, shuffle=False, num_workers=0, pin_memory=True, drop_last=True )
    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, drop_last=False )
    
dm = HoneyDataModule(Dataset=Dataset)

In [4]:
from torchmetrics import MatthewsCorrcoef as MCC
from torchmetrics import AUROC

In [5]:
val_epoch_loss_CE, val_epoch_acc_CE = [], []
train_epoch_loss_CE, train_epoch_acc_CE = [], []

class LitModel_Focal(pl.LightningModule):
    def __init__(self, model):
        super().__init__( )
        # self.save_hyperparameters()
        self.model = model
        # self.class_weights=class_weights.to('cuda')
        self.focal_loss = L.FocalLoss(alpha=0.25, gamma=2)
        # self.class_weights
        self.f1_score = F1(num_classes=12,average='weighted')
        self.mcc = MCC(num_classes=12)
        self.auroc = AUROC(num_classes=12,average='weighted')
        
        
    def forward(self, x):
        return self.model(x)
    def predict(self, x):
        with torch.no_grad():
            y_hat = self(x)
            return torch.argmax(y_hat, axis=1)
#     def compute_loss_and_metrics(self, batch):
#         x, y = batch['image'], batch['target']
#         # print(f'X: {x.shape} \t Y: {y.shape}')
#         y_hat = self(x)
#         # print(f'Output: {y_hat.shape}')
#         # loss = F.cross_entropy(y_hat, y, weight=self.class_weights)
#         loss = self.focal_loss(y_hat, y)
#         # acc = (torch.argmax(y_hat, axis=1) == y).sum().item() / y.shape[0]
#         # y1 = y.detach().cpu().numpy()
#         # # print(y1.shape)
#         mcc = self.mcc(y_hat, y)
#         auroc = self.auroc(y_hat, y)
#         y_hat1 = torch.argmax(y_hat, axis=1)
#         # y_hat1 = y_hat1.detach().cpu().numpy()
#         # print(y_hat1.shape)
#         f1w = self.f1_score( y, y_hat1)#, average='weighted')
#         return loss, f1w, mcc, auroc
#     def training_step(self, batch, batch_idx):
#         loss, f1w, mcc = self.compute_loss_and_metrics(batch)
#         self.log('train_loss', loss)
#         self.log('train_F1w', f1w, prog_bar=True)
#         self.log('train_mcc', mcc, prog_bar=True)
#         self.log('train_auroc', auroc, prog_bar=True)
#         #print(f'Training_step: loss> {loss} acc:{acc}')
#         return {'loss':loss,'f1w':torch.tensor(f1w), 'mcc':torch.tensor(mcc), 'auroc':torch.tensor(auroc)}
    
#     def training_epoch_end(self, outputs):
#         avg_train_loss = torch.stack([x['loss'] for x in outputs]).mean()
#         avg_train_f1w  = torch.stack([x['f1w'] for x in outputs]).mean()
#         avg_train_mcc  = torch.stack([x['mcc'] for x in outputs]).mean()
#         avg_train_auroc  = torch.stack([x['auroc'] for x in outputs]).mean()

#         train_epoch_loss_CE.append(avg_train_loss.item())
#         train_epoch_acc_CE.append(avg_train_f1w.item())
#         #print(f'Epoch {self.current_epoch} TrainLOSS:{avg_train_loss} TrainACC:{avg_train_acc}  ')
#     def validation_step(self, batch, batch_idx):
#         loss, f1w, mcc, auroc = self.compute_loss_and_metrics(batch)
#         self.log('val_loss', loss, prog_bar=True)
#         self.log('val_f1w', f1w, prog_bar=True)
#         self.log('val_mcc', mcc, prog_bar=True)
#         self.log('val_auroc', auroc, prog_bar=True)
        
#         return {'val_loss': torch.tensor(loss.item()), 'val_f1w': torch.tensor(f1w), 'val_mcc': torch.tensor(mcc), 'val_auroc': torch.tensor(auroc) }
#     def validation_epoch_end(self, outputs):
#         avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#         avg_val_f1w  = torch.stack([x['val_f1w'] for x in outputs]).mean()
#         avg_val_mcc  = torch.stack([x['val_mcc'] for x in outputs]).mean()
#         avg_val_auroc  = torch.stack([x['val_auroc'] for x in outputs]).mean()
        
#         self.log('EarlyStop_Log', avg_val_loss.detach(), on_epoch=True, sync_dist=True)
#         self.log('avg_val_f1w', avg_val_f1w.detach(), on_epoch=True, sync_dist=True)
#         self.log('avg_val_mcc', avg_val_mcc.detach(), on_epoch=True, sync_dist=True)
#         self.log('avg_val_auroc', avg_val_auroc.detach(), on_epoch=True, sync_dist=True)
        
#         val_epoch_loss_CE.append(avg_val_loss.item())
#         val_epoch_acc_CE.append(avg_val_f1w.item())
#         #print(f'VAL-Epoch {self.current_epoch} LOSS:{avg_val_loss} ACC:{avg_val_acc} ')
#     def configure_optimizers(self):
#         optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
#         lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
#                                                                     T_0=10,
#                                                                     T_mult=1,
#                                                                     eta_min=1e-7,
#                                                                     verbose=True,
#                                                                     )

#         # lr_scheduler = {'scheduler': MultiStepLR(optimizer, milestones=[10,20,30,40], gamma=0.5,),'interval': 'epoch','frequency':1}
#         return [optimizer], [lr_scheduler]

## Probando EfficientNet

In [6]:
# timm.list_models(pretrained=True)
# eff_model = timm.create_model('tf_efficientnet_b7',pretrained='True',num_classes=12)
eff_model = timm.create_model('densenet121',pretrained='True',num_classes=12)


In [7]:
eff_model_focal_sampler  = LitModel_Focal(model=eff_model)


In [8]:
!ls /mnt/gpid08/datasets/remote_sensing/tmp_from_gpid07/honey/results/Honey_densenet121/

'Best-epoch=11-val_loss=0.04-avg_val_f1w=0.95.ckpt'   pngs    wandb
 ious						      preds


In [9]:
%%time
model_name = "Honey_densenet121/Best-epoch=11-val_loss=0.04-avg_val_f1w=0.95.ckpt"
baseline_model  = LitModel_Focal(model=eff_model)


checkpoint =  torch.load('/mnt/gpid08/datasets/remote_sensing/tmp_from_gpid07/honey/results/'+ model_name)

baseline_model.load_state_dict(checkpoint['state_dict'])

targets, preds = [],[]
baseline_model.to('cuda')
for ii, data in enumerate(dm.test_dataloader()):
    with torch.no_grad():
        targets.append(data['target'].numpy()) #torch.argmax(data['target'],dim=1).numpy()
        # y = y.reshape((-1,1))
        # print(y.shape)
        ## inference
        o = baseline_model.predict(data['image'].to('cuda')).cpu().numpy()
        # print(o.shape)
        preds.append(o)
        # print(y)
        # print(o)
        # if ii==2:
        #   break
print(ii)
preds2 = np.vstack([x for x in preds]).reshape(-1,1)
targets2 = np.vstack([x for x in targets]).reshape(-1,1)
target_names = ['Pinus','Erica.m', 'Cistus sp', 'Lavandula', 'Citrus sp', 'Helianthus annuus',
          'Eucalyptus sp.', 'Rosmarinus officinalis', 'Brassica', 'Cardus', 'Tilia', 'Taraxacum']

print(classification_report(targets2, preds2, target_names=target_names))


896
                        precision    recall  f1-score   support

                 Pinus       0.00      0.00      0.00        28
               Erica.m       0.87      0.32      0.47       184
             Cistus sp       0.10      0.72      0.18        69
             Lavandula       0.00      0.00      0.00        74
             Citrus sp       0.00      0.00      0.00        53
     Helianthus annuus       0.99      0.80      0.88        91
        Eucalyptus sp.       0.00      0.00      0.00        95
Rosmarinus officinalis       0.00      0.00      0.00        52
              Brassica       0.32      0.48      0.38       140
                Cardus       0.12      0.32      0.17        19
                 Tilia       0.00      0.00      0.00        67
             Taraxacum       0.40      0.08      0.13        25

              accuracy                           0.29       897
             macro avg       0.23      0.23      0.18       897
          weighted avg       0.35 