In [None]:
import os
import numpy as np
from sklearn.metrics import confusion_matrix
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import random
import shutil

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

In [None]:
data_path = '/kaggle/input/mushrooms-classification-common-genuss-images/Mushrooms'

In [None]:
# Temporary folders for training and test images:
os.mkdir('/kaggle/temp')
os.chdir('/kaggle/temp')
os.mkdir('train')
os.mkdir('test')
os.chdir('/kaggle/working')

In [None]:
# Split images (75%/25%) and save to temporary folders:
for subfolder in os.listdir(data_path):

    # Making a list of all files in current subfolder:
    original_path = f'{data_path}/{subfolder}'
    original_data = os.listdir(original_path)

    # Number of samples in each group:
    n_samples = len(original_data)
    train_samples = int(n_samples * 0.75)
        
    train_path = f'/kaggle/temp/train/{subfolder}'
    test_path = f'/kaggle/temp/test/{subfolder}'
    
    # New class subfolder for training:
    os.chdir('/kaggle/temp/train')
    os.mkdir(subfolder)
    
    # Training images:
    for image in range(train_samples):
        original_file = f'{original_path}/{original_data[image]}'
        new_file = f'{train_path}/{original_data[image]}'
        shutil.copyfile(original_file, new_file)
    
    # New class subfolder for testing:
    os.chdir('/kaggle/temp/test')
    os.mkdir(subfolder)
    
    # Test images:
    for image in range(train_samples, n_samples):
        original_file = f'{original_path}/{original_data[image]}'
        new_file = f'{test_path}/{original_data[image]}'
        shutil.copyfile(original_file, new_file)

In [None]:
os.chdir('/kaggle/temp/train/Lactarius')
print(len([name for name in os.listdir('/kaggle/temp/train/Lactarius') if os.path.isfile(name)]))

os.chdir('/kaggle/temp/test/Lactarius')
print(len([name for name in os.listdir('/kaggle/temp/test/Lactarius') if os.path.isfile(name)]))

In [None]:
train_root = '/kaggle/temp/train'
test_root = '/kaggle/temp/test'

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

data_transforms = transforms.Compose([
    transforms.CenterCrop(550),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_dataset = torchvision.datasets.ImageFolder(train_root, data_transforms)
test_dataset = torchvision.datasets.ImageFolder(test_root, data_transforms)

In [None]:
batch_size = 40

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size)

print('number of batches for training:', len(train_dataloader),
      '\nnumber of batches for testing:', len(test_dataloader),
      '\nnumber of images:', len(train_dataset))

In [None]:
class_names = ["Agaricus", "Amanita", "Boletus", "Cortinarius", 
                "Entoloma", "Hygrocybe", "Lactarius", "Russula", "Suillus"]
ids = ["tensor(0)", "tensor(1)", "tensor(2)", "tensor(3)", 
       "tensor(4)", "tensor(5)", "tensor(6)", "tensor(7)", "tensor(8)"]

dict_class_names = dict(zip(ids, class_names))
print(dict_class_names)

In [None]:
def show_input(input_tensor, title):
    image = input_tensor.permute(1, 2, 0).numpy()
    image = std * image + mean
    plt.imshow(image.clip(0, 1))
    plt.title(title)
    plt.show()
    plt.pause(0.001)

X_batch, y_batch = next(iter(train_dataloader))

for x_item, y_item in zip(X_batch, y_batch):
    show_input(x_item, dict_class_names[str(y_item)])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
class LeNet(torch.nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, padding=2)
        self.act1 = torch.nn.ReLU()
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0)
        self.act2 = torch.nn.ReLU()
        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.flatten = torch.nn.Flatten()
        self.dropout = torch.nn.Dropout(p=0.25, inplace=False)
        
        self.fc1 = torch.nn.Linear(in_features=46656, out_features=120)
        self.act3 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(in_features=120, out_features=84)
        self.act4 = torch.nn.ReLU()
        self.fc3 = torch.nn.Linear(in_features=84, out_features=9)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.act2(x)
        x = self.maxpool2(x)

        #x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        #x = x.view(-1, self.num_flat_features(x))
        x = self.flatten(x)
        x = self.dropout(x)
        
        x = self.fc1(x)
        x = self.act3(x)
        x = self.fc2(x)
        x = self.act4(x)
        x = self.fc3(x)
        return x

lenet = LeNet().to(device)

In [None]:
import time

def train_model(model, num_epochs):
    
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-3)

    # Decay LR by a factor of 0.1 every 7 epochs
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.3)
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}:'.format(epoch, num_epochs - 1), flush=True)
        
        start_time = time.time()

        model.train()  # Set model to training mode

        running_loss = 0.
        running_acc = 0.

        # Iterate over data.
        for inputs, labels in train_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()

            # forward and backward
            with torch.set_grad_enabled(True):
                preds = model(inputs)
                loss_value = loss(preds, labels)
                preds_class = preds.argmax(dim=1)

                # backward + optimize only if in training phase
                loss_value.backward()
                optimizer.step()

            # statistics
            running_loss += loss_value.item()
            running_acc += (preds_class == labels.data).float().mean()

        epoch_loss = running_loss / len(train_dataloader)
        epoch_acc = running_acc / len(train_dataloader)

        scheduler.step()
        
        print('Loss: {:.4f} Acc: {:.4f} Time: {:.4f}'.format(epoch_loss, epoch_acc, time.time() - start_time), flush=True)

    return model

In [None]:
from torchvision import models


#------------------------------------------
# ResNet152
#------------------------------------------
resnet152 = models.resnet152(pretrained=True)

for param in resnet152.parameters():
    param.requires_grad = False
    
resnet152.fc = torch.nn.Sequential(
    torch.nn.Linear(resnet152.fc.in_features, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 9)
)

resnet152 = resnet152.to(device)
#------------------------------------------


#------------------------------------------
# MobileNet
#------------------------------------------
mobilenet = models.mobilenet_v2(pretrained=True)

for param in mobilenet.parameters():
    param.requires_grad = False

mobilenet.classifier[1] = torch.nn.Sequential(
    torch.nn.Linear(mobilenet.classifier[1].in_features, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 9)
)

mobilenet = mobilenet.to(device)
#------------------------------------------

In [None]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
#Training ResNet152
resnet152 = train_model(resnet152, num_epochs=50)

In [None]:
#Training Mobilenet
mobilenet = train_model(mobilenet, num_epochs=50)

In [None]:
#Training lenet
lenet = train_model(lenet, num_epochs=10)

In [None]:
def accuracy(model):
    model.eval()
    correct = 0
    
    for inputs, labels in test_dataloader:
        
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.set_grad_enabled(False):
            preds = model(inputs)
            
            for label, pred in zip(labels, preds):
                label = int(label.data.cpu().numpy())
                pred = int(torch.argmax(pred).data.cpu().numpy())
                if label == pred:
                    correct += 1

    return correct / len(test_dataset)

In [None]:
#Accuraces
print("ResNet's accuracy: ", accuracy(resnet152))
print("MobileNet's accuracy: ", accuracy(mobilenet))
print("LeNet's accuracy: ", accuracy(lenet))

In [None]:
def get_target_and_prediction(model):
    model.eval()

    targets = []
    predictions = []

    for inputs, labels in test_dataloader:
        
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.set_grad_enabled(False):
            preds = model(inputs)
            
            for label, pred in zip(labels, preds):
                label = int(label.data.cpu().numpy())
                pred = int(torch.argmax(pred).data.cpu().numpy())
            
                targets.append(label)
                predictions.append(pred)
    
    return targets, predictions

In [None]:
def create_confusion_matrix(model):
    targets, preditions = get_target_and_prediction(model)
    cm = confusion_matrix(targets, preditions)
    return cm

In [None]:
def plot_confusion_matrix(model, normalize=False):
    cm = create_confusion_matrix(model)
    cmap = plt.cm.Blues
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
#ResNet's confusion matrix
plt.figure(figsize=(10,10))
plot_confusion_matrix(resnet152)

In [None]:
#MobileNet's confusion matrix
plt.figure(figsize=(10,10))
plot_confusion_matrix(mobilenet)

In [None]:
#LeNet's confusion matrix
plt.figure(figsize=(10,10))
plot_confusion_matrix(lenet)