In [None]:
!wget --no-check-certificate https://storage.googleapis.com/emcassavadata/cassavaleafdata.zip 
!unzip /content/cassavaleafdata.zip

In [None]:
import os
import random
import numpy as np 

import torch
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F 
import torch.utils.data as data 

import torchvision.transforms as transforms
import torchvision.datasets as datasets

from torchsummary import summary

import matplotlib.pyplot as plt
from PIL import Image

In [None]:
data_paths = {
    'train': './train',
    'valid': './valid',
    'test': './test',
}

def loader(path):
    return Image.open(path)

img_size = 150
train_transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor()
])

train_data = datasets.ImageFolder(
    root=data_paths['train'],
    loader=loader,
    transform=train_transform
)
valid_data = datasets.ImageFolder(
    root=data_paths['valid'],
    loader=loader,
    transform=train_transform
)
test_data = datasets.ImageFolder(
    root=data_paths['test'],
    loader=loader,
    transform=train_transform
)

In [None]:
class LeNetClass(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding='same')
        self.avgpool1 = nn.AvgPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.avgpool2 = nn.AvgPool2d(kernel_size=2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Linear(16*4*4, 120)
        self.fc_2 = nn.Linear(120, 84)
        self.fc_3 = nn.Linear(84, num_classes)
        
    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.avgpool1(outputs)
        outputs = F.relu(outputs)
        
        outputs = self.conv2(outputs)
        outputs = self.avgpool2(outputs)
        outputs = F.relu(outputs)
        
        outputs = self.flatten(outputs)
        outputs = self.fc_1(outputs)
        outputs = self.fc_2(outputs)
        outputs = self.fc_3(outputs)
        return outputs