<a href="https://colab.research.google.com/github/kwantlin/ai4all-medicalai/blob/main/DRTest1ipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import numpy as np
import pandas as pd
import os
from PIL import Image
import random

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm  ###### For progress bar


Mounted at /content/drive


In [None]:
class APTOSDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        try:
          img_name = os.path.join(self.root_dir, 'train_images', self.data_frame.iloc[idx, 0] + '.png')
          image = Image.open(img_name).convert('RGB')
          label = self.data_frame.iloc[idx, 1]

          if self.transform:
              image = self.transform(image)
        except:
          return self.__getitem__(random.randint(0, self.__len__()))

        return image, label


In [None]:
data_path = '/content/drive/My Drive/aptos2019-blindness-detection'
csv_file = os.path.join(data_path, 'train.csv')

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [None]:
train_dataset = APTOSDataset(csv_file=csv_file, root_dir=data_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)


In [None]:
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 138MB/s]


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as pbar:
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), labels.float())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss / (i + 1))
            pbar.update(1)

print('Finished Training')


Epoch 1/10:   0%|          | 0/115 [00:03<?, ?batch/s]


KeyboardInterrupt: 

In [None]:
!ls '/content/drive/My Drive/aptos2019-blindness-detection/train_images/' | wc -l

3661


In [None]:
torch.save(model.state_dict(), 'diabetic_retinopathy_resnet50_model.pth')
