In [3]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os
import numpy as np
import ssl
use_cuda = torch.cuda.is_available()
ssl._create_default_https_context = ssl._create_unverified_context

maskNet = torchvision.models.mobilenet_v2(pretrained=True)
maskNet.classifier[1] = torch.nn.Linear(1280, 2)

if use_cuda:
  maskNet = maskNet.cuda()

for param in maskNet.features.parameters():
    param.requires_grad = False


data_dir = './dataset'

transform = transforms.Compose([transforms.Resize(225), transforms.RandomRotation(10), transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop(224), transforms.ToTensor(),
                                transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])

data = datasets.ImageFolder(data_dir, transform=transform)

# print(data.targets) 0 - without_mask 1 - with_mask

data_loader = torch.utils.data.DataLoader(data, batch_size=30, num_workers=0, shuffle=True)

if use_cuda:
    criterion = nn.CrossEntropyLoss().cuda()
else:
    criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(maskNet.parameters(), lr=.001, momentum=0.9)

def train(epochs, model, optimize, criter, use_cuda, save_path):

    min_loss = np.Inf

    for ii in range(1, epochs+1):
        current_loss = 0

        for batch_idx, (data, target) in enumerate(data_loader):
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            optimize.zero_grad()
            output = model(data)
            loss = criter(output, target)
            loss.backward()
            optimize.step()

            current_loss = current_loss + ((1 / (batch_idx + 1)) * (loss.data - current_loss))
            print("\nEpoch: " + str(ii) + " Loss: " + str(float(loss)))



            if min_loss > current_loss:
                print("Loss went from " + str(float(min_loss)) + " -> " + str(float(current_loss)) + " Saving ...")
                min_loss = current_loss
                torch.save(model.state_dict(), save_path)

train(3, maskNet, optimizer, criterion, use_cuda, "./mask_net.pt")


Epoch: 1 Loss: 0.7786961197853088
Loss went from inf -> 0.7786961197853088 Saving ...

Epoch: 1 Loss: 0.657819926738739
Loss went from 0.7786961197853088 -> 0.7182580232620239 Saving ...

Epoch: 1 Loss: 0.758148193359375

Epoch: 1 Loss: 0.68925940990448

Epoch: 1 Loss: 0.7492282390594482

Epoch: 1 Loss: 0.65843266248703
Loss went from 0.7182580232620239 -> 0.715264081954956 Saving ...

Epoch: 1 Loss: 0.6910805106163025
Loss went from 0.715264081954956 -> 0.7118092775344849 Saving ...

Epoch: 1 Loss: 0.662446916103363
Loss went from 0.7118092775344849 -> 0.7056390047073364 Saving ...

Epoch: 1 Loss: 0.6223303079605103
Loss went from 0.7056390047073364 -> 0.696382462978363 Saving ...

Epoch: 1 Loss: 0.6707441806793213
Loss went from 0.696382462978363 -> 0.6938186287879944 Saving ...

Epoch: 1 Loss: 0.6340952515602112
Loss went from 0.6938186287879944 -> 0.688389241695404 Saving ...

Epoch: 1 Loss: 0.586267352104187
Loss went from 0.688389241695404 -> 0.6798790693283081 Saving ...

Epoch


Epoch: 1 Loss: 0.4427855610847473
Loss went from 0.4449547231197357 -> 0.4449297785758972 Saving ...

Epoch: 1 Loss: 0.39436817169189453
Loss went from 0.4449297785758972 -> 0.44435521960258484 Saving ...

Epoch: 1 Loss: 0.4329272210597992
Loss went from 0.44435521960258484 -> 0.44422680139541626 Saving ...

Epoch: 1 Loss: 0.4855231046676636

Epoch: 1 Loss: 0.3507159650325775
Loss went from 0.44422680139541626 -> 0.44365301728248596 Saving ...

Epoch: 1 Loss: 0.30045968294143677
Loss went from 0.44365301728248596 -> 0.4420965611934662 Saving ...

Epoch: 1 Loss: 0.4350825250148773
Loss went from 0.4420965611934662 -> 0.44202113151550293 Saving ...

Epoch: 1 Loss: 0.32545074820518494
Loss went from 0.44202113151550293 -> 0.44078102707862854 Saving ...

Epoch: 1 Loss: 0.21485744416713715
Loss went from 0.44078102707862854 -> 0.4384028911590576 Saving ...

Epoch: 1 Loss: 0.2975534498691559
Loss went from 0.4384028911590576 -> 0.4369357228279114 Saving ...

Epoch: 1 Loss: 0.281426578760147

KeyboardInterrupt: 