### Libraries

In [10]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transform
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import os

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


import os
import math
import pandas as pd
import numpy as np
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn

# import src.model.trainer as trainer
# import src.model.config as config
# import src.utils.other_utils as other_utils
# import src.data.make_dataset as data_funcs

### Trainer

In [11]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    # model.load_state_dict(checkpoint)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [12]:
import torch
import torch.nn as nn


class Trainer():

  def __init__(self, model, criterion = None, optimizer = None, device = "cpu"):
    
    self.model = model
    self.criterion = criterion
    self.optimizer = optimizer
    self.device = device
    

  def train_step(self, train_dataloader):
    
    # Set model to training mode
    self.model.train()

    epoch_loss, epoch_accuracy = 0, 0

    for batch, (X,y) in enumerate(train_dataloader):
      X = X.to(self.device)
      y = y.to(self.device)

      # Reset Gradients
      self.optimizer.zero_grad()

      # Prediction
      out = self.model(X)

      # Calculation Loss
      loss = self.criterion(out, y)

      # Calculating Gradients
      loss.backward()

      # Update Weights
      self.optimizer.step()

      # Calculating Performance Metrics
      epoch_loss += loss.detach().item() / X.shape[0]
      #epoch_accuracy += (torch.argmax(torch.softmax(out[:,0,:], dim=1), dim=1) == y).sum() / X.shape[0]

    return epoch_loss#, epoch_accuracy


  def eval_step(self, val_dataloader):
    
    # Set model to training mode
    self.model.eval()

    epoch_loss, epoch_accuracy = 0, 0
    y_trues, y_probs = [], []
    #epoch_accuracy = 0
    with torch.inference_mode():
      for batch, (X,y) in enumerate(val_dataloader):
        X = X.to(self.device)
        y = y.to(self.device)

        # Prediction
        out = self.model(X)
        # Calculation Loss
        loss = self.criterion(out, y)

        # Calculating Performance Metrics
        epoch_loss += loss.item() / X.shape[0]
        #epoch_accuracy += (torch.argmax(torch.softmax(out[:,0,:], dim=1), dim=1) == y).sum() / X.shape[0]

    return epoch_loss#, epoch_accuracy

  def predict_step(self, val_dataloader):

    self.model.eval()

    y_preds = []
    y_probs = []

    with torch.inference_mode():
      for batch, (X,y) in enumerate(val_dataloader):
        X = X.to(self.device)
        y = y.to(self.device)

        # Prediction
        out = self.model(X)

        # Calculating Performance Metrics
        y_prob = torch.softmax(out[:,0,:], dim=1)
        y_pred = torch.argmax(torch.softmax(out[:,0,:], dim=1), dim=1)

        y_probs = [y_prob[index, y_pred[index]].item() for index in range(len(y_pred))]
        y_preds.extend(y_pred)

    return y_probs, y_preds


  def train(self, train_dataloader, val_dataloader, num_epochs = 5):
    xx = False
    if xx:
        load_checkpoint(
            "/Users/mpekey/Desktop/Mert_SabanciUniv/CS518/HoneyPollenClassification/model_checkpoints/model_ch.pth",
            self.model,
            self.optimizer,
            1e-4,
        )

    for epoch in range(num_epochs):

      train_loss, train_accuracy = self.train_step(train_dataloader)
      val_loss, val_accuracy = self.eval_step(val_dataloader)

      if xx:
        save_checkpoint(self.model, self.optimizer, filename="/Users/mpekey/Desktop/Mert_SabanciUniv/CS518/HoneyPollenClassification/model_checkpoints/model_ch.pth")

      # Logging
      print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.5f}, "
          f"train_acc: {train_accuracy:.5f}, "
          f"val_loss: {val_loss:.5f}, "
          f"val_acc: {val_accuracy:.5f}"
      )

### Dataset

In [31]:
import os
import math
import pandas as pd
import numpy as np
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

def get_pollen_family(pollen_type):

    if pollen_type in ['Acanthus dioscoridis', 'Acanthus sp']:
        return 'Acanthaceae'
    elif pollen_type in ['Sambucus ebulus', 'Sambucus nigra', 'Viburnum lantana','Viburnum lanata']:
        return 'Adoxaceae'
    elif pollen_type in ['Chenopodium foliosum','Amaranthaceae sp']:
        return 'Amaranthaceae'
    elif pollen_type in ['Allium rotundum', 'Allium ampeloprasum', 'Allium sp']:
        return 'Amaryllidaceae'
    elif pollen_type in ['Cotinus coggygria', 'Rhus coriaria', 'Rhus sp']:
        return 'Anacardiaceae'
    elif pollen_type in ['Eryngium campestre', 'Ferula orientalis', 'Malabaila lasiocarpa', 'Lecokia cretica', 'Pimpinella sp','Apiaceae sp']:
        return 'Apiaceae'
    elif pollen_type in ['Leopoldia tenuiflora','Ornithogalum narbonense','Muscari tenuifolium']:
        return 'Asparagaceae'
    elif pollen_type in ['Gundelia sp','Circium arvense','Centaurea urvellei','Cardus','Matricaria chamomila','Scorzonera latifolia','Asteraceae sp','Artemisia absinthium','Achillea arabica', 'Achilla arabica', 'Achillea millefolium', 'Achillea vermicularis','Anthemis cretica','Arctium minus','Arctium minus','Bellis perennis',
    'Carduus nutans','Centaurea bingoelensis','Centaurea iberica','Centaurea kurdica','Centaurea saligna','Centaurea solstitialis','Centaurea spectabilis',
    'Centaurea urvillei','Centaurea virgata','Chondrilla brevirostris','Chondrilla juncea','Cichorium intybus','Cirsium arvense','Cirsium yildizianum',
    'Cota altissima','Crepis sancta','Cyanus triumfettii','Echinops pungens','Gundelia tournefortii','Helichrysum arenarium','Helichrysum plicatum',
    'Iranecio eriospermus','Senecio eriospermus','Matricaria chamomilla','Onopordum acanthium','Onopordum acanthium','Tanacetum balsamita','Tanacetum zahlbruckneri',
    'Taraxacum campylodes','Taraxacum officinale','Tussilago farfara','Xeranthemum annuum','Xeranthemum longipapposum','Artemisia sp','Carduus sp','Centaurea sp',
    'Cichorium sp','Cirsium sp','Cirsium yildizianum','Echinops sp','Echinops sp','Helianthus sp','Helichrysum sp','Onopordum sp','Ptilostemon sp','Xanthium sp','Xeranthemum sp','Xeranthemum longipopposum','Xeranthemum annum']:
        return 'Asteraceae'
    elif pollen_type in ['Alkanna orientalis','Anchusa azurea','Anchusa leptophylla','Cerinthe minor','Echium italicum','Myosotis alpestris','Myosotis laxa','Myosotis stricta','Myosotis sylvatica','cyanea','Phyllocara aucheri','Anchusa sp','Echium sp']:
        return 'Boraginaceae'
    elif pollen_type in ['Capsella bursa pastoris','Aethionema grandiflorum', 'Capsella bursa-pastoris', 'Isatis glauca','Lepidium draba','Raphanus raphanistrum']:
        return 'Brassicaceae'
    elif pollen_type in ['Campanula glomerata','Campanula involucrata', 'Campanula propinqua','Campanula stricta', 'stricta', 'Campanula sp']:
        return 'Campanulaceae'
    elif pollen_type in ['Centranthus longifolius','Centranthus longiflorus', 'Morina persica', 'Scabiosa columbaria', 'Scabiosa rotata', 'Cephalaria sp', 'Scabiosa sp', 'Valeriana sp']:
        return 'Caprifoliaceae'
    elif pollen_type in ['Cerastium armeniacum', 'Saponaria prostrata', 'Saponaria viscosa', 'Saponaria viscosa', 'Silene spergulifolia','Dianthus sp','Silene sp','Silene compacta']:
        return 'Caryophyllaceae'
    elif pollen_type in ['Chenopodium sp']:
        return 'Chenopodiaceae'
    elif pollen_type in ['Convolvulus arvensis', 'Convolvulus galaticus', 'Convolvulus lineatus', 'Convolvulus sp']:
        return 'Convolvulaceae'
    elif pollen_type in ['Cornus sanguinea']:
        return 'Cornaceae'
    elif pollen_type in ['Phedimus obtusifolius']:
        return 'Crassulaceae'
    elif pollen_type in ['Elaeagnus angustifolia']:
        return 'Elaeagnaceae'
    elif pollen_type in ['Euphorbia esula','tommasiniana', 'Euphorbia macrocarpa', 'Euphorbia sp']:
        return 'Euphorbiaceae'
    elif pollen_type in ['Glycyrrhiza glabra','Astragalus topolanense','Astragalus sp','Astragalus pinetorium','Astragalus lagopoides','Securigera varia','Astracantha gummifera','Trifolium campestre-yeniden','Astragalus gummifer','Astracantha kurdica','Astragalus kurdicus','Astracantha muschiana','Astragalus muschianus','Astragalus aduncus','Astragalus bingollensis','Astragalus bustillosii','Astragalus brachycalyx','brachycalyx','Astragalus caspicus','Astrgalus lagopoides','Astrgalus lagopoides Lam','Astragalus onobrychis','Astragalus oocephalus','Astragalus pinetorum','Astragalus saganlugensis','Astragalus topalanense','Colutea cilicica','Genista aucheri','Genista aucheri','Lathyrus brachypterus','Lathyrus satdaghensis','Lotus corniculatus','Lotus gebelia','Hedysarum varium','Medicago sativa','Melilotus albus','Melilotus officinalis','Onobrychis viciifolia','Ononis spinosa','Robinia pseudoacacia','Robinia pseudoacacia','Trifolium campestre','Trifolium diffusum','Trifolium nigrescens','Trifolium pauciflorum','Trifolium pratense','Trifolium resupinatum','Vicia cracca','cracca','Astragalus gummifer','Astragalus longifolius','Astragalus topalanense','Astragalus spp','Coronilla sp','Hedysarum sp','Lathyrus sp','Lotus sp','Melilotus sp','Trifolium spp','Vicia sp']:
        return 'Fabaceae'
    elif pollen_type in ['Quercus petraea', 'pinnatiloba']:
        return 'Fagaceae'
    elif pollen_type in ['Geranium tuberosum', 'Geranium sp']:
        return 'Geraniaceae'
    elif pollen_type in ['Hypericum sp','Hypericum lydium', 'Hypericum perforatum','Hypericum scabrum','Hypericum spp']:
        return 'Hypericaceae'
    elif pollen_type in ['Ixiolirion tataricum']:
        return 'Ixioliriaceae'
    elif pollen_type in ['Marrubium astracanium','Salvia palestina','Ajuga chamaepitys','chia','Lamium album','Lamium garganicum','Lamium macrodon','Marrubium astracanicum','Marrubium vulgare','Mentha longifolia','longifolia','Mentha spicata','Nepeta baytopii','Nepeta cataria','Nepeta nuda','Nepeta trachonitica','Origanum acutidens','Origanum vulgare','gracile','Phlomis armeniaca','Phlomis herba-venti','pungens','Phlomis pungens','Phlomis kurdica','Salvia frigida','Salvia limbata','Salvia macrochlamys','Salvia multicaulis','Salvia palaestina','Salvia sclarea','Salvia staminea','Salvia trichoclada','Salvia verticillata','Salvia virgata','Satureja hortensis','Stachys annua','Stachys lavandulifolia','Teucrium chamaedrys','Teucrium orientale','Teucrium orientale','Teucrium polium','Thymus kotschyanus','Thymus pubescens','Lamium macrodon','Lamium sp','Origanum sp','Phlomis sp','Teucrium sp','Thymus spp']:
        return 'Lamiaceae'
    elif pollen_type in ['Linum mucronatum','armenum']:
        return 'Linaceae'
    elif pollen_type in ['Lythrum salicaria']:
        return 'Lythraceae'
    elif pollen_type in ['Malvaceae','Alcea apterocarpa','Alcea remotiflora','Malva neglecta','Alcea sp','Malva sp','Tilia sp']:
        return 'Malvaceae'
    elif pollen_type in ['Morus alba']:
        return 'Moraceae'
    elif pollen_type in ['Epilobium parviflorum']:
        return 'Onagraceae'
    elif pollen_type in ['Fumaria parviflora','Fumaria schleicheri','microcarpa','Papaver dubium','Papaver orientale']:
        return 'Papaveraceae'
    elif pollen_type in ['Anarrhinum orientale','Anarhinum orientale','Globularia trichosantha','Lagotis stolonifera','Linaria pyramidata','Plantago lanceolata','Plantago major','Plantago media','Linaria sp']:
        return 'Plantaginaceae'
    elif pollen_type in ['Acantholimon acerosum','Acantholimon armenum','Acantholimon calvertii','Acantholimon sp']:
        return 'Plumbaginaceae'
    elif pollen_type in ['Poaceae','Zea mays']:
        return 'Poaceae'
    elif pollen_type in ['Polygonum cognatum','Rheum ribes','Rumex acetosella','Rumex scutatus','Rumex sp']:
        return 'Polygonaceae'
    elif pollen_type in ['Lysimachia punctata', 'Lysimacha vulgaris']:
        return 'Primulaceae'
    elif pollen_type in ['Portulaca sp']:
        return 'Portulacaceae'
    elif pollen_type in ['Ranunculus kotchii','Ficaria fascicularis','Ranunculus kochii','Ranunculus heterorrhizus','Ranunculus kotschyi','Ranunculus sp']:
        return 'Ranunculaceae'
    elif pollen_type in ['Paliurus spina-christi']:
        return 'Rhamnaceae'
    elif pollen_type in ['Crateagus orientalis','Crateagus monogyna','Potentilla inclinata','Rosaceae','Sanguisorba minör','Agrimonia repens','Cotoneaster nummularius','Crataegus orientalis','Crataegus monogyna','Filipendula ulmaria','Malus sylvestris','Potentilla anatolica','Potentilla argentea','Prunus divaricata','ursina','Pyrus elaeagnifolia','Rosa canina','Rosa foetida','Rubus caesius','Rubus sanctus','Sanguisorba minor','lasiocarpa','Sorbus torminalis','Filipendula sp','Potentilla sp','Rosa canina']:
        return 'Rosaceae'
    elif pollen_type in ['Galium consanguineum','Galium verum','Galium sp']:
        return 'Rubiaceae'
    elif pollen_type in ['Citrus sp']:
        return 'Rutaceae'
    elif pollen_type in ['Salix alba','Salix caprea','Salix sp']:
        return 'Salicaceae'
    elif pollen_type in ['Verbascum armenum','Verbascum diversifolium','Verbascum gimgimense','Verbascum lasianthum','Verbascum sinuatum','Verbascum spp','Verbascum sinatum']:
        return 'Scrophulariaceae'
    elif pollen_type in ['Tamarix smyrnensis','Tamarix tetrandra']:
        return 'Tamaricaceae'
    elif pollen_type in ['Eremurus spectabilis','Eremurus sp']:
        return 'Xanthorrhoeaceae'
    elif pollen_type in ['Tribulus terrestris']:
        return 'Zygophyllaceae'
    else:
        return 'notfound'

def get_dataset_roots(dir_path):

    class_dict = {}
    
    for dirpath, dirnames, filenames in os.walk(dir_path):
        class_name = dirpath.split('/')[-1]
        if class_name != 'dataset':
            for filename in os.listdir(dirpath):
                class_dict[os.path.join(dir_path, class_name, filename)] = class_name
                
    data_dict = pd.DataFrame({'img_file':list(class_dict.keys()), 'type':list(class_dict.values())})
    data_dict['family'] = data_dict['type'].apply(get_pollen_family)
    return data_dict


def split_datasets(pollen_df, class_amt_df, label_name = 'type'):

    class_names = pollen_df[label_name].unique()

    first_check = True

    for t in class_names:

        img_amount = class_amt_df[class_amt_df[label_name] == t].img_num.iloc[0]

        if label_name == 'type':
            if img_amount >= 5:
                split_ratio = (0.6,0.2,0.2) # May be changed
            elif img_amount == 4:
                split_ratio = (0.5,0.25,0.25)
            elif img_amount == 3:
                split_ratio = (0.33,0.33,0.33)
            elif img_amount == 2:
                split_ratio = (0.5,0.5,0.0)
        elif label_name == 'family':
            if img_amount >= 5:
                split_ratio = (0.6,0.2,0.2)
            elif img_amount == 4:
                split_ratio = (0.5,0.25,0.25)
            elif img_amount == 3:
                split_ratio = (0.33,0.33,0.33)
            elif img_amount == 2:
                split_ratio = (0.5,0.5,0.0)

        idx_list = list(pollen_df[pollen_df[label_name] == t].index)

        train_idx = np.random.choice(idx_list, size=math.ceil(len(idx_list)*split_ratio[0]), replace=False)
        remaining_idx = []
        for i in idx_list:
            if i not in train_idx:
                remaining_idx.append(i)

        val_idx = np.random.choice(remaining_idx, size=math.ceil(len(remaining_idx)*(split_ratio[1]/(1-split_ratio[0]))), replace=False)

        test_idx = []
        for i in remaining_idx:
            if i not in val_idx:
                test_idx.append(i)

        if first_check:
            train_data = pollen_df.loc[train_idx,['img_file', label_name]]
            val_data = pollen_df.loc[val_idx,['img_file', label_name]]
            test_data = pollen_df.loc[test_idx,['img_file', label_name]]
            first_check = False
        else:
            train_data = pd.concat([train_data, pollen_df.loc[train_idx,['img_file', label_name]]], axis = 0)
            val_data = pd.concat([val_data, pollen_df.loc[val_idx,['img_file', label_name]]], axis = 0)
            test_data = pd.concat([test_data, pollen_df.loc[test_idx,['img_file', label_name]]], axis = 0)

    return train_data.reset_index(drop=True), val_data.reset_index(drop=True), test_data.reset_index(drop=True)


class PollenDataset(Dataset):
    
    def __init__(self, data, transform=None, is_family=False):
        
        self.transform = transform
        
        self.data = data
        if is_family:
            self.class_names = self.data['family'].unique()
        else:
            self.class_names = self.data['type'].unique()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        img_root = self.data.iloc[idx,0]
        label = self.data.iloc[idx, 1]
        
        #img = torchvision.io.read_image(img_root)
        img = Image.open(img_root)
        
        if self.transform:
            img = self.transform(img)
        print(img.shape)
        print(label)
        return (img, label)
        

### Norm Params

In [14]:
import torchvision.transforms as T
def get_data_transform(means, stdevs, is_train = True):

    if is_train:
        data_transform = T.Compose([
            T.Resize((32,32)),
            T.Normalize(mean = means, std = stdevs),
            T.ToTensor()
        ])
    else:
        data_transform = T.Compose([
            T.Resize((32,32)),
            T.Normalize(mean = means, std = stdevs),
            T.ToTensor()
        ])
    return data_transform

In [37]:
def get_normalization_params(dataset): ## Tum dataset olarak degistir

    X, _ = dataset
    means = []
    stdevs = []
    
    n_channels = X.shape[1]
    for c in range(n_channels):
        mean = torch.mean(X[:, c])
        std = torch.std(X[:, c])

        means.append(mean)
        stdevs.append(std)

    return means, stdevs

### Train

In [33]:
class_type = 'family'
DUMMY_TRANSFORM = T.Compose([
            T.Resize((32,32)),
            T.ToTensor(),
            T.Normalize(mean = (0,0,0), std = (1,1,1))
        ])

DATA_ROOT = '/Users/mpekey/Desktop/Mert_SabanciUniv/CS518/HoneyPollenClassification/dataset'

pollen_df = get_dataset_roots(DATA_ROOT)
family_amt_df = pollen_df['family'].value_counts().reset_index().rename(columns={'index':'family', 'family':'img_num'})
type_amt_df = pollen_df['type'].value_counts().reset_index().rename(columns={'index':'type', 'type':'img_num'})

train_df, val_df, test_df = split_datasets(pollen_df, family_amt_df, label_name = class_type)

dummy_train_dataset = PollenDataset(data=train_df, transform=DUMMY_TRANSFORM, is_family=True if class_type=='family' else False)
means, stdevs = get_normalization_params(dummy_train_dataset)

train_transform = get_data_transform(means, stdevs, is_train=True)
test_transform = get_data_transform(means, stdevs, is_train=False)

# Create Datasets
#Dataset() icine koy
train_dataset = PollenDataset(data=train_df, transform=train_transform, is_family=True if class_type=='family' else False)
val_dataset = PollenDataset(data=val_df, transform=test_transform, is_family=True if class_type=='family' else False)
test_dataset = PollenDataset(data=test_df, transform=test_transform, is_family=True if class_type=='family' else False)

# Create Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

torch.Size([3, 32, 32])
Brassicaceae
torch.Size([3, 32, 32])
Brassicaceae
torch.Size([3, 32, 32])
Brassicaceae


ValueError: too many values to unpack (expected 2)

In [9]:
# Get Model
from torchvision.models import resnet50#, ResNet50_Weights
model = resnet50()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [None]:
# Train
model_trainer = Trainer(model = model,
                                criterion = criterion,
                                optimizer = optimizer, 
                                device = 'cpu')

In [None]:
model_trainer.train(train_dataloader, val_dataloader, num_epochs = 1)