In [1]:
import torch
import torch.nn as nn
import torchvision

%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt

import time
import os
os.environ['KMP_DUPLICATE_?LIB_OK']='True'

from dataset import *
from transforms import *
from criteria import *
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from model_pl import PLWrapper

torch.manual_seed(42)

<torch._C.Generator at 0x7eff780f1cd0>

## Dataset

In [2]:
# x_transform = torchvision.transforms.Compose([lambda x: np.transpose(x, axes=(2, 0, 1)).astype(np.float16)])
x_transform = torchvision.transforms.Compose([
    lambda x: 0.2126*x[...,0] + 0.7152*x[...,1] + 0.0722*x[...,2],
    lambda x: (x*255).astype(int),
    CLAHE(40.0, (32,32)),
    DownsampleShortAxis(384),
    PadOrCenterCrop((384,384)),
    lambda x: (x - x.min()) / (x.max() - x.min()), 
    lambda x: np.expand_dims(x,0).astype(np.float16),
])
y_transform = torchvision.transforms.Compose([lambda y: np.array(eval(y)).astype(np.long)])

data_dir = "/media/gcodes/NVME/diabetic-retinopathy-detection/data/"
ds_train = SimpleDataset(data_dir + "x_train.txt", data_dir + "t_train.txt",
                         x_transform=x_transform, y_transform=y_transform,
                         x_path_prefix=data_dir+"/dr_imgs")
ds_val = SimpleDataset(data_dir + "x_val.txt", data_dir + "t_val.txt",
                       x_transform=x_transform, y_transform=y_transform,
                       x_path_prefix=data_dir+"/dr_imgs")
dl_train = DataLoader(ds_train, batch_size=12, shuffle=True, num_workers=6, pin_memory=True)
dl_val = DataLoader(ds_val, batch_size=24, shuffle=False, num_workers=6, pin_memory=True)

## Model

In [3]:
model = torchvision.models.resnet50(pretrained=False, progress=True)

# replace the fc layer
model = nn.Sequential(nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
                      *[m for m in model.children()][1:-1], 
                      nn.Flatten(),  # or shapes won't work out
                      nn.Linear(2048,5))

In [5]:
sum(p.numel() for p in model.parameters())

23512005

In [6]:
# load_model = False
# model_file = "models_gn_8_1-9/model_e200.pkl"
# if load_model:
#     if torch.cuda.is_available():
#         model.load_state_dict(torch.load(model_file))
#     else:
#         model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))

## Training

In [7]:
criterion = nn.CrossEntropyLoss()

In [8]:
checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(filepath="../dr_experiments/exp3/",
                                                           monitor="val_loss")

In [9]:
trainer = pl.Trainer(precision=16, amp_level=1, gpus=1, checkpoint_callback=checkpoint)
model_pl = PLWrapper(model, criterion, dl_train, dl_val, plot_loss=True)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
trainer.fit(model_pl)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | Sequential       | 23 M  
1 | criterion | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

# Tests

In [10]:
trainer.save_checkpoint("exp3.ckpt")