In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision import transforms, datasets
from torchvision.utils import make_grid

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, TopKCategoricalAccuracy
from ignite.handlers import EarlyStopping, TerminateOnNan, ModelCheckpoint, Timer

import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

In [None]:
# Transform images to tensors and then normalise.
# Input normalised as mini-batches of 3-chanel RGB images of shape (3 x H x W),
# where H and W are expected to be at least 224.
# Images have to be loaded into a range of [0, 1] and
# normalised using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

data_loc = '/work/m23ss/m23ss/liyiyan/plankton/CPRNet/data/'

zoo_datasets = {x: datasets.ImageFolder(os.path.join(data_loc, x), data_transforms[x])
                for x in ['train', 'valid']}  # ,'test']}
dataloaders = {x: torch.utils.data.DataLoader(zoo_datasets[x], batch_size=24,
                                              shuffle=True, num_workers=4)
               for x in ['train', 'valid']}  # , 'test']}

dataset_sizes = {x: len(zoo_datasets[x]) for x in ['train', 'valid']}  # , 'test']}
class_names = zoo_datasets['train'].classes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")