In [None]:
### Cell to link Notebook to Google Drive

from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
import pickle
import torch
import torch.nn as nn
import torch.optim as optim

import pandas as pd
import numpy as np
import os
import sklearn
import statistics
import scipy
import PIL
from skimage import io, transform
import random
from torchvision import transforms, utils

from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from scipy.signal import find_peaks

import torchvision

from sklearn.metrics import classification_report, plot_confusion_matrix
from sklearn.preprocessing import MinMaxScaler
import matplotlib.patches as patches
from sklearn.metrics import roc_curve, auc, confusion_matrix
from sklearn.metrics import roc_auc_score
import sklearn.metrics as metrics
import matplotlib.pyplot as plt
import warnings
from scipy import interp

from torch.optim.lr_scheduler import StepLR
from scipy.interpolate import CubicSpline

In [None]:
class PreMendeleyTrain_Dataset(Dataset):
  def __init__(self, csv_file, root_dir, transform = None, transform_gt=None,isSeg = False):
    
    self.data_amt = 0
    self.path = root_dir
    self.isSeg = isSeg
    self.transform = transform
    self.transform_gt = transform_gt

    self.X = {}
    self.Y = {}

    self.grab_data()

  def grab_data(self):
    for species in sorted(os.listdir(self.path)):
      diseased_path = os.path.join(self.path,species,'diseased')
      healthy_path = os.path.join(self.path,species,'healthy')
      
      if (os.path.exists(diseased_path)):
        for image in sorted(os.listdir(diseased_path)):
          self.X[diseased_path+'/'+image] = diseased_path+'/'+image
          label = self.name2label(species)
          self.data_amt += 1
          
          self.Y[diseased_path+'/'+image] = label

      if (os.path.exists(healthy_path)):
        for image in sorted(os.listdir(healthy_path)):
          self.X[healthy_path+'/'+image] = healthy_path+'/'+image
          label = self.name2label(species)
          self.data_amt += 1
          
          self.Y[healthy_path+'/'+image] = label

  def name2label(self,name):
    species = sorted(os.listdir(self.path))
    label = species.index(name)

    return label
      
  def __len__(self):
    return self.data_amt


  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()

    dict_names = list(self.X.keys())
    seg_dict_names = list(self.Y.keys())

    dict_name = dict_names[idx]
    seg_dict_name = seg_dict_names[idx]
    

    image = PIL.Image.open(dict_name)

    if self.isSeg:

      gt = PIL.Image.open(seg_dict_name)
    else:
      gt = self.Y[dict_name]

    if self.transform:
      image = self.transform(image)

    if self.isSeg:
      gt = self.transform_gt(gt)

    sample = {'segment': image, 'gt': gt}
    
    return sample

In [None]:
class PreTrain_Dataset(Dataset):
  def __init__(self, path, transform = None, transform_gt=None,isSeg = False):
    
    self.data_amt = 0
    self.path = path
    self.isSeg = isSeg
    self.transform = transform
    self.transform_gt = transform_gt

    self.X = {}
    self.Y = {}

    self.grab_data()

  def grab_data(self):
    for local in os.listdir(self.path+'images/'):
      local_path = self.path+'images/' + local + '/'
      seg_path = self.path+'segmented/' + local + '/'
      for species in os.listdir(local_path):
        folder_path = local_path + species + '/'
        seg_path = seg_path + species + '/'
        for images in os.listdir(folder_path):
          #print(images)
          #self.X[folder_path+images] = PIL.Image.open(folder_path+images)
          self.X[folder_path+images] = folder_path+images
          label = self.name2label(species)
          self.data_amt += 1
          if self.isSeg:
            #self.Y[folder_path+images] = PIL.Image.open(seg_path+images)
            self.Y[folder_path+images] = seg_path+images
          else:
            self.Y[folder_path+images] = label

  def name2label(self,name):
    species = sorted(os.listdir(self.path+'images/lab/'))
    label = species.index(name)

    return label
      
  def __len__(self):
    return self.data_amt


  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()

    dict_names = list(self.X.keys())
    seg_dict_names = list(self.Y.keys())

    dict_name = dict_names[idx]
    seg_dict_name = seg_dict_names[idx]

    image = PIL.Image.open(dict_name)

    if self.isSeg:
      #gt = io.imread(seg_dict_names[idx])
      gt = PIL.Image.open(seg_dict_name)
    else:
      gt = self.Y[dict_name]

    if self.transform:
      image = self.transform(image)
      #sample = self.transform(sample)

    if self.isSeg:
      gt = self.transform_gt(gt)

    #gt = torchvision.transforms.functional.to_tensor(gt)

    sample = {'segment': image, 'gt': gt}
    
    return sample

In [None]:
#### Model

import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.relu = nn.ReLU()
        
        self.conv1 = nn.Conv2d(3,16,3,padding = 1)
        self.bn1 = nn.BatchNorm2d(16)

        self.conv2 = nn.Conv2d(16,32,5,padding = 2,stride=2)
        self.bn2 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32,32,3,padding = 1,stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        
        self.GAP = nn.AvgPool2d(1)

        
    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.bn1(out)
        
        out = self.conv2(out)
        out = self.relu(out)
        out = self.bn2(out)
        
        
        out = self.conv3(out)
        out = self.relu(out)
        out = self.bn3(out)
        
        #print(out.size())
        
        out = self.GAP(out)

        return out



### TO VINCENT: Here is the classification module for the network. Change the last linear layer size to the number of classes in your dataset
class LeafSnap_Classification(nn.Module):
    def __init__(self):
        super(LeafSnap_Classification, self).__init__()
        self.fc1 = nn.Linear(75*75*32,512)
        self.relu4 = nn.ReLU()
        ### Here the 185 is the amount of classes in LeafSnap dataset
        self.fc2 = nn.Linear(512,185)
        
    def forward(self, x):
        
        out = x.view(x.size(0), -1)

        #print(x.shape)

        out = self.fc1(out)
        out = self.relu4(out)
        out = self.fc2(out)

        return out

class Mendeley_Classification(nn.Module):
    def __init__(self):
        super(Mendeley_Classification, self).__init__()
        self.fc1 = nn.Linear(75*75*32,512)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(512,12) 
        
    def forward(self, x):
        
        out = x.view(x.size(0), -1)

        #print(x.shape)

        out = self.fc1(out)
        out = self.relu4(out)
        out = self.fc2(out)

        return out

class Bali_Classification(nn.Module):
    def __init__(self):
        super(Bali_Classification, self).__init__()
        self.fc1 = nn.Linear(75*75*32,512)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(512,26)
        
    def forward(self, x):
        
        out = x.view(x.size(0), -1)
        out = self.fc1(out)
        out = self.relu4(out)
        out = self.fc2(out)

        return out

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU()
        
        self.conv1 = nn.ConvTranspose2d(32,32,3,padding = 1,stride=2)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.ConvTranspose2d(32,16,5,padding = 2,stride=2,output_padding = 1)
        self.bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.ConvTranspose2d(16,1,3)


        
    def forward(self, x):
        
        out = self.conv1(x)
        out = self.relu(out)
        out = self.bn1(out)
        
        out = self.conv2(out)
        out = self.relu(out)
        out = self.bn2(out)
        
        
        out = self.conv3(out)
        out = self.relu(out)

        return out

class Leaf_FCN(nn.Module):
    def __init__(self):
        super(Leaf_FCN, self).__init__()
        
        self.encoder = Encoder()
        
        self.leafsnap = LeafSnap_Classification()
        self.main = Bali_Classification()
        self.mendelev = Mendeley_Classification()

        self.decoder = Decoder()
        
    def forward(self, x, isMain = True, isSnap = False, isSeg = False, isMend = False):
        if isSeg:
            result = self.decoder(self.encoder(x))
        else:
            if isMain:
              result = self.main(self.encoder(x))
            elif isSnap:
              result = self.leafsnap(self.encoder(x))
            elif isMend:
              result = self.mendelev(self.encoder(x))
            
        return result

In [None]:
def randomInit(m):
    print("Model Randomly Initialized")
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.zeros_(m.bias)
        

def Joint_Loss(leafsnap_loss, leafsnap_seg_loss, mendelev_loss,mu=1):
    if leafsnap_seg_loss == None:
        loss = leafsnap_loss
    elif leafsnap_loss == None:
        loss = leafsnap_seg_loss
    else:
      if mendelev_loss == None:
        loss = mu*leafsnap_seg_loss + leafsnap_loss
      else:
        loss = mu*leafsnap_seg_loss + leafsnap_loss + mendelev_loss*mu*2
    return loss


def fit(model, loss1, loss2,loss3, optimizer, train_loader,seg_loader,mend_loader, val_loader, batch_size, num_epochs, scheduler = None, stat_count=100, device=None,num_ch=3,stepsize=10,PATH = '/content/drive/MyDrive/EEC205/FCN_std_path.pt'):
    less_losses = []
    all_losses = []
    train_acc = []
    lastest_train_acc = []
    loss_epoch = []
    epoch_count = 0

    curr_model_score = 0

    checkpoint = torch.load('/content/drive/MyDrive/EEC205/resnet18_final_pt.pt')
    model.load_state_dict(checkpoint['model_state_dict'])

    model.fc = nn.Linear(512,26)

    if device is not None:
        model.to(device)
    else:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model.to(device)

    randomInit(model)

    if scheduler == None:
        scheduler = StepLR(optimizer, step_size=stepsize, gamma=0.1)
    
    num_steps = len(train_loader)
    
    for epoch in range(num_epochs):
        loss_u = None
        if epoch != 0:
            scheduler.step()
            loss_epoch.append(f_loss.item())
            
        for train_ct in range(num_steps):
            
            try:
                data = labelled_iter.next()
            except:
                labelled_iter = iter(train_loader)
                data = labelled_iter.next()

            try:
                data_seg = seg_iter.next()
            except:
                seg_iter = iter(seg_loader)
                data_seg = seg_iter.next()

            try:
                data_mend = mend_iter.next()
            except:
                mend_iter = iter(mend_loader)
                data_mend = mend_iter.next()

            with torch.enable_grad():
                model.train()
                # labeled data
                images, labels = data[0].to(device,dtype=torch.float), data[1].to(device,dtype=torch.long)
                #images = torch.reshape(images, (batch_size,num_ch,input_size))
                optimizer.zero_grad()
                outputs = model(images)
                outputs = outputs.float()
                #print(labels)
                loss = loss1(outputs, labels)

                images_seg, labels_seg = data_seg['segment'].to(device,dtype=torch.float), data_seg['gt'].to(device,dtype=torch.float)
                outputs_seg = model(images_seg, isMain=False,isSeg=True,isMend=False,isSnap=False)
                outputs_seg = outputs_seg.float()
                loss_seg = loss2(outputs_seg, labels_seg)

                images_mend, labels_mend = data_mend['segment'].to(device,dtype=torch.float), data_mend['gt'].to(device,dtype=torch.long)
                outputs_mend = model(images_mend, isMain=False,isSeg=False,isMend=True)
                outputs_mend = outputs_mend.float()
                loss_mend = loss3(outputs_mend, labels_mend)


                f_loss = Joint_Loss(loss,loss_seg,loss_mend,mu=0.25)
                f_loss.backward()
                optimizer.step()

                
                all_losses.append(f_loss.item())

                # Print statistics on every stat_count iteration
                if (train_ct+1) % stat_count == 0:
                    less_losses.append(f_loss.item())
                    print('Epoch [%d/%d], Step [%d/%d],  Loss: %.4f'
                                %(epoch+1, num_epochs, train_ct+1, 
                                len(train_loader), f_loss.item()))

        #end of batch for loop

        #print("Forced Save")
        #torch.save({'model_state_dict': model.state_dict()}, PATH)

        with torch.no_grad():
            model.eval()
            total_train = 0
            total_len = 0
            total_lengood = 0
            total_correct = 0
            train_predicted_full = []
            train_labels_full = []
            for val_data in val_loader:
                valimages, vallabels = val_data[0].cuda(), val_data[1].cuda()
                valimages = valimages.float()
                #vallabels = vallabels.float()

                #valimages = torch.reshape(valimages, (batch_size,num_ch,input_size))
                trainoutputs = model(valimages)

                _, trainpredicted = torch.max(trainoutputs.data, 1)
                total_train += vallabels.size(0)
                total_correct += (trainpredicted == vallabels).sum().item()

                train_predicted_full = train_predicted_full + trainpredicted.cpu().data.numpy().tolist()
                train_labels_full = train_labels_full +vallabels.cpu().data.numpy().tolist()
            #end of train eval
        #current epoch for loop          
        print("END OF EPOCH")
        class_dict = classification_report(train_labels_full, train_predicted_full, labels=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25],output_dict=True)
        score_0 = class_dict['0']['f1-score'] + 0.0001
        score_1 = class_dict['1']['f1-score'] + 0.0001
        score_2 = class_dict['2']['f1-score'] + 0.0001
        score_3 = class_dict['3']['f1-score'] + 0.0001
        score_4 = class_dict['4']['f1-score'] + 0.0001
        score_5 = class_dict['5']['f1-score'] + 0.0001
        score_6 = class_dict['6']['f1-score'] + 0.0001
        score_7 = class_dict['7']['f1-score'] + 0.0001
        score_8 = class_dict['8']['f1-score'] + 0.0001
        score_9 = class_dict['9']['f1-score'] + 0.0001
        score_10 = class_dict['10']['f1-score'] + 0.0001
        score_11 = class_dict['11']['f1-score'] + 0.0001
        score_12 = class_dict['12']['f1-score'] + 0.0001
        score_13 = class_dict['13']['f1-score'] + 0.0001
        score_14 = class_dict['14']['f1-score'] + 0.0001
        score_15 = class_dict['15']['f1-score'] + 0.0001
        score_16 = class_dict['16']['f1-score'] + 0.0001
        score_17 = class_dict['17']['f1-score'] + 0.0001
        score_18 = class_dict['18']['f1-score'] + 0.0001
        score_19 = class_dict['19']['f1-score'] + 0.0001
        score_20 = class_dict['20']['f1-score'] + 0.0001
        score_21 = class_dict['21']['f1-score'] + 0.0001
        score_22 = class_dict['22']['f1-score'] + 0.0001 
        score_23 = class_dict['23']['f1-score'] + 0.0001
        score_24 = class_dict['24']['f1-score'] + 0.0001
        score_25 = class_dict['25']['f1-score'] + 0.0001
        score = score_0 * score_1 * score_2 * score_3 * score_4 * score_5 * score_6 * score_7 * score_8 * score_9 * score_10 * score_11 * score_12 * score_13 * score_14 * score_15 * score_16 * score_17 * score_18 * score_19 * score_20 * score_21 * score_22 * score_23 * score_24 * score_25
        total_acc = total_correct/total_train
        lastest_train_acc.append(total_acc)
        print("Model Score = ", total_acc*score) 
        if curr_model_score < (total_acc*score):
            curr_model_score = total_acc*score 
            print("Model Checkpoint saved!")
            torch.save({'model_state_dict': model.state_dict()}, '/content/drive/MyDrive/EEC205/resnet18_best.pt')

        epoch_count = epoch_count + 1

        if epoch_count % 5 == 0:
          print("Tenth Model Checkpoint saved!")
          torch.save({'model_state_dict': model.state_dict()}, '/content/drive/MyDrive/EEC205/resnet18_tenth.pt')

        if epoch_count == num_epochs - 1:
          print("Final Model Checkpoint saved!")
          #torch.save({'model_state_dict': model.state_dict()}, PATH.split('.')[0]+'_final'+PATH.split('.')[1])
          torch.save({'model_state_dict': model.state_dict()}, '/content/drive/MyDrive/EEC205/resnet18_final.pt')

    return all_losses, loss_epoch, lastest_train_acc

In [None]:
transforms_all = transforms.Compose([
    transforms.Resize((300,300)),
    transforms.ToTensor()
])

transforms_seg = transforms.Compose([
    transforms.Resize((300,300)),
    transforms.Grayscale(),
    transforms.ToTensor()
])

In [None]:


class_dataset = PreTrain_Dataset('/content/drive/MyDrive/EEC205/dataset/',transform=transforms_all,isSeg=False)

seg_dataset = PreTrain_Dataset('/content/drive/MyDrive/EEC205/dataset/',transform=transforms_all,transform_gt=transforms_seg,isSeg=True)

mend_dataset = PreMendeleyTrain_Dataset(csv_file='/content/drive/MyDrive/EEC205/leaves_label.csv', root_dir='/content/drive/MyDrive/EEC205/hb74ynkjcn-1/', transform=transforms_all)

val_bali = torchvision.datasets.ImageFolder('/content/drive/MyDrive/EEC205/data/val',transform=transforms_all)

In [None]:
bsize = 200


classloader = DataLoader(class_dataset, batch_size=bsize,
                       shuffle=True, num_workers=0, drop_last = True)

segloader = DataLoader(seg_dataset, batch_size=bsize,
                       shuffle=True, num_workers=0, drop_last = True)

mendloader = DataLoader(mend_dataset, batch_size=bsize,
                       shuffle=True, num_workers=0, drop_last = True)

valloader = DataLoader(val_bali, batch_size=bsize,
                        shuffle=False, num_workers=0, drop_last = False)


In [None]:
save_model_path = '/content/drive/MyDrive/EEC205/resnet18.pt'

#model = Leaf_FCN()

model = torchvision.models.resnet18(pretrained=True)

model.fc = nn.Linear(512,185)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

loss1 = nn.CrossEntropyLoss()
loss2 = nn.MSELoss()
loss3 = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 20

In [None]:
all_losses, loss_epoch, lastest_train_acc = fit(model, loss1, loss2,loss3, optimizer, trainloader,segloader,mendloader, valloader, bsize, epochs, scheduler = None, stat_count=5, device=None,num_ch=3,stepsize=10,PATH = save_model_path)

Model Randomly Initialized
Epoch [1/20], Step [5/146],  Loss: 2.7011
Epoch [1/20], Step [10/146],  Loss: 2.1546
Epoch [1/20], Step [15/146],  Loss: 1.9317
Epoch [1/20], Step [20/146],  Loss: 1.4434
Epoch [1/20], Step [25/146],  Loss: 1.2672
Epoch [1/20], Step [30/146],  Loss: 1.1037
Epoch [1/20], Step [35/146],  Loss: 0.8996
Epoch [1/20], Step [40/146],  Loss: 0.8757
Epoch [1/20], Step [45/146],  Loss: 0.7156
Epoch [1/20], Step [50/146],  Loss: 0.4726
Epoch [1/20], Step [55/146],  Loss: 0.6084
Epoch [1/20], Step [60/146],  Loss: 0.4389
Epoch [1/20], Step [65/146],  Loss: 0.4317
Epoch [1/20], Step [70/146],  Loss: 0.3943
Epoch [1/20], Step [75/146],  Loss: 0.2721
Epoch [1/20], Step [80/146],  Loss: 0.3745
Epoch [1/20], Step [85/146],  Loss: 0.2922
Epoch [1/20], Step [90/146],  Loss: 0.2928
Epoch [1/20], Step [95/146],  Loss: 0.2307
Epoch [1/20], Step [100/146],  Loss: 0.2023
Epoch [1/20], Step [105/146],  Loss: 0.2018
Epoch [1/20], Step [110/146],  Loss: 0.2549
Epoch [1/20], Step [115/1