# MT_DL v1: Binäre Klassifikation

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

#dtype = torch.device("cuda")
dtype = torch.cuda.FloatTensor

In [2]:
class Binary(nn.Sequential):
    def __init__(self, input_dim=512, output_dim=1):
        super(Binary, self).__init__()
        self.l1 = nn.Linear(input_dim, input_dim)
        self.l2 = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        x = F.relu(self.l1(x))
        return torch.sigmoid(self.l2(x))        

In [3]:
model = Binary()
model.cuda()

Binary(
  (l1): Linear(in_features=512, out_features=512, bias=True)
  (l2): Linear(in_features=512, out_features=1, bias=True)
)

In [4]:
scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

transform = transforms.Compose([
    scaler, 
    to_tensor,
    normalize
])

## Datenset

In [1]:
train_ds = datasets.ImageFolder('../../../data/bush_ds/train/', transform=transform)
train_ds.idx_to_class = {i:c for c, i in train_ds.class_to_idx.items()}
train_dl = DataLoader(train_ds, batch_size=10, num_workers=4)

valid_ds = datasets.ImageFolder('../../../data/bush_ds/valid/', transform=transform)
valid_ds.idx_to_class = {i:c for c, i in valid_ds.class_to_idx.items()}
valid_dl = DataLoader(valid_ds, batch_size=10, num_workers=4)

NameError: name 'datasets' is not defined

## Face recognition

In [6]:
def get_vectors(imgs):
    # 2. Create a PyTorch Variable with the transformed image
    t_img = Variable(imgs).type(dtype)  
    
    # 3. Create a vector of zeros that will hold our feature vector
    #    The 'avgpool' layer has an output size of 512
    my_embedding = torch.zeros(imgs.shape[0], 512)    

    # 4. Define a function that will copy the output of a layer
    def copy_data(m, i, o):
      my_embedding.copy_(o.data.squeeze())    
    
    # 5. Attach that function to our selected layer
    h = layer.register_forward_hook(copy_data)    
    # 6. Run the model on our transformed image
    resnet(t_img)    
    # 7. Detach our copy function from the layer
    h.remove()    
    # 8. Return the feature vector
    return my_embedding

In [7]:
# Load the pretrained model
resnet = models.resnet18(pretrained=True)
resnet.cuda()
# Use the model object to select the desired layer
layer = resnet._modules.get('avgpool')
_ = resnet.eval()

## Training

In [8]:
loss_func = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=1e-4) #weight_decay=1e-4 

In [10]:
epochs = 400
for e in range(epochs):
    epoch_loss = 0
    positiv_imgs = 0
    print('Epoche: ' + str(e))
    
    valid_loss = 0
    for vdata, vtarget in valid_dl:
        vdata_t = get_vectors(vdata)
        
        data_v   = Variable(data_t, requires_grad=False).type(dtype)
        target_v = Variable(target, requires_grad=False).type(dtype)

        # forward
        pred = model.forward(data_v)
        # loss
        loss = loss_func(pred, target_v.float())
        valid_loss = valid_loss + loss
        

    for data, target in train_dl:
        data_t = get_vectors(data)
        
        data_v   = Variable(data_t, requires_grad=False).type(dtype)
        target_v = Variable(target, requires_grad=False).type(dtype)
        positiv_imgs = positiv_imgs + torch.sum(target_v)
    
        # forward
        pred = model.forward(data_v)
        # zero grads
        optimizer.zero_grad()
        # calculate loss
        loss = loss_func(pred, target_v.float())
        epoch_loss = epoch_loss + loss
        # back prop
        loss.backward()
        optimizer.step()
    
  #  print('pos imgs: ' + str(positiv_imgs))
    print('loss : ' + str(epoch_loss.item()))
    print('valid loss : ' + str(valid_loss.item()))
    print('---------------------')
        
    

Epoche: 0
loss : 4.145920276641846
valid loss : 0.31541919708251953
---------------------
Epoche: 1
loss : 6.995523452758789
valid loss : 0.06145597621798515
---------------------
Epoche: 2
loss : 6.878757953643799
valid loss : 0.05796614661812782
---------------------
Epoche: 3
loss : 6.7510457038879395
valid loss : 0.055426329374313354
---------------------
Epoche: 4
loss : 6.627974987030029
valid loss : 0.053080081939697266
---------------------
Epoche: 5
loss : 6.5130534172058105
valid loss : 0.050897981971502304
---------------------
Epoche: 6
loss : 6.40562105178833
valid loss : 0.04892781376838684
---------------------
Epoche: 7
loss : 6.303030490875244
valid loss : 0.047087498009204865
---------------------
Epoche: 8
loss : 6.203808784484863
valid loss : 0.045242175459861755
---------------------
Epoche: 9
loss : 6.107645511627197
valid loss : 0.04358586668968201
---------------------
Epoche: 10
loss : 6.020188808441162
valid loss : 0.04216645285487175
---------------------
Epo

loss : 2.202859401702881
valid loss : 0.012830191291868687
---------------------
Epoche: 91
loss : 2.1783392429351807
valid loss : 0.012652414850890636
---------------------
Epoche: 92
loss : 2.1563994884490967
valid loss : 0.012423591688275337
---------------------
Epoche: 93
loss : 2.13535475730896
valid loss : 0.012325582094490528
---------------------
Epoche: 94
loss : 2.112225294113159
valid loss : 0.012143108993768692
---------------------
Epoche: 95
loss : 2.091402530670166
valid loss : 0.012028766795992851
---------------------
Epoche: 96
loss : 2.0714032649993896
valid loss : 0.011873728595674038
---------------------
Epoche: 97
loss : 2.049075126647949
valid loss : 0.011694220826029778
---------------------
Epoche: 98
loss : 2.0300893783569336
valid loss : 0.011577733792364597
---------------------
Epoche: 99
loss : 2.0086333751678467
valid loss : 0.011383001692593098
---------------------
Epoche: 100
loss : 1.9876888990402222
valid loss : 0.011259564198553562
---------------

loss : 1.0973939895629883
valid loss : 0.0024937789421528578
---------------------
Epoche: 179
loss : 1.0904426574707031
valid loss : 0.0024485220201313496
---------------------
Epoche: 180
loss : 1.0828386545181274
valid loss : 0.0023787806276232004
---------------------
Epoche: 181
loss : 1.0770821571350098
valid loss : 0.0023288659285753965
---------------------
Epoche: 182
loss : 1.0711992979049683
valid loss : 0.002283755224198103
---------------------
Epoche: 183
loss : 1.0629452466964722
valid loss : 0.0022388226352632046
---------------------
Epoche: 184
loss : 1.0582987070083618
valid loss : 0.0021994851995259523
---------------------
Epoche: 185
loss : 1.0518858432769775
valid loss : 0.002160093979910016
---------------------
Epoche: 186
loss : 1.0455657243728638
valid loss : 0.0021109359804540873
---------------------
Epoche: 187
loss : 1.0395232439041138
valid loss : 0.002052335301414132
---------------------
Epoche: 188
loss : 1.034249186515808
valid loss : 0.0020162512082

loss : 0.6710500121116638
valid loss : 0.0006405918393284082
---------------------
Epoche: 266
loss : 0.6684697866439819
valid loss : 0.000640479673165828
---------------------
Epoche: 267
loss : 0.6657224297523499
valid loss : 0.0006315165082924068
---------------------
Epoche: 268
loss : 0.6632922291755676
valid loss : 0.0006243617390282452
---------------------
Epoche: 269
loss : 0.6604516506195068
valid loss : 0.0006208384875208139
---------------------
Epoche: 270
loss : 0.6581448912620544
valid loss : 0.0006134171271696687
---------------------
Epoche: 271
loss : 0.6560685038566589
valid loss : 0.0006112419650889933
---------------------
Epoche: 272
loss : 0.6535031199455261
valid loss : 0.0006043371395207942
---------------------
Epoche: 273
loss : 0.6508875489234924
valid loss : 0.000599052757024765
---------------------
Epoche: 274
loss : 0.6488891839981079
valid loss : 0.0005964153679087758
---------------------
Epoche: 275
loss : 0.6461179852485657
valid loss : 0.00059458252

loss : 0.5280683636665344
valid loss : 0.0004136991628911346
---------------------
Epoche: 353
loss : 0.5266984105110168
valid loss : 0.00041270896326750517
---------------------
Epoche: 354
loss : 0.5256951451301575
valid loss : 0.00041296452400274575
---------------------
Epoche: 355
loss : 0.5241568088531494
valid loss : 0.00041073860484175384
---------------------
Epoche: 356
loss : 0.5234002470970154
valid loss : 0.00040808902122080326
---------------------
Epoche: 357
loss : 0.5220677256584167
valid loss : 0.0004080980143044144
---------------------
Epoche: 358
loss : 0.5214328169822693
valid loss : 0.000405552244046703
---------------------
Epoche: 359
loss : 0.5198384523391724
valid loss : 0.00040316791273653507
---------------------
Epoche: 360
loss : 0.5186781287193298
valid loss : 0.00040435264236293733
---------------------
Epoche: 361
loss : 0.5179301500320435
valid loss : 0.0004024106601718813
---------------------
Epoche: 362
loss : 0.5167883634567261
valid loss : 0.0004

In [11]:
torch.save(model.state_dict(), 'bush_v1_400.pt')