# Download the datasets from 

https://www.kaggle.com/c/plant-pathology-2020-fgvc7

# and store it in the same directory as this notebook file

In [None]:
! pip install transformers
! pip install albumentations
! pip install ipywidgets
!pip install efficientnet_pytorch

from efficientnet_pytorch import EfficientNet
import torch
import torchvision.models as models
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
import numpy as np
import imageio
import torch.optim as optim
import glob
from tqdm.notebook import tqdm
from progress.bar import Bar
torch.autograd.set_detect_anomaly(True)
import time
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import PIL
from torchvision import transforms
import matplotlib.pyplot as plt

from transformers import get_cosine_schedule_with_warmup
from transformers import AdamW
from albumentations import *
from albumentations.pytorch import ToTensor

import ipywidgets as widgets 


# Displaying the types of leaves

In [None]:

im_healthy = plt.imread('./images/Train_2.jpg', format = 'jpg')
im_multi = plt.imread('./images/Train_1.jpg', format = 'jpg')
im_rust = plt.imread('./images/Train_3.jpg', format = 'jpg')
im_scab = plt.imread('./images/Train_0.jpg', format = 'jpg')


fig = plt.figure(figsize=(16,10))
ax = fig.add_subplot(2, 2, 1)
ax.imshow(im_healthy)
ax.set_title('Healthy', fontsize = 20)

ax = fig.add_subplot(2, 2, 2)
ax.imshow(im_multi)
ax.set_title('Multiple Diseases', fontsize = 20)

ax = fig.add_subplot(2, 2, 3)
ax.imshow(im_rust)
ax.set_title('Rust', fontsize = 20)

ax = fig.add_subplot(2, 2, 4)
ax.imshow(im_scab)
ax.set_title('Scab', fontsize = 20)

# Initialising the neural net parameters

In [None]:
batch_size = 4
epoch = 50
model_name = 'efficientnet-b5'
image_size = EfficientNet.get_image_size(model_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Defining custom dataset class

In [None]:
import cv2

class Dataset(object):
    
    def __init__(self,batch_size,image_size,train=False,cross_vd=False,test = False):
        
        self.image_path = "/home/Bhattacharya/Desktop/Plant Prediction/plant-pathology-2020-fgvc7/images/{0}.jpg"
        self.batch_size  = batch_size
        self.train = train
        self.cross_vd = cross_vd
        self.test = test
        self.image_size  = image_size 
        self.dataset = self.load_dataset()
        self.num_samples = len(self.dataset)
        self.num_batchs = int(np.ceil(self.num_samples / self.batch_size))
        self.batch_count = 0

    def load_dataset(self):
        
        dataset = None
        if self.train == True:
            
            train = pd.read_csv("/home/Bhattacharya/Desktop/Plant Prediction/plant-pathology-2020-fgvc7/train.csv")
            #Train test split
            train_set_number = int(len(train) * 0.8)
            dataset = train[:train_set_number]
            #reshuffle     
            dataset = dataset.sample(frac=1).reset_index(drop=True)
           
            
        elif self.cross_vd == True:
            
            train = pd.read_csv("/home/Bhattacharya/Desktop/Plant Prediction/plant-pathology-2020-fgvc7/train.csv")
            #Train test split
            train_set_number = int(len(train) * 0.8)
            dataset = train[train_set_number:]
            #reshuffle     
            dataset = dataset.sample(frac=1).reset_index(drop=True)
            
        elif self.test == True:
            dataset = pd.read_csv("/home/Bhattacharya/Desktop/Plant Prediction/plant-pathology-2020-fgvc7/test.csv")
            
        
        
        return dataset


    def __iter__(self):
        return self

    def __len__(self):
        return self.num_batchs

    
    def __next__(self):
       
            batch_image = np.zeros((self.batch_size,3, self.image_size, self.image_size))
            batch_target = np.zeros((self.batch_size, 1))
            num = 0
            
            if self.batch_count < self.num_batchs:
                while num < self.batch_size:
                    index = self.batch_count * self.batch_size + num
                    if index >= self.num_samples: index -= self.num_samples
                    annotation = self.image_path.format(self.dataset.loc[index,"image_id"])
                    if self.train or self.cross_vd:
                        if np.float64(self.dataset.loc[index,"healthy"]) == 1:
                            batch_target[num,0] = 0
                        if np.float64(self.dataset.loc[index,"multiple_diseases"]) == 1:
                            batch_target[num,0] = 1
                        if np.float64(self.dataset.loc[index,"rust"]) == 1:
                            batch_target[num,0] = 2
                        if np.float64(self.dataset.loc[index,"scab"]) == 1:
                            batch_target[num,0] = 3
                        
                    image = self.parse_annotation(annotation)
                    
                    batch_image[num, :, :, :] = image
                    num += 1

                self.batch_count += 1
                return (batch_image,batch_target)
            else:
                self.batch_count = 0
                self.dataset = self.dataset.sample(frac=1).reset_index(drop=True)
                raise StopIteration
    
    def parse_annotation(self,image_path):
        
       
        input_image = Image.open(image_path)
        input_image_np = np.array(input_image)
        preprocess = None
        transform = None
        
        if self.train == True:
            #training data augmentation
           
            transform = Compose([HorizontalFlip(p=0.5),
                                  VerticalFlip(p=0.5),
                                  ShiftScaleRotate(rotate_limit=25.0, p=0.7),
                                  OneOf([IAAEmboss(p=1),
                                         IAASharpen(p=1),
                                         Blur(p=1)], p=0.5),
                                  OneOf([ElasticTransform(p=1),
                                         IAAPiecewiseAffine(p=1)], p=0.5),
                                  Resize(self.image_size,self.image_size, always_apply=True),
                                  Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                                  ToTensor()])
            
        else:
            
            #cross vd and test data no augmentation
        
            transform = Compose([Resize(self.image_size,self.image_size, always_apply=True),
                                          Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), always_apply=True),
                                          ToTensor()])
            
        input_tensor = transform(image = input_image_np)["image"]
    
        return input_tensor

        
    

# Initialising datasets

In [None]:

dataset = Dataset(batch_size,image_size,train=True)
total_minibatches = len(dataset)

total_training_size = total_minibatches * batch_size
validation_dataset = Dataset(batch_size,image_size,cross_vd = True)

total_validation_minibatches = len(validation_dataset) 
total_validation_size = total_validation_minibatches * batch_size

print("Training Mini batches:",len(dataset))
print("Validation Mini batches:",len(validation_dataset))

# Defining Custom Loss function

In [None]:
def compute_loss(pred,target):
    
    values, indices = torch.max(pred,1)  
    target = target.long()
    loss = nn.CrossEntropyLoss()
    loss_value = loss(pred,target) 
    return loss_value

# Defining Neural Net

In [None]:
net = EfficientNet.from_pretrained(model_name) 
net._fc =  nn.Sequential(nn.Linear(2048,1024,bias=True),
                          nn.ReLU(),
                          nn.Dropout(p=0.5),
                          nn.Linear(1024,4, bias = True))
net.to(device)
print("Total Parameters", sum(p.numel() for p in net.parameters()))

# Defining the optimisation function

In [None]:
#optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
optimizer = torch.optim.Adam(net.parameters(), weight_decay = 1e-4)
num_train_steps = int(total_minibatches * epoch)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=total_minibatches*5, 
                                            num_training_steps=num_train_steps)

# Defining train and validation functions

In [None]:


def train(images,target,epoch,counter):
    
    net.train()
    optimizer.zero_grad()
    pred  = net(images)
    loss = compute_loss(pred,target)
    loss.backward()
    optimizer.step()
    scheduler.step()
       
    
    return loss
    



def validate(images,target):
    
    with torch.no_grad():
        net.eval()
        pred = net(images)
        values, indices = torch.max(pred,1)
        loss = compute_loss(pred,target)
        target = [int(x) for x in target]
        indices = indices.tolist()
        correct_pred = np.sum(np.array(indices) == np.array(target))
        
        
       
        return correct_pred,loss
        
        
        
    
        
        
    
    
    

# Training Network

In [None]:
pbar_epoch = tqdm(total = epoch, desc='Epoch')  
pbar_train = tqdm(total = total_minibatches*epoch, desc='Training')     

train_loss = []
cv_loss = []
loss = None
for i in range(epoch):
        
    counter = 0
    running_loss = 0
    for d in dataset:
        
        counter = counter + 1
        images,y = d
        
        images = np.reshape(images,(batch_size,3,image_size,image_size))
        images_py = torch.from_numpy(images).float().to(device)
   
        y_py = torch.from_numpy(y).float().to(device)
        y_py = torch.squeeze(y_py)
        
        
        
        loss = train(images_py,y_py,i,counter)
        running_loss = running_loss + loss.item() 
        if counter%10 == 0:
            print("Loss:",loss.item())
        
        pbar_train.update()
        
        
    train_loss.append(running_loss/total_training_size)
    
    
    if i%1==0:
        #Cross Validation 
        preds_count = 0
        pbar = tqdm(total = total_validation_minibatches, desc='Validation')
        running_loss = 0
        for d in validation_dataset:
    
            images,y = d
            images = np.reshape(images,(batch_size,3,image_size,image_size))
            images_py = torch.from_numpy(images).float().to(device)
            y_py = torch.from_numpy(y).float().to(device)
            y_py = torch.squeeze(y_py)
            corrected_pred, loss_v = validate(images_py,y_py)
            running_loss = running_loss + loss_v.item()
            preds_count = preds_count + corrected_pred
        
            pbar.update()

        cv_loss.append(running_loss/total_validation_size)
        accuracy = (preds_count/total_validation_size)*100
        print("Accuracy:",(preds_count/total_validation_size)*100)
        if accuracy > 95:
            
            #If accuracy is above 95% predict on the test set
            test_dataset = Dataset(1,image_size,test=True)
            print("saving submission file")
            pbar_test = tqdm(total = len(test_dataset))
            pandas_dict = {"image_id":[],"healthy":[],"multiple_diseases":[],"rust":[],"scab":[]}
            for j,d in enumerate(test_dataset):
                with torch.no_grad():
        
                    images,y = d
                    images = np.reshape(images,(1,3,image_size,image_size))
                    images_py = torch.from_numpy(images).float().to(device)
        
                    y_py = torch.from_numpy(y).float().to(device)
                    y_py = torch.squeeze(y_py)
                    net.eval()
                    pred = net(images_py)
                    softmax_output = F.softmax(pred, dim=1)
        
                    softmax_output = softmax_output.cpu().detach().numpy()
                    softmax_output = np.squeeze(softmax_output)
                    softmax_output = [float("{:.2f}".format(x))  for x in softmax_output]
                    pandas_dict["image_id"].append("Test_{0}".format(j))
                    pandas_dict["healthy"].append(softmax_output[0])
                    pandas_dict["multiple_diseases"].append(softmax_output[1])
                    pandas_dict["rust"].append(softmax_output[2])
                    pandas_dict["scab"].append(softmax_output[3])
        
                    pbar_test.update()
        
            df = pd.DataFrame.from_dict(pandas_dict)
            df.to_csv('submission_effnet_b5_{0}_{1}.csv'.format(accuracy,i), index = False)
            
            
    
    net.train()
    if i%10==0:
        
        torch.save({
                'epoch': i+1,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                'scheduler': scheduler.state_dict()
                }, "/home/Bhattacharya/Desktop/Plant Prediction/plant-pathology-2020-fgvc7/checkpoint_effnet_b5_2/model_{0}.pth".format(i+1))
    pbar_epoch.update()
    
   
        
    
    

        


    

# Plotting Train vs Cross Validation Loss graph

In [None]:
plt.figure()
plt.ylim(0,1.5)
sns.lineplot(list(range(epoch)), train_loss)
sns.lineplot(list(range(epoch)), cv_loss)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['Train','Validation'])