In [1]:

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

In [2]:
%load_ext autoreload
%autoreload 2

from bs_helpers import *
from bs_gameclass import *
from utils import *

In [3]:
net = bs_unet()
#net.load_state_dict(torch.load('data/battleships_unet.dat'))
net.optim = torch.optim.Adam(lr=0.001, betas=(0., 0.999), params=net.parameters())

In [4]:
gamma = 0.8

In [5]:
def batchgen(games, size, verbose=1):
    
    while True:
        indlist = np.random.permutation(range(len(games)))
        minibatches = [ indlist[k*size:(k+1)*size] for k in range(len(indlist)//size) ]
        
        for mb in minibatches:
            xs = np.zeros((size, SX, SY, 3))
            ys = np.zeros((size, SX, SY, 1))
            rs = np.zeros((size))
            for k, l in enumerate(mb):
                s, (i,j), qmax, r = games[l]
                xs[k] = encode_x(s.sea, s.det)
                ys[k, i, j] = 1.
                if np.random.rand()<.5:
                    xs[k] = xs[k,:,::-1]
                    ys[k] = ys[k,:,::-1]
                if np.random.rand()<.5:
                    xs[k] = xs[k,::-1]
                    ys[k] = ys[k,::-1]
                if np.random.rand()<.5:
                    xs[k] = xs[k].transpose(1,0,2)
                    ys[k] = ys[k].transpose(1,0,2)
                rs[k] = r + gamma*qmax
            yield xs, ys, rs
        if verbose:
            print('Finished one batchgen epoch!')

In [6]:

for epoch in range(9999999):

    # Play some games
    games = []
    glengths = []
    slast = 0
    for _ in trange(100):
        s = GameState()
        h = create_sea()
        single_game = []
        while not GameClass.getEnded(s):
            #plot_sea(s.sea, s.det)
            #plt.show()
            net.predict(encode_x(s.sea, s.det))
            Qa = t2np(net.p[0,:,:,0])   # p ist der pre activation output des unets
            i, j = argmax2d(Qa-10000*s.det)
            Qmax = np.max(Qa)
            if np.random.rand() < 0.1:
                vms = GameClass.getValidActions(s)
                l = np.random.choice(len(vms))
                i, j = vms[l]
            r = h[i,j] # reward
            
            if slast:
                single_game.append((slast, alast, Qmax, rlast))
            slast = s
            alast = (i, j)
            rlast = r
            
            s = GameClass.getNextState(s, (i,j), hidden=h)
        games += single_game
        '''
        # Cumulate reward
        cr = 0.
        for s, (i,j), r in reversed(single_game):
            cr = gamma*cr + r
            games.append((s, (i,j), cr))'''
        glengths.append(len(single_game))
    
    print('Game length mean', np.mean(glengths))
    
    
    # train the nnet
    bg = batchgen(games, size=32)
    losses = []
    
    #raise
    for k in trange(1000): 
        xs, ys, qs = next(bg)
        xs, ys, qs = np2t(xs, ys, qs)
        net(xs)
        qp = net.p
        qp = torch.sum(qp * ys, dim=(1,2,3))
        loss = torch.mean((qp-qs)**2)
        #print(loss)
        loss.backward()
        net.optim.step()
        net.optim.zero_grad()
        losses.append(loss.item())
        if k%100 == 0:
            print(np.mean(losses))
            losses = []
    torch.save(net.state_dict(), 'model_qlearning.dat')

  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 96.91


  0%|          | 0/1000 [00:00<?, ?it/s]

0.8327263593673706
0.32285124614834787
0.23711337864398957
0.22522459544241427
Finished one batchgen epoch!
0.22316719606518745
0.20052776709198952
0.20846928045153618
Finished one batchgen epoch!
0.20436072513461112
0.19410697244107722
0.1992460737377405
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 81.35


  0%|          | 0/1000 [00:00<?, ?it/s]

0.4933244585990906
0.23783406794071196
0.2366768305003643
Finished one batchgen epoch!
0.2269740504026413
0.22391414970159532
0.2188565208017826
Finished one batchgen epoch!
0.21750738680362702
0.21721743315458297
Finished one batchgen epoch!
0.211019374281168
0.21626223631203176


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 81.73


  0%|          | 0/1000 [00:00<?, ?it/s]

0.4753163456916809
0.2413460284471512
0.23632537126541137
Finished one batchgen epoch!
0.227835019081831
0.22807803966104984
0.22044113650918007
Finished one batchgen epoch!
0.22063058719038964
0.21806661151349543
Finished one batchgen epoch!
0.21803186371922492
0.21315196759998797


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 78.1


  0%|          | 0/1000 [00:00<?, ?it/s]

0.3155933618545532
0.2402176721394062
0.25274561777710913
Finished one batchgen epoch!
0.23757493287324905
0.23387689664959907
Finished one batchgen epoch!
0.23534662649035454
0.23610495001077653
0.23269934631884098
Finished one batchgen epoch!
0.2274780696630478
0.22379737183451653
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 80.58


  0%|          | 0/1000 [00:00<?, ?it/s]

0.40117335319519043
0.22983393132686614
0.23702092036604883
Finished one batchgen epoch!
0.22911240488290788
0.2285153490304947
0.23284673020243646
Finished one batchgen epoch!
0.2218441218137741
0.22680989384651185
Finished one batchgen epoch!
0.22015192106366158
0.22647958144545555


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 79.58


  0%|          | 0/1000 [00:00<?, ?it/s]

0.35337698459625244
0.24310083463788032
0.23050324082374574
Finished one batchgen epoch!
0.2298709098994732
0.23547763228416443
Finished one batchgen epoch!
0.23214229822158813
0.2270799335092306
0.22707001723349093
Finished one batchgen epoch!
0.23220621570944786
0.22492822512984276
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 79.97


  0%|          | 0/1000 [00:00<?, ?it/s]

0.20857565104961395
0.22243590883910655
0.22007428288459777
Finished one batchgen epoch!
0.22337678104639053
0.22363715082406999
Finished one batchgen epoch!
0.2208158691972494
0.21852516114711762
0.22065398022532462
Finished one batchgen epoch!
0.21723286524415017
0.21266584053635598
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 77.7


  0%|          | 0/1000 [00:00<?, ?it/s]

0.2517523169517517
0.2393266302347183
0.2316388449072838
Finished one batchgen epoch!
0.23648208990693093
0.23374976605176925
Finished one batchgen epoch!
0.22895076960325242
0.23297278203070163
0.2288188634812832
Finished one batchgen epoch!
0.24162893995642662
0.23291070714592935
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 80.75


  0%|          | 0/1000 [00:00<?, ?it/s]

0.23264707624912262
0.22057465314865113
0.22174033239483834
Finished one batchgen epoch!
0.217053599357605
0.21849729225039483
0.22202093333005904
Finished one batchgen epoch!
0.21207856252789498
0.21514873057603837
Finished one batchgen epoch!
0.21099414080381393
0.21707699194550514


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 78.14


  0%|          | 0/1000 [00:00<?, ?it/s]

0.1759956181049347
0.2310436899960041
0.2328270949423313
Finished one batchgen epoch!
0.23309667631983758
0.22750789821147918
Finished one batchgen epoch!
0.22957227870821953
0.23084503322839736
0.22829759776592254
Finished one batchgen epoch!
0.22526306554675102
0.22052114069461823
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 83.32


  0%|          | 0/1000 [00:00<?, ?it/s]

0.22296249866485596
0.21623095721006394
0.21947216749191284
Finished one batchgen epoch!
0.22173598796129226
0.21337980329990386
0.21387710817158223
Finished one batchgen epoch!
0.21393259450793267
0.21996252477169037
Finished one batchgen epoch!
0.20625209458172322
0.21539974123239516


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 78.04


  0%|          | 0/1000 [00:00<?, ?it/s]

0.15598562359809875
0.23534399434924125
0.23696477070450783
Finished one batchgen epoch!
0.23469816535711288
0.23813192456960677
Finished one batchgen epoch!
0.23346078708767892
0.23914777472615242
0.23076882392168044
Finished one batchgen epoch!
0.23555389210581779
0.23305804878473282
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 77.92


  0%|          | 0/1000 [00:00<?, ?it/s]

0.23605217039585114
0.2314954023063183
0.2325952585041523
Finished one batchgen epoch!
0.23539865002036095
0.22680689454078673
Finished one batchgen epoch!
0.2342188148200512
0.2275737379491329
0.2324027168750763
Finished one batchgen epoch!
0.22943222254514695
0.22722269743680953
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 79.86


  0%|          | 0/1000 [00:00<?, ?it/s]

0.179605633020401
0.22661725491285323
0.2228179456293583
Finished one batchgen epoch!
0.22401572547852994
0.2196181008219719
Finished one batchgen epoch!
0.23187944784760475
0.2249002207070589
0.2200329813361168
Finished one batchgen epoch!
0.2228360327333212
0.22047556221485137
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 77.5


  0%|          | 0/1000 [00:00<?, ?it/s]

0.33129265904426575
0.23247587889432908
0.23856990531086922
Finished one batchgen epoch!
0.23081572026014327
0.2320405387878418
Finished one batchgen epoch!
0.23326373770833014
0.2324203509092331
0.23482868060469628
Finished one batchgen epoch!
0.230390961766243
0.2323654495179653
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 78.94


  0%|          | 0/1000 [00:00<?, ?it/s]

0.14585994184017181
0.22747405230998993
0.2331470837444067
Finished one batchgen epoch!
0.23184709906578063
0.22155949637293815
Finished one batchgen epoch!
0.2586349196732044
0.22472122803330422
0.22866766080260276
Finished one batchgen epoch!
0.23208853706717492
0.22549274504184724
Finished one batchgen epoch!


  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [7]:
torch.save(net.state_dict(), 'model_qlearning.dat')

In [None]:
# gegeben: sea, model
det = np.zeros((10,10))
sea = create_sea(100)
##det = create_detection()
i,j = [],[]
from vidcapture import *
vc = vidcapture('F:/$Daten/vidcaptures/battleships qlearning/frame%05d.png')

while True:
    # Welchen Detektieren??
    x = encode_x(sea, det)
    net.predict(x)
    p = net.p
    p = t2np(p)[0,:,:,0]
    i, j = argmax2d(p-det*999999)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
    ax1.title.set_text('Map')
    ax2.title.set_text('Neural Network prediction')
    ax1.axis('off')
    #ax1.imshow(visualize(sea, det), vmin=0., vmax=3.)
    plot_sea(sea, det, ax1)
    ax1.scatter(j, i, c='black', alpha=1, s=150)
    ax2.axis('off')
    ax2.imshow(p)
    #plt.title(text)
    vc.capture()
    plt.show()
    
    if np.sum(det*sea) >= np.sum(sea): break
        
    det[i,j] = 1