In [1]:
import os
from urllib import request
from zipfile import ZipFile

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, models, transforms

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
from skorch import NeuralNetClassifier
from skorch.helper import predefined_split

torch.manual_seed(360);

In [5]:
NUM_WORKERS = 16
BATCH_SIZE = 32

In [6]:
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name(i))

Tesla V100-PCIE-32GB
Tesla V100-PCIE-32GB
Tesla V100-PCIE-32GB


# Load data

In [7]:
def download_and_extract_data(dataset_dir=""):
    data_zip = os.path.join(dataset_dir, "hymenoptera_data.zip")
    data_path = os.path.join(dataset_dir, "hymenoptera_data")
    url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"

    if not os.path.exists(data_path):
        if not os.path.exists(data_zip):
            print("Starting to download data...")
            data = request.urlopen(url, timeout=300).read()
            with open(data_zip, "wb") as f:
                f.write(data)

        print("Starting to extract data...")
        with ZipFile(data_zip, "r") as zip_f:
            zip_f.extractall(dataset_dir)

    print("Data has been downloaded and extracted to {}.".format(dataset_dir))


download_and_extract_data()

Data has been downloaded and extracted to .


In [8]:
data_dir = "hymenoptera_data"
train_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)
val_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

train_ds = datasets.ImageFolder(os.path.join(data_dir, "train"), train_transforms)
val_ds = datasets.ImageFolder(os.path.join(data_dir, "val"), val_transforms)

In [9]:
class PretrainedModel(nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet18(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 2)
        self.model = model

    def forward(self, x):
        return self.model(x)

In [10]:
model = PretrainedModel()

In [11]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model)

model = model.to(device)

Let's use 3 GPUs!


In [12]:
net = NeuralNetClassifier(
    model,
    criterion=nn.CrossEntropyLoss,
    lr=0.001,
    batch_size=BATCH_SIZE,
    max_epochs=25,
    # module__output_features=2,
    optimizer=optim.SGD,
    optimizer__momentum=0.9,
    iterator_train__shuffle=True,
    iterator_train__num_workers=NUM_WORKERS,
    iterator_valid__shuffle=True,
    iterator_valid__num_workers=NUM_WORKERS,
    train_split=predefined_split(val_ds),
    device=device,  # comment to train on cpu
)

In [None]:
net.fit(train_ds, y=None);

  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m0.7447[0m       [32m0.6993[0m        [35m0.5894[0m  15.5822
      2        [36m0.5176[0m       [32m0.9020[0m        [35m0.3360[0m  4.3054
      3        [36m0.3752[0m       [32m0.9281[0m        [35m0.2454[0m  4.2245
      4        [36m0.3053[0m       [32m0.9346[0m        [35m0.1982[0m  4.1999
      5        [36m0.2473[0m       [32m0.9477[0m        [35m0.1771[0m  4.1333
      6        [36m0.1731[0m       0.9412        [35m0.1668[0m  4.0653
      7        0.2397       0.9412        0.1683  4.1875
      8        0.2158       0.9412        [35m0.1653[0m  4.3121
      9        [36m0.1495[0m       0.9346        [35m0.1635[0m  4.2005
     10        0.2144       0.9281        0.1767  4.1665
