### Libraries

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime



import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import math
import os
from PIL import Image

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
seed=22
torch.manual_seed(seed) # Seed for general torch operations
torch.cuda.manual_seed(seed) # Seed for CUDA
torch.cuda.manual_seed_all(seed)

import random
random.seed(seed)

import numpy as np
np.random.seed(seed)

### Configs

In [4]:
# CONFIGS

LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_MODEL = "/Users/mpekey/Desktop/Mert_SabanciUniv/CS518/HoneyPollenClassification/model_checkpoints/model_ch.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-3
NUM_EPOCHS = 15
BATCH_SIZE = 8
NUM_WORKERS = 2
IMG_CHANNELS = 3
IMG_SIZE = (232,232)
DATA_ROOT = '/content/drive/MyDrive/CS518/dataset_resized'


train_transform = T.Compose([
            #T.Resize(IMG_SIZE),  
            T.CenterCrop((224,224)),
            T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
            T.RandomHorizontalFlip(0.5),
            T.RandomVerticalFlip(0.5),
            #T.RandomRotation(degrees=(0, 90)),
            T.ToTensor(),
            T.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
            #mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]
        ])
test_transform = T.Compose([
            #T.Resize(IMG_SIZE),
            T.CenterCrop((224,224)),
            T.ToTensor(),
            T.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        ])

### Trainer

In [4]:
def top_accuracy(output, target, topk=(1,)):

  maxk = max(topk)
  batch_size = target.size(0)

  _, y_pred = output.topk(k=maxk, dim=1)
  y_pred = y_pred.t()
  target_reshaped = target.view(1, -1).expand_as(y_pred)
  
  correct = (y_pred == target_reshaped)


  list_topk_accs = []
  for k in topk:
      
      ind_which_topk_matched_truth = correct[:k]
      
      flattened_indicator_which_topk_matched_truth = ind_which_topk_matched_truth.reshape(-1).float()
      tot_correct_topk = flattened_indicator_which_topk_matched_truth.float().sum(dim=0, keepdim=True)
      
      topk_acc = tot_correct_topk.item() / float(batch_size)
      list_topk_accs.append(topk_acc)
  return list_topk_accs  

In [5]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    # model.load_state_dict(checkpoint)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def create_writer(experiment_name: str, model_name: str, extra: str=None):

    # Get timestamp of current date (all experiments on certain day live in same folder)
    timestamp = datetime.now().strftime("%Y-%m-%d") # returns current date in YYYY-MM-DD format

    if extra:
        # Create log directory path
        log_dir = os.path.join("runs", timestamp, experiment_name, model_name, extra)
    else:
        log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
        
    print(f"[INFO] Created SummaryWriter, saving to: {log_dir}...")
    return SummaryWriter(log_dir=log_dir)

def accuracy_fn(y_true, y_pred):
  correct = torch.eq(y_true, y_pred).sum().item()
  acc = (correct / len(y_pred)) * 100
  return acc

In [6]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


class Trainer():

  def __init__(self, model, criterion = None, optimizer = None, lr_scheduler = None, device = "cpu", model_name = None, experiment_name = None):
    
    self.model = model
    self.criterion = criterion
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.device = device
    self.model_name = model_name
    self.experiment_name = experiment_name

    self.writer = create_writer(self.experiment_name, self.model_name, None)
    

  def train_step(self, train_dataloader):
    
    # Set model to training mode
    self.model.train()

    epoch_loss, epoch_accuracy = 0, 0
    topk_acc_list = 0
    for batch, (X,y) in enumerate(train_dataloader):
      X = X.to(self.device)
      y = y.to(self.device)

      # Reset Gradients
      self.optimizer.zero_grad()

      # Prediction
      out = self.model(X)

      # Calculation Loss
      loss = self.criterion(out, y)

      # Calculating Gradients
      loss.backward()

      # Update Weights
      self.optimizer.step()
      if self.lr_scheduler is not None:
        self.lr_scheduler.step()

      # Calculating Performance Metrics
      epoch_loss += loss.item()
      epoch_accuracy += accuracy_fn(y_true=y, y_pred=out.argmax(dim=1))

    epoch_loss /= len(train_dataloader)
    epoch_accuracy /= len(train_dataloader)

    return epoch_loss, epoch_accuracy


  def eval_step(self, val_dataloader):
    
    # Set model to training mode
    self.model.eval()

    epoch_loss, epoch_accuracy = 0, 0
    topk_acc_list = [0, 0, 0]
    y_trues, y_probs = [], []

    with torch.inference_mode():
      for batch, (X,y) in enumerate(val_dataloader):
        X = X.to(self.device)
        y = y.to(self.device)

        # Prediction
        out = self.model(X)
        # Calculation Loss
        loss = self.criterion(out, y)

        # Calculating Performance Metrics
        epoch_loss += loss.item()
        epoch_accuracy += accuracy_fn(y_true=y, y_pred=out.argmax(dim=1))

        topk_list = top_accuracy(out, y, topk=(1,3,5))
        topk_acc_list[0] += topk_list[0]
        topk_acc_list[1] += topk_list[1]
        topk_acc_list[2] += topk_list[2]

      topk_acc_list[0] /= len(val_dataloader)
      topk_acc_list[1] /= len(val_dataloader)
      topk_acc_list[2] /= len(val_dataloader)
      epoch_loss /= len(val_dataloader)
      epoch_accuracy /= len(val_dataloader)

    return epoch_loss, epoch_accuracy, topk_acc_list


  def predict_step(self, val_dataloader):

    self.model.eval()

    y_preds = []
    y_probs = []

    with torch.inference_mode():
      for batch, (X,y) in enumerate(val_dataloader):
        X = X.to(self.device)
        y = y.to(self.device)

        # Prediction
        out = self.model(X)

        # Calculating Performance Metrics
        y_prob = torch.softmax(out[:,0,:], dim=1)
        y_pred = torch.argmax(torch.softmax(out[:,0,:], dim=1), dim=1)

        y_probs = [y_prob[index, y_pred[index]].item() for index in range(len(y_pred))]
        y_preds.extend(y_pred)

    return y_probs, y_preds


  def train(self, train_dataloader, val_dataloader, num_epochs = 5, patience = 5):

    results = {"train_loss": [],
               "train_acc": [],
               "val_loss": [],
               "val_acc": []
    }

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_MODEL,
            self.model,
            self.optimizer,
            LEARNING_RATE,
        )

    best_val_loss = np.inf

    for epoch in tqdm(range(num_epochs)):

      train_loss, train_accuracy = self.train_step(train_dataloader)
      val_loss, val_accuracy, topk_list_val = self.eval_step(val_dataloader)

      # Early stopping
      if val_loss < best_val_loss:
          best_val_loss = val_loss
          best_model = self.model
          _patience = patience
      else:
          _patience -= 1
      if not _patience:
          print("Stopping early!")
          break


      if SAVE_MODEL:
        save_checkpoint(self.model, self.optimizer, filename=CHECKPOINT_MODEL)

      # Logging
      print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.5f}, "
          f"train_acc: {train_accuracy:.5f}, "
          f"val_loss: {val_loss:.5f}, "
          f"val_acc: {val_accuracy:.5f}, "
          f"top1_val acc: {topk_list_val[0]:.5f}, "
          f"top3_val acc: {topk_list_val[1]:.5f}, "
          f"top5_val acc: {topk_list_val[2]:.5f}"
      )

      results["train_loss"].append(train_loss)
      results["train_acc"].append(train_accuracy)
      results["val_loss"].append(val_loss)
      results["val_acc"].append(val_accuracy)

      # Add loss results to SummaryWriter
      self.writer.add_scalars(main_tag="Loss", 
                          tag_scalar_dict={"train_loss": train_loss,
                                          "val_loss": val_loss},
                          global_step=epoch)

      # Add accuracy results to SummaryWriter
      self.writer.add_scalars(main_tag="Accuracy", 
                          tag_scalar_dict={"train_acc": train_accuracy,
                                          "val_acc": val_accuracy}, 
                          global_step=epoch)
      
      # Track the PyTorch model architecture
      self.writer.add_graph(model=self.model, 
                        input_to_model=torch.randn(BATCH_SIZE, 3, IMG_SIZE[0], IMG_SIZE[1]).to(DEVICE))
    
      # Close the writer
    self.writer.close()

    # Assign best model to the last model
    return best_model

### Dataset

In [7]:
def get_pollen_family(pollen_type):

    if pollen_type in ['Acanthus dioscoridis', 'Acanthus sp']:
        return 'Acanthaceae'
    elif pollen_type in ['Sambucus ebulus', 'Sambucus nigra', 'Viburnum lantana','Viburnum lanata']:
        return 'Adoxaceae'
    elif pollen_type in ['Chenopodium foliosum','Amaranthaceae sp']:
        return 'Amaranthaceae'
    elif pollen_type in ['Allium rotundum', 'Allium ampeloprasum', 'Allium sp']:
        return 'Amaryllidaceae'
    elif pollen_type in ['Cotinus coggygria', 'Rhus coriaria', 'Rhus sp']:
        return 'Anacardiaceae'
    elif pollen_type in ['Eryngium campestre', 'Ferula orientalis', 'Malabaila lasiocarpa', 'Lecokia cretica', 'Pimpinella sp','Apiaceae sp']:
        return 'Apiaceae'
    elif pollen_type in ['Leopoldia tenuiflora','Ornithogalum narbonense','Muscari tenuifolium']:
        return 'Asparagaceae'
    elif pollen_type in ['Gundelia sp','Circium arvense','Centaurea urvellei','Cardus','Matricaria chamomila','Scorzonera latifolia','Asteraceae sp','Artemisia absinthium','Achillea arabica', 'Achilla arabica', 'Achillea millefolium', 'Achillea vermicularis','Anthemis cretica','Arctium minus','Arctium minus','Bellis perennis',
    'Carduus nutans','Centaurea bingoelensis','Centaurea iberica','Centaurea kurdica','Centaurea saligna','Centaurea solstitialis','Centaurea spectabilis',
    'Centaurea urvillei','Centaurea virgata','Chondrilla brevirostris','Chondrilla juncea','Cichorium intybus','Cirsium arvense','Cirsium yildizianum',
    'Cota altissima','Crepis sancta','Cyanus triumfettii','Echinops pungens','Gundelia tournefortii','Helichrysum arenarium','Helichrysum plicatum',
    'Iranecio eriospermus','Senecio eriospermus','Matricaria chamomilla','Onopordum acanthium','Onopordum acanthium','Tanacetum balsamita','Tanacetum zahlbruckneri',
    'Taraxacum campylodes','Taraxacum officinale','Tussilago farfara','Xeranthemum annuum','Xeranthemum longipapposum','Artemisia sp','Carduus sp','Centaurea sp',
    'Cichorium sp','Cirsium sp','Cirsium yildizianum','Echinops sp','Echinops sp','Helianthus sp','Helichrysum sp','Onopordum sp','Ptilostemon sp','Xanthium sp','Xeranthemum sp','Xeranthemum longipopposum','Xeranthemum annum']:
        return 'Asteraceae'
    elif pollen_type in ['Alkanna orientalis','Anchusa azurea','Anchusa leptophylla','Cerinthe minor','Echium italicum','Myosotis alpestris','Myosotis laxa','Myosotis stricta','Myosotis sylvatica','cyanea','Phyllocara aucheri','Anchusa sp','Echium sp']:
        return 'Boraginaceae'
    elif pollen_type in ['Capsella bursa pastoris','Aethionema grandiflorum', 'Capsella bursa-pastoris', 'Isatis glauca','Lepidium draba','Raphanus raphanistrum']:
        return 'Brassicaceae'
    elif pollen_type in ['Campanula glomerata','Campanula involucrata', 'Campanula propinqua','Campanula stricta', 'stricta', 'Campanula sp']:
        return 'Campanulaceae'
    elif pollen_type in ['Centranthus longifolius','Centranthus longiflorus', 'Morina persica', 'Scabiosa columbaria', 'Scabiosa rotata', 'Cephalaria sp', 'Scabiosa sp', 'Valeriana sp']:
        return 'Caprifoliaceae'
    elif pollen_type in ['Cerastium armeniacum', 'Saponaria prostrata', 'Saponaria viscosa', 'Saponaria viscosa', 'Silene spergulifolia','Dianthus sp','Silene sp','Silene compacta']:
        return 'Caryophyllaceae'
    elif pollen_type in ['Chenopodium sp']:
        return 'Chenopodiaceae'
    elif pollen_type in ['Convolvulus arvensis', 'Convolvulus galaticus', 'Convolvulus lineatus', 'Convolvulus sp']:
        return 'Convolvulaceae'
    elif pollen_type in ['Cornus sanguinea']:
        return 'Cornaceae'
    elif pollen_type in ['Phedimus obtusifolius']:
        return 'Crassulaceae'
    elif pollen_type in ['Elaeagnus angustifolia']:
        return 'Elaeagnaceae'
    elif pollen_type in ['Euphorbia esula','tommasiniana', 'Euphorbia macrocarpa', 'Euphorbia sp']:
        return 'Euphorbiaceae'
    elif pollen_type in ['Glycyrrhiza glabra','Astragalus topolanense','Astragalus sp','Astragalus pinetorium','Astragalus lagopoides','Securigera varia','Astracantha gummifera','Trifolium campestre-yeniden','Astragalus gummifer','Astracantha kurdica','Astragalus kurdicus','Astracantha muschiana','Astragalus muschianus','Astragalus aduncus','Astragalus bingollensis','Astragalus bustillosii','Astragalus brachycalyx','brachycalyx','Astragalus caspicus','Astrgalus lagopoides','Astrgalus lagopoides Lam','Astragalus onobrychis','Astragalus oocephalus','Astragalus pinetorum','Astragalus saganlugensis','Astragalus topalanense','Colutea cilicica','Genista aucheri','Genista aucheri','Lathyrus brachypterus','Lathyrus satdaghensis','Lotus corniculatus','Lotus gebelia','Hedysarum varium','Medicago sativa','Melilotus albus','Melilotus officinalis','Onobrychis viciifolia','Ononis spinosa','Robinia pseudoacacia','Robinia pseudoacacia','Trifolium campestre','Trifolium diffusum','Trifolium nigrescens','Trifolium pauciflorum','Trifolium pratense','Trifolium resupinatum','Vicia cracca','cracca','Astragalus gummifer','Astragalus longifolius','Astragalus topalanense','Astragalus spp','Coronilla sp','Hedysarum sp','Lathyrus sp','Lotus sp','Melilotus sp','Trifolium spp','Vicia sp']:
        return 'Fabaceae'
    elif pollen_type in ['Quercus petraea', 'pinnatiloba']:
        return 'Fagaceae'
    elif pollen_type in ['Geranium tuberosum', 'Geranium sp']:
        return 'Geraniaceae'
    elif pollen_type in ['Hypericum sp','Hypericum lydium', 'Hypericum perforatum','Hypericum scabrum','Hypericum spp']:
        return 'Hypericaceae'
    elif pollen_type in ['Ixiolirion tataricum']:
        return 'Ixioliriaceae'
    elif pollen_type in ['Marrubium astracanium','Salvia palestina','Ajuga chamaepitys','chia','Lamium album','Lamium garganicum','Lamium macrodon','Marrubium astracanicum','Marrubium vulgare','Mentha longifolia','longifolia','Mentha spicata','Nepeta baytopii','Nepeta cataria','Nepeta nuda','Nepeta trachonitica','Origanum acutidens','Origanum vulgare','gracile','Phlomis armeniaca','Phlomis herba-venti','pungens','Phlomis pungens','Phlomis kurdica','Salvia frigida','Salvia limbata','Salvia macrochlamys','Salvia multicaulis','Salvia palaestina','Salvia sclarea','Salvia staminea','Salvia trichoclada','Salvia verticillata','Salvia virgata','Satureja hortensis','Stachys annua','Stachys lavandulifolia','Teucrium chamaedrys','Teucrium orientale','Teucrium orientale','Teucrium polium','Thymus kotschyanus','Thymus pubescens','Lamium macrodon','Lamium sp','Origanum sp','Phlomis sp','Teucrium sp','Thymus spp']:
        return 'Lamiaceae'
    elif pollen_type in ['Linum mucronatum','armenum']:
        return 'Linaceae'
    elif pollen_type in ['Lythrum salicaria']:
        return 'Lythraceae'
    elif pollen_type in ['Malvaceae','Alcea apterocarpa','Alcea remotiflora','Malva neglecta','Alcea sp','Malva sp','Tilia sp']:
        return 'Malvaceae'
    elif pollen_type in ['Morus alba']:
        return 'Moraceae'
    elif pollen_type in ['Epilobium parviflorum']:
        return 'Onagraceae'
    elif pollen_type in ['Fumaria parviflora','Fumaria schleicheri','microcarpa','Papaver dubium','Papaver orientale']:
        return 'Papaveraceae'
    elif pollen_type in ['Anarrhinum orientale','Anarhinum orientale','Globularia trichosantha','Lagotis stolonifera','Linaria pyramidata','Plantago lanceolata','Plantago major','Plantago media','Linaria sp']:
        return 'Plantaginaceae'
    elif pollen_type in ['Acantholimon acerosum','Acantholimon armenum','Acantholimon calvertii','Acantholimon sp']:
        return 'Plumbaginaceae'
    elif pollen_type in ['Poaceae','Zea mays']:
        return 'Poaceae'
    elif pollen_type in ['Polygonum cognatum','Rheum ribes','Rumex acetosella','Rumex scutatus','Rumex sp']:
        return 'Polygonaceae'
    elif pollen_type in ['Lysimachia punctata', 'Lysimacha vulgaris']:
        return 'Primulaceae'
    elif pollen_type in ['Portulaca sp']:
        return 'Portulacaceae'
    elif pollen_type in ['Ranunculus kotchii','Ficaria fascicularis','Ranunculus kochii','Ranunculus heterorrhizus','Ranunculus kotschyi','Ranunculus sp']:
        return 'Ranunculaceae'
    elif pollen_type in ['Paliurus spina-christi']:
        return 'Rhamnaceae'
    elif pollen_type in ['Crateagus orientalis','Crateagus monogyna','Potentilla inclinata','Rosaceae','Sanguisorba minör','Sanguisorba min”r','Agrimonia repens','Cotoneaster nummularius','Crataegus orientalis','Crataegus monogyna','Filipendula ulmaria','Malus sylvestris','Potentilla anatolica','Potentilla argentea','Prunus divaricata','ursina','Pyrus elaeagnifolia','Rosa canina','Rosa foetida','Rubus caesius','Rubus sanctus','Sanguisorba minor','lasiocarpa','Sorbus torminalis','Filipendula sp','Potentilla sp','Rosa canina']:
        return 'Rosaceae'
    elif pollen_type in ['Galium consanguineum','Galium verum','Galium sp']:
        return 'Rubiaceae'
    elif pollen_type in ['Citrus sp']:
        return 'Rutaceae'
    elif pollen_type in ['Salix alba','Salix caprea','Salix sp']:
        return 'Salicaceae'
    elif pollen_type in ['Verbascum armenum','Verbascum diversifolium','Verbascum gimgimense','Verbascum lasianthum','Verbascum sinuatum','Verbascum spp','Verbascum sinatum']:
        return 'Scrophulariaceae'
    elif pollen_type in ['Tamarix smyrnensis','Tamarix tetrandra']:
        return 'Tamaricaceae'
    elif pollen_type in ['Eremurus spectabilis','Eremurus sp']:
        return 'Xanthorrhoeaceae'
    elif pollen_type in ['Tribulus terrestris']:
        return 'Zygophyllaceae'
    else:
        return 'notfound'

def get_dataset_roots(dir_path):

    class_dict = {}
    
    for dirpath, dirnames, filenames in os.walk(dir_path):
        class_name = dirpath.split('/')[-1]
        if class_name != 'dataset_resized':
            for filename in os.listdir(dirpath):
                class_dict[os.path.join(dir_path, class_name, filename)] = class_name
                
    data_dict = pd.DataFrame({'img_file':list(class_dict.keys()), 'type':list(class_dict.values())})
    data_dict['family'] = data_dict['type'].apply(get_pollen_family)
    return data_dict


def split_datasets(pollen_df, class_amt_df, label_name = 'type'):

    class_names = pollen_df[label_name].unique()

    first_check = True

    for t in class_names:

        img_amount = class_amt_df[class_amt_df[label_name] == t].img_num.iloc[0]

        if label_name == 'type':
            if img_amount >= 5:
                split_ratio = (0.6,0.2,0.2)
            elif img_amount == 4:
                split_ratio = (0.5,0.25,0.25)
            elif img_amount == 3:
                split_ratio = (0.33,0.33,0.33)
            elif img_amount == 2:
                split_ratio = (0.5,0.5,0.0)
        elif label_name == 'family':
            if img_amount >= 5:
                split_ratio = (0.6,0.2,0.2)
            elif img_amount == 4:
                split_ratio = (0.5,0.25,0.25)
            elif img_amount == 3:
                split_ratio = (0.33,0.33,0.33)
            elif img_amount == 2:
                split_ratio = (0.5,0.5,0.0)

        idx_list = list(pollen_df[pollen_df[label_name] == t].index)

        train_idx = np.random.choice(idx_list, size=math.ceil(len(idx_list)*split_ratio[0]), replace=False)
        remaining_idx = []
        for i in idx_list:
            if i not in train_idx:
                remaining_idx.append(i)

        val_idx = np.random.choice(remaining_idx, size=math.ceil(len(remaining_idx)*(split_ratio[1]/(1-split_ratio[0]))), replace=False)

        test_idx = []
        for i in remaining_idx:
            if i not in val_idx:
                test_idx.append(i)

        if first_check:
            train_data = pollen_df.loc[train_idx,['img_file', label_name]]
            val_data = pollen_df.loc[val_idx,['img_file', label_name]]
            test_data = pollen_df.loc[test_idx,['img_file', label_name]]
            first_check = False
        else:
            train_data = pd.concat([train_data, pollen_df.loc[train_idx,['img_file', label_name]]], axis = 0)
            val_data = pd.concat([val_data, pollen_df.loc[val_idx,['img_file', label_name]]], axis = 0)
            test_data = pd.concat([test_data, pollen_df.loc[test_idx,['img_file', label_name]]], axis = 0)

    return train_data.reset_index(drop=True), val_data.reset_index(drop=True), test_data.reset_index(drop=True)


class PollenDataset(Dataset):
    
    def __init__(self, data, transform=None, is_family=False):
        
        self.transform = transform
        self.data = data
        if is_family:
            self.class_names = self.data['family'].unique()
            self.idx_to_class = {i:j for i, j in enumerate(self.class_names)}
            self.class_to_idx = {value:key for key,value in self.idx_to_class.items()}
        else:
            self.class_names = self.data['type'].unique()
            self.idx_to_class = {i:j for i, j in enumerate(self.class_names)}
            self.class_to_idx = {value:key for key,value in self.idx_to_class.items()}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        img_root = self.data.iloc[idx,0]
        label = str(self.data.iloc[idx, 1])
        label_num = self.class_to_idx[label]
        
        #img = torchvision.io.read_image(img_root)
        img = Image.open(img_root)
        
        if self.transform:
            img = self.transform(img)

        return img, label_num
        

### Norm Params

In [38]:
def get_normalization_params(dataset): ## Tum dataset olarak degistir

    means = []
    stdevs = []
    
    n_channels = 3
    for c in range(n_channels):
      mean = 0
      std = 0
      for i in range(len(dataset)):
        
        mean += torch.mean(dataset[i][0][:, c])
        std += torch.std(dataset[i][0][:, c])

      means.append(mean/len(dataset))
      stdevs.append(std/len(dataset)))

    return means, stdevs

In [39]:
data_means, data_stdevs = get_normalization_params(train_dataset)

In [40]:
print(data_means)
print(data_stdevs)

[tensor(1447.2092), tensor(1440.0553), tensor(1455.5039)]
[tensor(244.5075), tensor(231.5544), tensor(229.3636)]


### Get Average Height and Width

In [None]:
DATA_ROOT = '/Users/mpekey/Desktop/Mert_SabanciUniv/CS518/HoneyPollenClassification/dataset'

pollen_df = get_dataset_roots(DATA_ROOT)
family_amt_df = pollen_df['family'].value_counts().reset_index().rename(columns={'index':'family', 'family':'img_num'})
type_amt_df = pollen_df['type'].value_counts().reset_index().rename(columns={'index':'type', 'type':'img_num'})

train_df, val_df, test_df = split_datasets(pollen_df, family_amt_df, label_name = 'family')



In [None]:
heights = 0
widths = 0
for i in range(train_df.shape[0]):
    h, w = Image.open(train_df.iloc[i,0]).size
    heights += h
    widths += w

print(f'Average Height: {heights/train_df.shape[0]}')
print(f'Average Width: {widths/train_df.shape[0]}')

class_type = 'family'

dummy_dataset = PollenDataset(data=train_df, transform=T.ToTensor(), is_family=True if class_type=='family' else False)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=16, shuffle=False)

Average Height: 1853.9306197964847
Average Width: 1549.495837187789


In [None]:
means, stdevs = get_normalization_params(dummy_dataloader)

print(f'Means: {means}')
print(f'Stdevs: {stdevs}')

: 

: 

### Model

In [8]:
class CustomModelFamily(nn.Module):

    def __init__(self, class_names):
        super(CustomModelFamily, self).__init__()
        
        self.convolutional_layer = nn.Sequential(            
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
            #nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(in_channels=512, out_channels=len(class_names), kernel_size=3, stride=1),
            nn.ReLU()
        )

        self.avg1 = nn.AvgPool2d(kernel_size = len(class_names))
        #self.fc = nn.Linear(in_features=len(class_names), out_features=len(class_names))


    def forward(self, x):
        x = self.convolutional_layer(x)
        x = self.avg1(x)
        x = torch.flatten(x, 1)
        return x

In [9]:
def get_model(model_name, class_names, full_train = False, pretrained = False):
    
    if model_name == 'resnet50':
      if pretrained:
          model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT).to(DEVICE)
      else:
          model = torchvision.models.resnet50().to(DEVICE)
      
      if full_train == False:
          for parameter in model.parameters():
              parameter.requires_grad = False
          
      model.fc = nn.Linear(in_features=2048, out_features=len(class_names)).to(DEVICE)
      return model
    elif model_name == 'densenet161':
      if pretrained:
          model = torchvision.models.densenet161(weights=torchvision.models.DenseNet161_Weights.DEFAULT).to(DEVICE)
      else:
          model = torchvision.models.densenet161().to(DEVICE)
      
      if full_train == False:
          for parameter in model.parameters():
              parameter.requires_grad = False
          
      model.classifier = nn.Linear(in_features=2208, out_features=len(class_names), bias=True).to(DEVICE)
      return model
    
    elif model_name == 'resnext':
      if pretrained:
          model = torchvision.models.resnext101_32x8d(weights=torchvision.models.ResNeXt101_32X8D_Weights.DEFAULT).to(DEVICE)
      else:
          model = torchvision.models.resnext101_32x8d().to(DEVICE)
      
      if full_train == False:
          for parameter in model.parameters():
              parameter.requires_grad = False
          
      model.classifier = nn.Linear(in_features=2048, out_features=len(class_names), bias=True).to(DEVICE)
      
      return model

    elif model_name == 'inception':
      if pretrained:
          model = torchvision.models.inception_v3(weights=torchvision.models.Inception_V3_Weights.DEFAULT).to(DEVICE)
      else:
          model = torchvision.models.inception_v3().to(DEVICE)
      
      if full_train == False:
          for parameter in model.parameters():
              parameter.requires_grad = False
          
      model.classifier = nn.Linear(in_features=2048, out_features=len(class_names), bias=True).to(DEVICE)
      
      return model

    elif model_name == 'custom':
      model = CustomModelFamily(class_names).to(DEVICE)
      return model
      

### Train

In [19]:
# CONFIGS

LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_MODEL = "/Users/mpekey/Desktop/Mert_SabanciUniv/CS518/HoneyPollenClassification/model_checkpoints/model_ch.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 15
BATCH_SIZE = 8
NUM_WORKERS = 2
IMG_CHANNELS = 3
IMG_SIZE = (232,232)
DATA_ROOT = '/content/drive/MyDrive/CS518/dataset_resized'
no_augment=False

train_transform = T.Compose([
            #T.Resize(IMG_SIZE),  
            T.CenterCrop((224,224)),
            T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
            T.RandomHorizontalFlip(0.5),
            T.RandomVerticalFlip(0.5),
            T.ToTensor(),
            T.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        ])


# For No Augmentation
if no_augment:
  train_transform = T.Compose([
              #T.Resize(IMG_SIZE),  
              T.CenterCrop((224,224)),
              T.ToTensor(),
              T.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
          ])

test_transform = T.Compose([
            #T.Resize(IMG_SIZE),
            T.CenterCrop((224,224)),
            T.ToTensor(),
            T.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        ])


class_type = 'family' # or type

pollen_df = get_dataset_roots(DATA_ROOT)
family_amt_df = pollen_df['family'].value_counts().reset_index().rename(columns={'index':'family', 'family':'img_num'})
type_amt_df = pollen_df['type'].value_counts().reset_index().rename(columns={'index':'type', 'type':'img_num'})

# Split Data
if class_type == 'type':
  train_df, val_df, test_df = split_datasets(pollen_df, type_amt_df, label_name = class_type)
elif class_type == 'family':
  train_df, val_df, test_df = split_datasets(pollen_df, family_amt_df, label_name = class_type)

# Create Datasets
train_dataset = PollenDataset(data=train_df, transform=train_transform, is_family=True if class_type=='family' else False)
val_dataset = PollenDataset(data=val_df, transform=test_transform, is_family=True if class_type=='family' else False)
test_dataset = PollenDataset(data=test_df, transform=test_transform, is_family=True if class_type=='family' else False)

# Create Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(train_df.shape)
print(val_df.shape)
print(test_df.shape)

print()


model_name = 'densenet161'
experiment_name = '8b_aug_1e4lr_pretrained_full_train_family_ls'

# Get Model
my_model = get_model(model_name=model_name,
                        class_names=train_dataset.class_names,
                        full_train=True,
                        pretrained=True)
optimizer = torch.optim.Adam(my_model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
#lambda1 = lambda epoch: 0.65 ** epoch
#scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
scheduler = None

model_trainer = Trainer(model = my_model,
                                criterion = criterion,
                                optimizer = optimizer, 
                                lr_scheduler = scheduler,
                                device = DEVICE,
                                model_name = model_name,
                                experiment_name = experiment_name)

best_model = model_trainer.train(train_dataloader, val_dataloader, num_epochs = 100, patience = 5)



(1081, 2)
(358, 2)
(348, 2)

[INFO] Created SummaryWriter, saving to: runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls/densenet161...


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 2.65675, train_acc: 39.79779, val_loss: 2.14998, val_acc: 52.87037, top1_val acc: 0.52870, top3_val acc: 0.73704, top5_val acc: 0.81481


  1%|          | 1/100 [00:44<1:13:40, 44.65s/it]

Epoch: 2 | train_loss: 2.02614, train_acc: 58.08824, val_loss: 1.87217, val_acc: 63.98148, top1_val acc: 0.63981, top3_val acc: 0.80741, top5_val acc: 0.87222


  2%|▏         | 2/100 [01:28<1:12:17, 44.26s/it]

Epoch: 3 | train_loss: 1.74517, train_acc: 65.80882, val_loss: 1.71164, val_acc: 66.57407, top1_val acc: 0.66574, top3_val acc: 0.84444, top5_val acc: 0.90648


  3%|▎         | 3/100 [02:12<1:11:22, 44.15s/it]

Epoch: 4 | train_loss: 1.50859, train_acc: 74.26471, val_loss: 1.49009, val_acc: 75.27778, top1_val acc: 0.75278, top3_val acc: 0.88148, top5_val acc: 0.93148


  4%|▍         | 4/100 [02:55<1:10:03, 43.79s/it]

Epoch: 5 | train_loss: 1.37761, train_acc: 78.03309, val_loss: 1.38988, val_acc: 78.14815, top1_val acc: 0.78148, top3_val acc: 0.91481, top5_val acc: 0.95463


  5%|▌         | 5/100 [03:39<1:09:13, 43.72s/it]

Epoch: 6 | train_loss: 1.21761, train_acc: 85.20221, val_loss: 1.35949, val_acc: 78.42593, top1_val acc: 0.78426, top3_val acc: 0.91574, top5_val acc: 0.95000


  6%|▌         | 6/100 [04:23<1:08:28, 43.71s/it]

Epoch: 7 | train_loss: 1.14521, train_acc: 86.76471, val_loss: 1.34764, val_acc: 80.37037, top1_val acc: 0.80370, top3_val acc: 0.93241, top5_val acc: 0.96111


  7%|▋         | 7/100 [05:06<1:07:29, 43.54s/it]

Epoch: 8 | train_loss: 1.07661, train_acc: 89.61397, val_loss: 1.22869, val_acc: 82.77778, top1_val acc: 0.82778, top3_val acc: 0.94074, top5_val acc: 0.98056


  8%|▊         | 8/100 [05:49<1:06:35, 43.42s/it]

Epoch: 9 | train_loss: 1.03892, train_acc: 90.71691, val_loss: 1.40140, val_acc: 77.50000, top1_val acc: 0.77500, top3_val acc: 0.92130, top5_val acc: 0.95556


  9%|▉         | 9/100 [06:33<1:05:53, 43.45s/it]

Epoch: 10 | train_loss: 0.94900, train_acc: 94.76103, val_loss: 1.24198, val_acc: 82.40741, top1_val acc: 0.82407, top3_val acc: 0.94722, top5_val acc: 0.96944


 10%|█         | 10/100 [07:16<1:05:13, 43.48s/it]

Epoch: 11 | train_loss: 0.91720, train_acc: 95.40441, val_loss: 1.25453, val_acc: 84.25926, top1_val acc: 0.84259, top3_val acc: 0.93519, top5_val acc: 0.96111


 11%|█         | 11/100 [07:59<1:04:21, 43.39s/it]

Epoch: 12 | train_loss: 0.89500, train_acc: 95.22059, val_loss: 1.34193, val_acc: 83.61111, top1_val acc: 0.83611, top3_val acc: 0.93056, top5_val acc: 0.94722


 12%|█▏        | 12/100 [08:43<1:03:34, 43.35s/it]

Epoch: 13 | train_loss: 0.89691, train_acc: 96.59926, val_loss: 1.22302, val_acc: 85.83333, top1_val acc: 0.85833, top3_val acc: 0.94722, top5_val acc: 0.96111


 13%|█▎        | 13/100 [09:26<1:02:55, 43.39s/it]

Epoch: 14 | train_loss: 0.87208, train_acc: 96.69118, val_loss: 1.14023, val_acc: 85.55556, top1_val acc: 0.85556, top3_val acc: 0.95278, top5_val acc: 0.98056


 14%|█▍        | 14/100 [10:10<1:02:16, 43.45s/it]

Epoch: 15 | train_loss: 0.81334, train_acc: 98.43750, val_loss: 1.08940, val_acc: 88.24074, top1_val acc: 0.88241, top3_val acc: 0.96667, top5_val acc: 0.97778


 15%|█▌        | 15/100 [10:53<1:01:25, 43.36s/it]

Epoch: 16 | train_loss: 0.83210, train_acc: 97.24265, val_loss: 1.17879, val_acc: 86.57407, top1_val acc: 0.86574, top3_val acc: 0.96667, top5_val acc: 0.97778


 16%|█▌        | 16/100 [11:36<1:00:41, 43.35s/it]

Epoch: 17 | train_loss: 0.86861, train_acc: 96.50735, val_loss: 1.26882, val_acc: 84.16667, top1_val acc: 0.84167, top3_val acc: 0.93333, top5_val acc: 0.96389


 17%|█▋        | 17/100 [12:20<1:00:10, 43.50s/it]

Epoch: 18 | train_loss: 0.87968, train_acc: 95.95588, val_loss: 1.14475, val_acc: 87.31481, top1_val acc: 0.87315, top3_val acc: 0.94722, top5_val acc: 0.97500


 18%|█▊        | 18/100 [13:05<59:56, 43.85s/it]  

Epoch: 19 | train_loss: 0.82594, train_acc: 97.51838, val_loss: 1.13529, val_acc: 87.50000, top1_val acc: 0.87500, top3_val acc: 0.95000, top5_val acc: 0.97222


 19%|█▉        | 19/100 [14:28<1:01:43, 45.72s/it]

Stopping early!





In [None]:
model_trainer.eval_step(val_dataloader)

(1.3092551021014942,
 63.23529411764706,
 [0.6323529411764706, 0.8480392156862745, 0.9117647058823529])

In [14]:
CHECKPOINT_filename = f"/content/drive/MyDrive/CS518/model_checkpoints/{model_name+'-'+experiment_name}.pth"
save_checkpoint(best_model, model_trainer.optimizer, filename=CHECKPOINT_filename)

=> Saving checkpoint


In [20]:
from google.colab import files
!zip -r /content/drive/MyDrive/CS518/tensorboard_items/densenet_8b_aug_1e4lr_pretrained_full_train_family_ls.zip /content/runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls
files.download("/content/drive/MyDrive/CS518/tensorboard_items/densenet_8b_aug_1e4lr_pretrained_full_train_family_ls.zip")

  adding: content/runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls/ (stored 0%)
  adding: content/runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls/densenet161/ (stored 0%)
  adding: content/runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls/densenet161/Loss_train_loss/ (stored 0%)
  adding: content/runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls/densenet161/Loss_train_loss/events.out.tfevents.1673609110.610578f805c1.124.21 (deflated 49%)
  adding: content/runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls/densenet161/events.out.tfevents.1673609068.610578f805c1.124.20 (deflated 96%)
  adding: content/runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls/densenet161/Accuracy_train_acc/ (stored 0%)
  adding: content/runs/2023-01-13/8b_aug_1e4lr_pretrained_full_train_family_ls/densenet161/Accuracy_train_acc/events.out.tfevents.1673609110.610578f805c1.124.23 (deflated 53%)
  adding: content/runs/2023-01-13/8b_aug_1e4lr_pretrain

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#!rm -rf /content/runs/2023-01-12

### STEPS

In [None]:
model_name = 'resnet50'
experiment_name = '8b_aug_1e3lr_pretrained_full_train'

# Get Model
my_model = get_model(model_name=model_name,
                        class_names=train_dataset.class_names,
                        full_train=True,
                        pretrained=True)
optimizer = torch.optim.Adam(my_model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(label_smoothing=0.0)
#lambda1 = lambda epoch: 0.65 ** epoch
#scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
scheduler = None

In [None]:
# Train
model_trainer = Trainer(model = my_model,
                                criterion = criterion,
                                optimizer = optimizer, 
                                lr_scheduler = scheduler,
                                device = DEVICE,
                                model_name = model_name,
                                experiment_name = experiment_name)

[INFO] Created SummaryWriter, saving to: runs/2023-01-12/8b_aug_1e3lr_pretrained_full_train/resnet50...


In [None]:
model_trainer.train(train_dataloader, val_dataloader, num_epochs = 100)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 2.86174, train_acc: 27.29779, val_loss: 3.05461, val_acc: 26.66667, top1_val acc: 0.26667, top3_val acc: 0.53333, top5_val acc: 0.66389


  1%|          | 1/100 [00:40<1:07:22, 40.83s/it]

Epoch: 2 | train_loss: 2.51518, train_acc: 34.83456, val_loss: 2.54288, val_acc: 37.50000, top1_val acc: 0.37500, top3_val acc: 0.58056, top5_val acc: 0.68981


  2%|▏         | 2/100 [01:02<48:09, 29.48s/it]  

Epoch: 3 | train_loss: 2.17970, train_acc: 40.80882, val_loss: 2.12265, val_acc: 36.66667, top1_val acc: 0.36667, top3_val acc: 0.66111, top5_val acc: 0.76852


  3%|▎         | 3/100 [01:23<41:46, 25.84s/it]

Epoch: 4 | train_loss: 2.01242, train_acc: 43.29044, val_loss: 2.01685, val_acc: 42.96296, top1_val acc: 0.42963, top3_val acc: 0.66667, top5_val acc: 0.76944


  4%|▍         | 4/100 [01:45<38:35, 24.12s/it]

Epoch: 5 | train_loss: 1.83452, train_acc: 45.86397, val_loss: 1.93206, val_acc: 41.38889, top1_val acc: 0.41389, top3_val acc: 0.71389, top5_val acc: 0.82407


  5%|▌         | 5/100 [02:07<36:48, 23.24s/it]

Epoch: 6 | train_loss: 1.64845, train_acc: 49.90809, val_loss: 1.88552, val_acc: 45.09259, top1_val acc: 0.45093, top3_val acc: 0.71019, top5_val acc: 0.81852


  6%|▌         | 6/100 [02:28<35:34, 22.71s/it]

Epoch: 7 | train_loss: 1.47878, train_acc: 53.21691, val_loss: 1.66691, val_acc: 47.68519, top1_val acc: 0.47685, top3_val acc: 0.75556, top5_val acc: 0.85926


  7%|▋         | 7/100 [02:50<34:46, 22.44s/it]

Epoch: 8 | train_loss: 1.39247, train_acc: 56.34191, val_loss: 1.44687, val_acc: 54.53704, top1_val acc: 0.54537, top3_val acc: 0.81019, top5_val acc: 0.90000


  8%|▊         | 8/100 [03:12<34:03, 22.22s/it]

Epoch: 9 | train_loss: 1.29901, train_acc: 59.74265, val_loss: 2.13326, val_acc: 45.00000, top1_val acc: 0.45000, top3_val acc: 0.78704, top5_val acc: 0.89722


  9%|▉         | 9/100 [03:34<33:31, 22.10s/it]

Epoch: 10 | train_loss: 1.14716, train_acc: 65.99265, val_loss: 1.34151, val_acc: 55.55556, top1_val acc: 0.55556, top3_val acc: 0.82685, top5_val acc: 0.92778


 10%|█         | 10/100 [03:55<32:51, 21.91s/it]

Epoch: 11 | train_loss: 1.03523, train_acc: 66.26838, val_loss: 1.46525, val_acc: 56.11111, top1_val acc: 0.56111, top3_val acc: 0.82500, top5_val acc: 0.89444


 11%|█         | 11/100 [04:17<32:21, 21.82s/it]

Epoch: 12 | train_loss: 1.02098, train_acc: 67.18750, val_loss: 1.14157, val_acc: 64.72222, top1_val acc: 0.64722, top3_val acc: 0.89444, top5_val acc: 0.93333


 12%|█▏        | 12/100 [04:38<31:48, 21.69s/it]

Epoch: 13 | train_loss: 0.94642, train_acc: 71.23162, val_loss: 1.71716, val_acc: 52.96296, top1_val acc: 0.52963, top3_val acc: 0.81389, top5_val acc: 0.89722


 13%|█▎        | 13/100 [05:00<31:18, 21.59s/it]

Epoch: 14 | train_loss: 0.82135, train_acc: 73.80515, val_loss: 1.52540, val_acc: 52.22222, top1_val acc: 0.52222, top3_val acc: 0.80833, top5_val acc: 0.89444


 14%|█▍        | 14/100 [05:21<30:52, 21.54s/it]

Epoch: 15 | train_loss: 0.74748, train_acc: 75.91912, val_loss: 0.96916, val_acc: 71.94444, top1_val acc: 0.71944, top3_val acc: 0.89537, top5_val acc: 0.94630


 15%|█▌        | 15/100 [05:42<30:27, 21.50s/it]

Epoch: 16 | train_loss: 0.74011, train_acc: 77.38971, val_loss: 1.14492, val_acc: 66.66667, top1_val acc: 0.66667, top3_val acc: 0.89722, top5_val acc: 0.94167


 16%|█▌        | 16/100 [06:04<30:08, 21.53s/it]

Epoch: 17 | train_loss: 0.66321, train_acc: 78.58456, val_loss: 1.13710, val_acc: 70.00000, top1_val acc: 0.70000, top3_val acc: 0.87407, top5_val acc: 0.94630


 17%|█▋        | 17/100 [06:25<29:46, 21.53s/it]

Epoch: 18 | train_loss: 0.55618, train_acc: 81.70956, val_loss: 1.16504, val_acc: 66.20370, top1_val acc: 0.66204, top3_val acc: 0.88241, top5_val acc: 0.93889


 18%|█▊        | 18/100 [06:47<29:22, 21.50s/it]

Epoch: 19 | train_loss: 0.49705, train_acc: 84.46691, val_loss: 1.44564, val_acc: 62.03704, top1_val acc: 0.62037, top3_val acc: 0.83981, top5_val acc: 0.89352


 19%|█▉        | 19/100 [07:29<31:56, 23.66s/it]

Stopping early!





In [None]:
model_trainer.eval_step(val_dataloader)

(1.047808875557449,
 72.31481481481481,
 [0.7231481481481481, 0.8990740740740741, 0.9490740740740742])

In [None]:
CHECKPOINT_filename = f"/content/drive/MyDrive/CS518/model_checkpoints/{model_name+'-'+experiment_name}.pth"
save_checkpoint(model_trainer.model, model_trainer.optimizer, filename=CHECKPOINT_filename)

=> Saving checkpoint


In [None]:
from google.colab import files
#!zip -r /content/drive/MyDrive/CS518/tensorboard_items/8b_aug_1e3lr_pretrained_full_train.zip /content/runs/2023-01-12/8b_aug_1e3lr_pretrained_full_train
files.download("/content/drive/MyDrive/CS518/tensorboard_items/8b_aug_1e3lr_pretrained_full_train.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>