In [1]:
#!/usr/bin/python
# -*- coding: sjis -*-
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pickle
import cv2
import random
from multiboxloss import MultiBoxLoss
from mydataloader2 import *
from augmentations import SSDAugmentation

from tqdm.notebook import tqdm

In [2]:
#http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

In [3]:
#https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth

In [4]:
## (1) データの準備と設定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 64 #30
augment = SSDAugmentation()
prepro = PreProcess(augment)
dirpath = './VOCdevkit/VOC2012/JPEGImages/'
epoch_num = 15

In [5]:
## (2) モデルの定義
from mynet import SSD

In [6]:
## (3) モデルの生成，損失関数，最適化関数の設定
##   (3.1) モデルの生成
net = SSD()
vgg_weights = torch.load('./models/vgg16_reducedfc.pth')
net.vgg.load_state_dict(vgg_weights)
net.to(device)

##   (3.2) 損失関数 の設定
optimizer = optim.SGD(net.parameters(),
                      lr=1e-3,momentum=0.9,
                      weight_decay=5e-4)

##   (3.3) 最適化関数の設定
from multiboxloss import MultiBoxLoss
criterion = MultiBoxLoss(device=device)

In [9]:
## (4) 学習
net.train()
for ep in tqdm(range(epoch_num)):
    i = 0
    ans = pickle.load(open('ans.pkl', 'rb'))
    dataset = MyDataset(ans, dirpath, prepro)
    dataloader = DataLoader(dataset,batch_size=batch_size, 
                            shuffle=True, collate_fn=my_collate_fn)
    for xs, ys in tqdm(dataloader):
        xs  = [ torch.FloatTensor(x) for x in xs ]        
        images = torch.stack(xs, dim=0)
        images = images.to(device)        
        targets  = [ torch.FloatTensor(y).to(device) for y in ys ]        
        outputs = net(images)
        loss_l, loss_c = criterion(outputs, targets)
        loss = loss_l + loss_c
        print(i, loss_l.item(), loss_c.item())
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_value_(net.parameters(), clip_value=2.0)
        optimizer.step()
        loss_l, loss_c  = 0, 0
        xs, ys, bc = [], [], 0
        i += 1        
    outfile = "./models/ssd3-" + str(ep) + ".model"
    torch.save(net.state_dict(),outfile)
    print(outfile," saved")

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/191 [00:00<?, ?it/s]

  mode = random.choice(self.sample_options)


0 2.73567271232605 16.4898624420166
1 2.5330569744110107 15.613561630249023
2 2.582627773284912 14.135025978088379
3 2.6395750045776367 12.79157829284668
4 2.6575093269348145 12.414515495300293
5 2.7552263736724854 12.34683609008789
6 2.5352492332458496 12.368987083435059


Corrupt JPEG data: premature end of data segment


7 2.694406747817993 12.303264617919922
8 2.5824015140533447 12.771342277526855
9 2.350144863128662 12.359814643859863
10 2.8539562225341797 12.322539329528809
11 2.575579881668091 12.425379753112793
12 2.5748634338378906 12.451379776000977
13 2.7704670429229736 12.784599304199219
14 2.548245668411255 12.821991920471191
15 2.5051016807556152 12.510875701904297
16 2.7484095096588135 12.46420669555664
17 2.6903836727142334 12.313272476196289
18 2.5802645683288574 12.505946159362793
19 2.717930793762207 12.456514358520508
20 2.729543924331665 12.356840133666992
21 2.569187641143799 12.106094360351562
22 2.5528738498687744 12.065717697143555
23 2.6768057346343994 12.176640510559082
24 2.716383934020996 12.018635749816895
25 2.6001815795898438 12.097128868103027
26 2.642289638519287 12.049032211303711
27 2.4945802688598633 12.085688591003418
28 2.4744837284088135 11.985223770141602
29 2.597470760345459 12.107222557067871
30 2.589268207550049 11.942398071289062
31 2.588256359100342 12.1475172

Corrupt JPEG data: 349 extraneous bytes before marker 0xd9


163 2.456484317779541 5.4503607749938965
164 2.2173259258270264 5.351112365722656
165 2.5004429817199707 5.341372489929199
166 2.26515793800354 5.341893196105957
167 2.271876811981201 5.517574787139893
168 2.5013949871063232 5.507876396179199
169 2.420483350753784 4.733960151672363
170 2.3134357929229736 5.329891204833984
171 2.3609580993652344 5.070218086242676
172 2.3376729488372803 5.0686798095703125
173 2.2402939796447754 5.1805195808410645
174 2.343120813369751 5.168910026550293
175 2.2933943271636963 5.26286506652832
176 2.2767841815948486 5.234142303466797
177 2.384796142578125 5.167876720428467
178 2.4114949703216553 4.98444128036499
179 2.220655918121338 5.080563068389893
180 2.3056633472442627 5.044843673706055
181 2.485010862350464 5.179081439971924
182 2.303396701812744 4.967403888702393
183 2.305634021759033 5.236117839813232
184 2.27116322517395 5.124146461486816
185 2.1161835193634033 4.8558878898620605
186 2.332869529724121 5.159003734588623
187 2.3634817600250244 5.362

  0%|          | 0/191 [00:00<?, ?it/s]

0 2.1671409606933594 4.888928413391113
1 2.0909252166748047 4.728413105010986
2 2.3173019886016846 4.976984977722168
3 2.27329158782959 4.933199405670166
4 2.370997667312622 5.432089805603027
5 2.246527910232544 4.922970771789551
6 2.348895311355591 4.930974006652832
7 2.130417823791504 4.608484745025635
8 2.30410099029541 5.261539459228516
9 2.39908504486084 4.756765365600586
10 2.291679859161377 5.125567436218262
11 2.198399305343628 5.012322902679443
12 2.3355588912963867 5.1907267570495605
13 2.188678026199341 5.06380033493042
14 2.311380624771118 5.237321853637695
15 2.2807369232177734 5.263873100280762
16 2.3073341846466064 5.174554824829102
17 2.080364227294922 5.205620765686035
18 2.1775455474853516 5.179316520690918
19 2.2011823654174805 5.057979583740234
20 2.2380945682525635 4.747550010681152
21 2.0450148582458496 5.158615589141846
22 2.4028148651123047 5.094432353973389
23 2.3486690521240234 5.095494270324707
24 2.1249752044677734 4.844978332519531
25 2.189758777618408 4.95

Corrupt JPEG data: premature end of data segment


162 2.0185153484344482 4.63571834564209
163 2.2146434783935547 4.647458553314209
164 1.775320291519165 4.547956943511963
165 2.0896155834198 4.803890228271484
166 2.0283989906311035 5.013847351074219


Corrupt JPEG data: 349 extraneous bytes before marker 0xd9


167 1.8641284704208374 4.894190311431885
168 2.079322099685669 4.927749156951904
169 1.8231858015060425 5.054948806762695
170 2.0731987953186035 4.576578617095947
171 1.8038488626480103 4.78397274017334
172 1.992616057395935 4.684304237365723
173 2.1158440113067627 4.8977155685424805
174 1.9383094310760498 4.401448726654053
175 1.9521011114120483 4.527445316314697
176 1.9348198175430298 4.590099811553955
177 2.0818326473236084 4.680517673492432
178 2.0369958877563477 5.086763858795166
179 1.7746094465255737 4.616820335388184
180 1.8928859233856201 4.759129524230957
181 2.2168071269989014 4.537293910980225
182 1.8682531118392944 4.701338768005371
183 2.151304244995117 4.781604290008545
184 2.002023696899414 4.4468793869018555
185 1.9508998394012451 4.609340190887451
186 2.124115228652954 4.808782577514648
187 1.9503341913223267 4.241888046264648
188 1.964491844177246 4.8253889083862305
189 2.041619300842285 5.028930187225342
190 2.191983222961426 5.083044052124023
ssd3-1.model  saved


  0%|          | 0/191 [00:00<?, ?it/s]

0 2.0385208129882812 4.444845676422119
1 1.8681102991104126 4.270753860473633
2 2.0646591186523438 4.760782241821289
3 1.8771271705627441 5.1347150802612305
4 1.8392400741577148 4.9399733543396
5 2.0073654651641846 4.4162397384643555
6 1.9308134317398071 4.915396690368652
7 2.1122963428497314 4.705398082733154
8 2.1435883045196533 5.210374355316162
9 2.2626230716705322 4.509541988372803
10 1.9410074949264526 4.765841960906982
11 1.8261332511901855 4.791759014129639
12 1.9383610486984253 4.723002910614014
13 2.1305341720581055 4.951857566833496
14 2.0041632652282715 4.459288120269775
15 1.9945100545883179 4.914962291717529
16 1.9106462001800537 4.553580284118652
17 1.9478347301483154 4.673267364501953
18 1.8536051511764526 4.829941272735596
19 1.9823384284973145 4.621645450592041
20 2.094468832015991 4.31493616104126
21 2.062195301055908 5.021909236907959
22 2.0744853019714355 4.230087757110596
23 2.13584303855896 4.937096118927002
24 2.1355974674224854 4.748419761657715
25 1.9609742164

Corrupt JPEG data: 349 extraneous bytes before marker 0xd9


41 1.7679811716079712 4.550726413726807
42 2.021979331970215 4.9391703605651855
43 1.8728289604187012 4.6363067626953125
44 1.9947506189346313 4.584421634674072
45 1.8723564147949219 4.739198684692383
46 1.8233169317245483 4.900681018829346
47 1.7807750701904297 4.443795204162598
48 1.9792689085006714 4.669641494750977
49 1.8393656015396118 4.658182144165039
50 2.0164365768432617 4.256960868835449
51 1.8609189987182617 4.3706793785095215
52 1.811273455619812 4.616129398345947
53 1.9787042140960693 4.745967864990234
54 2.0241453647613525 4.489486217498779
55 1.9709218740463257 4.834067344665527
56 2.1065237522125244 4.690495491027832
57 2.0653598308563232 4.664525508880615
58 2.047572374343872 4.547375202178955
59 2.1207823753356934 4.511150360107422
60 1.883245825767517 4.40794563293457
61 1.8993022441864014 4.6883015632629395
62 2.027345657348633 4.241824626922607
63 1.9947344064712524 5.136812210083008
64 2.0690033435821533 4.664374351501465
65 1.8896162509918213 4.561890125274658
66

Corrupt JPEG data: premature end of data segment


181 1.7108361721038818 4.6017746925354
182 1.8006911277770996 4.373085021972656
183 1.9477931261062622 4.6281280517578125
184 1.7619270086288452 4.559062957763672
185 1.807905912399292 4.428924560546875
186 1.9339426755905151 4.088348865509033
187 1.6565483808517456 4.476376533508301
188 1.3636759519577026 4.288267612457275
189 1.7140358686447144 4.537351131439209
190 2.0073084831237793 4.5449299812316895
ssd3-2.model  saved
