In [1]:
from par_model import *
from os.path import join
from par_model import ImageDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [3]:
label_path = join('./','data','par_datasets','training_set.txt')
img_path = join('./','data','par_datasets','training_set')

In [4]:
extractor = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
extractor.eval()
extractor = nn.Sequential(*list(extractor.children())[:-1])
for param in extractor.parameters():
    param.requires_grad = False

Using cache found in /home/adam/.cache/torch/hub/pytorch_vision_v0.10.0


In [5]:
hat_model = BinaryMobilnetClassifier(extractor)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, hat_model.parameters()),lr = 0.01)
criterion = nn.BCELoss()


In [6]:
hat_model = BinaryMobilnetClassifier(extractor)

In [7]:
hat_data_set = ImageDataset(label_path,img_path,'hat', transform=transform)

In [8]:
train_loader = torch.utils.data.DataLoader(hat_data_set,batch_size=64)

In [9]:
hat_model.to('cuda')

BinaryMobilnetClassifier(
  (extractor): Sequential(
    (0): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(16, 96, kernel_

In [10]:
for epoch in range(22):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = hat_model(inputs)
        labels = labels.unsqueeze(1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # print statistics
        running_loss += loss.item()
        if i % 20 == 0:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss/20 :.3f}')
            running_loss = 0.0

print('Finished Training')

[1,     1] loss: 0.027
[1,    21] loss: 0.503
[1,    41] loss: 0.632
[1,    61] loss: 0.612
[1,    81] loss: 0.625
[1,   101] loss: 0.586
[1,   121] loss: 0.516
[1,   141] loss: 0.566
[1,   161] loss: 0.690
[1,   181] loss: 0.633
[1,   201] loss: 0.516
[1,   221] loss: 0.786
[1,   241] loss: 0.554
[1,   261] loss: 0.576
[1,   281] loss: 0.552
[1,   301] loss: 0.562
[1,   321] loss: 0.489
[1,   341] loss: 0.661
[1,   361] loss: 0.727
[1,   381] loss: 0.508
[1,   401] loss: 0.504
[1,   421] loss: 0.616
[1,   441] loss: 0.600
[1,   461] loss: 0.589
[1,   481] loss: 0.606
[1,   501] loss: 0.594
[1,   521] loss: 0.600
[1,   541] loss: 0.597
[1,   561] loss: 0.586
[1,   581] loss: 0.589
[1,   601] loss: 0.607
[1,   621] loss: 0.645
[1,   641] loss: 0.642
[1,   661] loss: 0.602
[1,   681] loss: 0.547
[1,   701] loss: 0.587
[1,   721] loss: 0.703
[1,   741] loss: 0.543
[1,   761] loss: 0.506
[1,   781] loss: 0.613
[1,   801] loss: 0.562
[1,   821] loss: 0.545
[1,   841] loss: 0.617
[1,   861] 

KeyboardInterrupt: 

In [None]:
torch.save(hat_model.state_dict(),'./weights/hat.pth')