# impot modules

In [None]:
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# config

In [None]:
TRAIN_DIR = "data/train"
TEST_DIR = "data/test"
BATCH_SIZE = 16
PRETRAINED = False
EPOCHS = 100
SAVE_PATH = "nsfw.pth"
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# load dataset

In [None]:
train_set = torchvision.datasets.ImageFolder(TRAIN_DIR, transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]))
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
)

test_set = torchvision.datasets.ImageFolder(TEST_DIR, transform=transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]))
test_loader = torch.utils.data.DataLoader(test_set,
    batch_size=BATCH_SIZE,
    num_workers=2,
)

# create model

In [None]:
model = torchvision.models.googlenet(pretrained=PRETRAINED)
model.fc = nn.Linear(model.fc.in_features, 148)
# model.load_state_dict(torch.load(SAVE_PATH))
model = model.to(DEVICE)

# set loss function & optimizer

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# train

In [None]:
for epoch in range(EPOCHS):
    desc = "Epoch: %d / %d" % (epoch + 1, EPOCHS)
    loss_sum = 0

    for data in tqdm(train_loader, desc=desc):
        inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

        optimizer.zero_grad()

        outputs, _ = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()

    print("loss: %.3f" % (loss_sum / len(train_loader)))

# test

In [None]:
correct = [0] * 148
total = [0] * 148

model.eval()
with torch.no_grad():
    for data in tqdm(test_loader):
        inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

        _, outputs = torch.max(model(inputs), 1)

        for i in range(len(labels)):
            label = labels[i]
            correct[label] += (outputs[i] == labels[i]).item()
            total[label] += 1

for i in range(148):
    print('%46s: %.2f%%' % (train_set.classes[i], 100 * correct[i] / total[i]))

# save model

In [None]:
torch.save(model.state_dict(), SAVE_PATH)