In [None]:
from torchvision import models
import torch.nn as nn

def model_A(num_classes):
    # pretrained = True means we use the pretrained parameters of ResNet18
    model_resnet = models.resnet18(pretrained=True)
    num_features = model_resnet.fc.in_features # The input channels of the full connection layer
    model_resnet.fc = nn.Linear(num_features, num_classes) # We modify the number of classes
    # We only train the full connection layer (fine-tune)
    for param in model_resnet.parameters():
        param.requires_grad = False
    for param in model_resnet.fc.parameters():
        param.requires_grad = True
    return model_resnet

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

## Note that: here we provide a basic solution for loading data and transforming data.
## You can directly change it if you find something wrong or not good enough.

## the mean and standard variance of imagenet dataset
## mean_vals = [0.485, 0.456, 0.406]
## std_vals = [0.229, 0.224, 0.225]

def load_data(data_dir = "./data/",input_size = 224,batch_size = 36):
    # data augmentation
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(input_size), # Resize to 224 * 224
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    ## Load dataset
    ## For other tasks, you may need to modify the data dir or even rewrite some part of 'data.py'
    image_dataset_train = datasets.ImageFolder(os.path.join(data_dir, '2-Medium-Scale'), data_transforms['train'])
    image_dataset_valid = datasets.ImageFolder(os.path.join(data_dir, 'test'), data_transforms['test'])

    train_loader = DataLoader(image_dataset_train, batch_size=batch_size, shuffle=True, num_workers=0)
    valid_loader = DataLoader(image_dataset_valid, batch_size=batch_size, shuffle=False, num_workers=0)

    return train_loader, valid_loader

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os

## Note that: here we provide a basic solution for training and validation.
## You can directly change it if you find something wrong or not good enough.

def train_model(model,train_loader, valid_loader, criterion, optimizer, num_epochs=20):

    def train(model, train_loader,optimizer,criterion):
        model.train(True)
        total_loss = 0.0
        total_correct = 0

        for inputs, labels in train_loader:
            # send the data to device (GPU)
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs) # prediction
            loss = criterion(outputs, labels) # loss
            _, predictions = torch.max(outputs, 1) # The class with maximal probability
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)
            total_correct += torch.sum(predictions == labels.data)

        epoch_loss = total_loss / len(train_loader.dataset)
        epoch_acc = total_correct.double() / len(train_loader.dataset)
        return epoch_loss, epoch_acc.item()

    def valid(model, valid_loader,criterion):
        model.train(False)
        total_loss = 0.0
        total_correct = 0
        for inputs, labels in valid_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, predictions = torch.max(outputs, 1)
            total_loss += loss.item() * inputs.size(0)
            total_correct += torch.sum(predictions == labels.data)
        epoch_loss = total_loss / len(valid_loader.dataset)
        epoch_acc = total_correct.double() / len(valid_loader.dataset)
        return epoch_loss, epoch_acc.item()

    best_acc = 0.0
    for epoch in range(num_epochs):
        print('*' * 100)
        print('epoch:{:d}/{:d}'.format(epoch, num_epochs))
        train_loss, train_acc = train(model, train_loader,optimizer,criterion)
        print("training: loss:   {:.4f}, accuracy: {:.4f}".format(train_loss, train_acc))
        valid_loss, valid_acc = valid(model, valid_loader,criterion)
        print("validation: loss: {:.4f}, accuracy: {:.4f}".format(valid_loss, valid_acc))
        # save the best model
        if valid_acc > best_acc:
            best_acc = valid_acc
            best_model = model
            torch.save(best_model, 'best_model.pt')

In [None]:
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## about model
num_classes = 10

## about data
data_dir = "data" ## You may need to specify the data_dir first
inupt_size = 224
batch_size = 18

## about training
num_epochs = 20
lr = 0.001

## model initialization
model = model_A(num_classes=num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device:', device)
model = model.to(device)

## optimizer
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
## loss function
criterion = nn.CrossEntropyLoss()

## data preparation
train_loader, valid_loader = load_data(data_dir=data_dir,input_size=inupt_size, batch_size=batch_size)
# train
train_model(model,train_loader, valid_loader, criterion, optimizer, num_epochs=num_epochs)
