##Setup

In [2]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import pandas as pd
import matplotlib.pyplot as plt
import os
import zipfile 
import gdown
from natsort import natsorted
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

##Data

In [3]:
os.getcwd()

'/Users/kainoajim/Desktop'

In [4]:
train_labels = pd.read_csv('./trainLabels.csv')


In [5]:
image_names = os.listdir("./archive/0")
image_names.sort()

In [6]:
images_path = "./archive/0"
img_path = os.path.join(images_path, train_labels.iloc[3500].image+".jpeg")
img_path

'./archive/0/4375_left.jpeg'

In [7]:
transform=transforms.Compose([
    transforms.Resize((299,299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

In [8]:
class DRDataset(Dataset):
    def __init__(self, imagepath=images_path, total=None,transform=transform):
        self.df = pd.read_csv('./trainLabels.csv')
        
        self.transform = transform
        self.imagepath = imagepath
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.imagepath, self.df.iloc[index].image +".jpeg")
        img = Image.open(img_path)
        
        if(self.transform):
            img = self.transform(img)
        
        return img, torch.tensor(self.df.iloc[index].level)

In [27]:
total_data = DRDataset(total=35126)

generator = torch.Generator().manual_seed(42)
train_data,test_data = torch.utils.data.random_split(total_data, [0.7, 0.3], generator=generator)

24589
10537


In [28]:
NUM_EPOCHS = 1
NUM_CLASSES = 5
BATCH_SIZE = 32
lr = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

##Model

In [29]:

from efficientnet_pytorch import EfficientNet
model_efficient = EfficientNet.from_pretrained('efficientnet-b0')

Loaded pretrained weights for efficientnet-b0


In [30]:
model_efficient._fc = torch.nn.Linear(in_features=1280, out_features=5, bias=True)

##Optimize/Loss

In [31]:
optimizer = optim.SGD(model_efficient.parameters(), lr=lr, momentum=0.9)

In [32]:
criterion = nn.CrossEntropyLoss()

##Training

In [33]:
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_data, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_data, batch_size=2*BATCH_SIZE, shuffle=False)

In [34]:
for epoch in range(NUM_EPOCHS):
    
    model_efficient.train()
    for inputs, targets in tqdm(train_loader):
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model_efficient(inputs)
        
        targets = targets.to(torch.long)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()

100%|██████████| 769/769 [1:51:31<00:00,  8.70s/it]  


##Testing

In [35]:

def check_accuracy(model, loader):
    model_efficient.eval()
    
    correct_output = 0
    total_output = 0
    
    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.to(device=device)
            y = y.to(device=device)
            
            score = model_efficient(x)
            _,predictions = score.max(1)
            
            correct_output += (y==predictions).sum()
            total_output += predictions.shape[0]
    # model_efficient.train()
    print(f"out of {total_output} , total correct: {correct_output} with an accuracy of {float(correct_output/total_output)*100}")

print('train evaluation\n')
check_accuracy(model_efficient, train_loader_at_eval)
print('test evaluation\n')
check_accuracy(model_efficient, test_loader)

train evaluation



100%|██████████| 385/385 [36:04<00:00,  5.62s/it]


out of 24589 , total correct: 18253 with an accuracy of 74.23238158226013
test evaluation



100%|██████████| 165/165 [15:21<00:00,  5.59s/it]

out of 10537 , total correct: 7796 with an accuracy of 73.98690581321716



