In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader,random_split
from torchvision import transforms,models
from torchmetrics import Accuracy,Precision,Recall,CohenKappa,F1Score
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassConfusionMatrix
from torchmetrics.utilities.plot import plot_confusion_matrix
from config import *

In [None]:
hyperparameters = {'epochs':100,
                  'lr':0.005,
                  'patience':20,
                  'train_size':0.8,
                  'val_size':0.1,
                  'test_size':0.1,
                  'train_batch_size':32,
                  'val_batch_size':32,
                  'test_batch_size':32
}

In [None]:
class CustomDataset(Dataset):
  def __init__(self,path,transform=None):
    f = np.load(path)
    self.image = f["image"].astype(np.uint8)
    self.label = torch.tensor(f["label"],dtype = torch.long)
    self.transform =  transform
  def __len__(self):
    return self.label.shape[0]

  def __getitem__(self,index):
    if self.transform:
      return self.transform(self.image[index]),self.label[index]
    return self.image[index],self.label[index]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
class_encode = {"healthy":0,"angular_leafspot":1,
                "Calciumdeficiency":2,"Leaf_scorch":3,"leaf_spot":4}
class_decode = {0:"healthy",1:"angular_leafspot",
                2:"Calciumdeficiency",3:"Leaf_scorch",4:"leaf_spot"}
num_class = 5
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(
                                std=[75.640,92.417, 91.233, ],
                                mean=[  20.912,2.458, -22.453]),# In (batch,C,H,W) format

])
dataset = CustomDataset(path=dataset_path,transform=transform)

In [None]:
train_split,val_split,test_split = random_split(dataset,[hyperparameters['train_size'],
                                                         hyperparameters['val_size'],
                                                         hyperparameters['test_size']])

train_loader = DataLoader(train_split,batch_size = hyperparameters['train_batch_size'],shuffle=True,num_workers =4)
val_loader = DataLoader(val_split,batch_size = hyperparameters['val_batch_size'],shuffle=False)
test_loader = DataLoader(test_split,batch_size = hyperparameters['test_batch_size'],shuffle=False)

In [None]:
#dataset class count check
class_count = {k:0 for k in class_encode.keys()}
for _,label in dataset:
  class_count[class_decode[label.item()]]+=1
print(f'{class_count} total={sum(class_count.values())}')


ResNet Model

In [None]:
class Model(nn.Module):
    def __init__(self,num_class,device,path_to_pretrained_weight=None):
        super().__init__()
        self.resnet = models.resnet34(weights=None).to(torch.float32)
        self.resnet.fc = nn.Sequential(
                        nn.Linear(in_features=self.resnet.fc.in_features,out_features=1024),
                        nn.Dropout(p=0.5),
                        nn.Linear(in_features=1024,out_features=num_class)
        )
        if path_to_pretrained_weight:
            pretrained_weight = torch.load(path_to_pretrained_weight,map_location=torch.device(device))
            model_state_dict = self.resnet.state_dict()
            model_state_dict.update({layer_name:layer for layer_name,layer in pretrained_weight.items() if "fc" not in layer_name})
            self.resnet.load_state_dict(model_state_dict)

    def forward(self,x):
        return self.resnet(x)

In [None]:
model = Model(num_class=num_class,path_to_pretrained_weight=pretrained_model_path,device=device)
model.to(device)
#freeze all layers except layer4 and fc
for name,layer in model.named_parameters():
    if 'layer4' in name or 'fc' in name:
        layer.requires_grad = True
        # print(f"{name} {layer.requires_grad}")
    else:
        layer.requires_grad = False
        # print(f"{name} {layer.requires_grad}")

In [None]:
optimizer = torch.optim.AdamW(model.parameters(),lr = hyperparameters['lr'],weight_decay = 0.01)
criterion = nn.CrossEntropyLoss()

In [None]:
base_metrics = MetricCollection({"accuracy":Accuracy(task="multiclass",num_classes=num_class),
                           "precision":Precision(task="multiclass",average="macro",num_classes=num_class),
                           "recall":Recall(task="multiclass",average="macro",num_classes=num_class),
                           "f1score":F1Score(task="multiclass",average="macro",num_classes=num_class),
                           "kappa":CohenKappa(task="multiclass",num_classes=num_class)
})
confmat = MulticlassConfusionMatrix(num_classes=num_class).to(device)

train_metrics = base_metrics.clone(prefix='train_').to(device)
val_metrics = base_metrics.clone(prefix='val_').to(device)
test_metrics = base_metrics.clone(prefix='test_').to(device)

metric_names = (
"accuracy",
"precision",
"recall",
"f1score",
"kappa",
"epochs_loss_values",
"batch_loss")

log_batch = 100 #Compute and log the metrics values after every log_batch batches

log_data={}
for data_name in ("training","validation",'testing'):
    log_data[data_name]  = dict()
    for prefix in ('train','test','val'):
        if prefix not in data_name:
            continue
        for name in metric_names:
            log_data[data_name][prefix+"_"+name]  = list()

In [None]:
from datetime import datetime
def save_model(path,model_dict,model_name="model"):
    now=datetime.now()
    full_path = f'{path}/{model_name}@{now}.pth'
    torch.save(model_dict,full_path)
    print("Model is saved")
    return full_path

In [None]:
min_epoch_loss = float('inf')
curr_limit = 0
best_epoch = 0
best_model_path = None
train_metrics.reset()
val_metrics.reset()
for epoch in range(hyperparameters['epochs']):
  train_epoch_loss = 0.0
  val_epoch_loss = 0.0

  for batch_num,(train_image,train_label) in enumerate(train_loader):
    model.train()
    train_image = train_image.to(device)
    train_label = train_label.to(device).squeeze()

    train_output = model(train_image)#forward propagation
    train_loss = criterion(train_output,train_label)#calculate loss
    optimizer.zero_grad()#zero the gradients
    train_loss.backward()#backpropagate the loss
    optimizer.step()#update the parameters

    #logging the metric values
    train_epoch_loss+=train_loss.item()
    log_data['training']['train_batch_loss'].append(train_loss.item())

    #adding taining metrics
    train_metrics.update(train_output,train_label)
    if (batch_num+1)%log_batch==0:
      computed_train_metrics = train_metrics.compute()
      for metric,value in computed_train_metrics.items():
        log_data['training'][metric].append(value.item())
      train_metrics.reset()

    #validating on validation data
    model.eval()
    with torch.no_grad():
      for val_image,val_label in val_loader:
        val_image = val_image.to(device)
        val_label = val_label.to(device).squeeze()
        val_output = model(val_image)
        val_loss = criterion(val_output,val_label)
        val_epoch_loss+=val_loss.item()
        log_data['validation']["val_batch_loss"].append(val_loss.item())

        #adding validation metrics
        val_metrics.update(val_output,val_label)
      #computing for overall batches
      if (batch_num+1)%log_batch==0:
        computed_val_metrics = val_metrics.compute()
        for metric,value in computed_val_metrics.items():
          log_data['validation'][metric].append(value.item())
        val_metrics.reset()

    if (batch_num+1)%log_batch==0:
      print(f"epoch {epoch} batch {batch_num} | train loss: {train_loss.item():.3f} accu {computed_train_metrics['train_accuracy'].item():.3f} | val loss: {val_loss.item():.3f} accu {computed_val_metrics['val_accuracy'].item():.3f}")
    else:
      print(f"epoch {epoch} batch {batch_num} | train loss: {train_loss.item():.3f}| val loss: {val_loss.item():.3f}")

  log_data['training']["train_epochs_loss_values"].append(train_epoch_loss)
  log_data['validation']["val_epochs_loss_values"].append(val_epoch_loss)

  print(f'At epoch {epoch} train loss is {train_epoch_loss}')
  print(f'At epoch {epoch} valuation loss is {val_epoch_loss}')


  if val_epoch_loss<min_epoch_loss:
    min_epoch_loss = val_epoch_loss
    curr_limit = 0
    best_epoch = epoch
    best_model_path = save_model(path=model_saving_path,model_dict=model.state_dict(),model_name=f'epoch:{epoch}')
  else:
    curr_limit+=1
    if curr_limit>=hyperparameters['patience']:
      print("Early stopping is trigered!")
      print(f"last model saved is in epoch {best_epoch}")
      break

print('finished training')

In [None]:
#load the best model
state_dict_best_model = torch.load(best_model_path,map_location=device)
model.load_state_dict(state_dict_best_model)

In [None]:
def display_axis(*args):
    data,ax,label,color,x_axis,y_axis = args
    ax.plot(data,label=label,color=color)
    ax.set_xlabel(x_axis)
    ax.set_ylabel(y_axis)
    ax.legend()

fig,axes = plt.subplots(4,2,figsize=(15,10))
display_axis(log_data['training']['train_epochs_loss_values'],axes[0][0],"training loss","red","x_batches","loss")
display_axis(log_data['validation']['val_epochs_loss_values'],axes[0][1],"validation loss","blue","batches","loss")
display_axis(log_data['training']['train_batch_loss'],axes[1][0],"training batch loss","red","batches","loss")
display_axis(log_data['validation']['val_batch_loss'],axes[1][1],"validation batch loss","blue","batches","loss")
display_axis(log_data['training']['train_f1score'],axes[2][0],"training f1score","red","batches","f1score")
display_axis(log_data['validation']['val_f1score'],axes[2][1],"validation f1score","blue","batches","f1score")
display_axis(log_data['training']['train_accuracy'],axes[3][0],"training acuu","red","batches","accuracy")
display_axis(log_data['validation']['val_accuracy'],axes[3][1],"validation acuu","blue","batches","accuracy")

In [None]:
def merge_itrs(*itrs):
    for itr in itrs:
        for v in itr:
            yield v

In [None]:
for image,test_label in merge_itrs(test_loader,val_loader):
  model.eval()
  test_metrics.reset()
  with torch.no_grad():
    image = image.to(device)
    test_label = test_label.to(device).squeeze()
    output = model(image)
    test_metrics.update(output,test_label)
    confmat.update(output,test_label)
  #computing for overall batches
  computed_test_metrics = test_metrics.compute()
  for metric,value in computed_test_metrics.items():
    log_data['testing'][metric].append(value.item())

cm = confmat.compute()
fig,ax = plot_confusion_matrix(cm,labels = class_encode.keys(),cmap = 'Blues')
plt.show()
confmat.reset()

In [None]:
#Saving the metric values
with open(metrics_saving_path,'wb') as file:
    pickle.dump({'log_data':log_data,'cm':cm},file)