In [1]:

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

In [5]:
%load_ext autoreload
%autoreload 2

from bs_helpers import *
from bs_gameclass import *
from utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
net = bs_unet()
#net.load_state_dict(torch.load('data/battleships_unet.dat'))
net.optim = torch.optim.Adam(lr=0.001, betas=(0., 0.999), params=net.parameters())

In [7]:
gamma = 0.8

In [8]:
def batchgen(games, size, verbose=1):
    
    while True:
        indlist = np.random.permutation(range(len(games)))
        minibatches = [ indlist[k*size:(k+1)*size] for k in range(len(indlist)//size) ]
        
        for mb in minibatches:
            xs = np.zeros((size, SX, SY, 3))
            ys = np.zeros((size, SX, SY, 1))
            rs = np.zeros((size))
            for k, l in enumerate(mb):
                s, (i,j), qmax, r = games[l]
                xs[k] = encode_x(s.sea, s.det)
                ys[k, i, j] = 1.
                if np.random.rand()<.5:
                    xs[k] = xs[k,:,::-1]
                    ys[k] = ys[k,:,::-1]
                if np.random.rand()<.5:
                    xs[k] = xs[k,::-1]
                    ys[k] = ys[k,::-1]
                if np.random.rand()<.5:
                    xs[k] = xs[k].transpose(1,0,2)
                    ys[k] = ys[k].transpose(1,0,2)
                rs[k] = r + gamma*qmax
            yield xs, ys, rs
        if verbose:
            print('Finished one batchgen epoch!')

In [None]:

for epoch in range(9999999):

    # Play some games
    games = []
    glengths = []
    slast = 0
    for _ in trange(100):
        s = GameState()
        h = create_sea()
        single_game = []
        while not GameClass.getEnded(s):
            #plot_sea(s.sea, s.det)
            #plt.show()
            net.predict(encode_x(s.sea, s.det))
            Qa = t2np(net.p[0,:,:,0])   # p ist der pre activation output des unets
            i, j = argmax2d(Qa-10000*s.det)
            Qmax = np.max(Qa)
            if np.random.rand() < 0.1:
                vms = GameClass.getValidActions(s)
                l = np.random.choice(len(vms))
                i, j = vms[l]
            r = h[i,j] # reward
            
            if slast:
                single_game.append((slast, alast, Qmax, rlast))
            slast = s
            alast = (i, j)
            rlast = r
            
            s = GameClass.getNextState(s, (i,j), hidden=h)
        games += single_game
        '''
        # Cumulate reward
        cr = 0.
        for s, (i,j), r in reversed(single_game):
            cr = gamma*cr + r
            games.append((s, (i,j), cr))'''
        glengths.append(len(single_game))
    
    print('Game length mean', np.mean(glengths))
    
    
    # train the nnet
    bg = batchgen(games, size=32)
    losses = []
    
    #raise
    for k in trange(1000): 
        xs, ys, qs = next(bg)
        xs, ys, qs = np2t(xs, ys, qs)
        net(xs)
        qp = net.p
        qp = torch.sum(qp * ys, dim=(1,2,3))
        loss = torch.mean((qp-qs)**2)
        #print(loss)
        loss.backward()
        net.optim.step()
        net.optim.zero_grad()
        losses.append(loss.item())
        if k%100 == 0:
            print(np.mean(losses))
            losses = []

  0%|          | 0/100 [00:00<?, ?it/s]

Game length mean 97.61


  0%|          | 0/1000 [00:00<?, ?it/s]

0.822239339351654
0.29377289399504664
0.22943088859319688
0.22336354568600655
Finished one batchgen epoch!
0.20833134055137634
0.20364635340869428
0.2053680918365717
Finished one batchgen epoch!
0.19291525095701217
0.19506750755012037


In [18]:
qs.shape

torch.Size([32])

In [29]:
vc = vidcapture('F:/$Daten/vidcaptures/battleships cnn/frame%05d.png')

s = GameState()
hdn = create_sea()
while not GameClass.getEnded(s):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
    plot_sea(hdn, s.det, ax=ax1)
    x = encode_x(s.sea, s.det)
    net.predict(x)
    p = net.p
    p = t2np(p)[0,:,:,0]
    i, j = argmax2d(p-s.det*9999)
    s = GameClass.getNextState(s, (i,j), hdn)
    ax2.imshow(p)
    ax2.axis('off')
    ax2.scatter(j, i, c='black', alpha=1, s=150)
    plt.show()

NameError: name 'vidcapture' is not defined

In [None]:
# gegeben: sea, model
det = np.zeros((10,10))
sea = create_sea(100)
##det = create_detection()
i,j = [],[]
from vidcapture import *
vc = vidcapture('F:/$Daten/vidcaptures/battleships qlearning/frame%05d.png')

while True:
    prob = net.predict(encode_x(sea, det))
    prob[det > 0] = 0
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
    ax1.title.set_text('Map')
    ax2.title.set_text('Neural Network prediction')
    ax1.axis('off')
    #ax1.imshow(visualize(sea, det), vmin=0., vmax=3.)
    plot_sea(sea, det, ax1)
    ax1.scatter(j, i, c='black', alpha=1, s=150)
    ax2.axis('off')
    ax2.imshow(prob)
    #plt.title(text)
    vc.capture()
    plt.show()
    
    if np.sum(det*sea) >= np.sum(sea): break
        
    # Welchen Detektieren??
    x = encode_x(s.sea, s.det)
    net.predict(x)
    p = net.p
    p = t2np(p)[0,:,:,0]
    i, j = argmax2d(p-s.det*9999)
    