In [88]:
import numpy as np
from torch.utils.data import DataLoader
from torch import nn
import random
import math
import torch
from torch import nn
from torch import optim
from tensorboardX import SummaryWriter
from datetime import datetime
import time
import pygame
import load
from models import DynamicTopologyModel
import load_topology_dataset
from load_topology_dataset import MAX_CELL_COUNT, MIN_CELL_COUNT, IMAGE_SIZE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device="cpu"
EXPERIMENT=2
writer = SummaryWriter('runs/experiment_{}/'.format(EXPERIMENT))

In [89]:
def save_parameters(writer, model, batch_idx):
    for k, v in model.state_dict().items():
        shape = v.shape
        # Don't do this for single weights or biases
        if np.any(np.array(shape) > 1):
            mean = torch.mean(v)
            std_dev = torch.std(v)
            maximum = torch.max(v)
            minimum = torch.min(v)
            writer.add_scalars("{}_{} ".format(k, shape), {"mean": mean,
                                                                    "std_dev": std_dev,
                                                                    "max": maximum,
                                                                    "min": minimum}, batch_idx)
            writer.add_histogram("{}_{}".format(k, shape), v, batch_idx)
            
        else:
            writer.add_scalar("{}_{}".format(k, shape), v.data, batch_idx)

In [90]:
def test_parameters_update(last_parameters_np ,parameters_np):
    for last_param, param in zip(last_parameters_np, parameters_np):
        diff = param - last_param
        #print(last_parameters.max())
        print(diff.max())
        #print(not np.all(diff))


In [91]:
def print_example(reference_cell_input, reference_cell_present_input,
                neighbourhood_cell_rel_input, neighbourhood_cell_load_input, neighbourhood_cell_present_input,
                  reference_cell_target, reference_cell_present_target, neighbourhood_cell_rel_target,
                  neighbourhood_cell_present_target):
    print("reference_cell_input") 
    print(reference_cell_input)
    print("reference_cell_present_input") 
    print(reference_cell_present_input)
    print("neighbourhood_cell_rel_input") 
    print(neighbourhood_cell_rel_input)
    print("neighbourhood_cell_load_input") 
    print(neighbourhood_cell_load_input)
    print("neighbourhood_cell_present_input") 
    print(neighbourhood_cell_present_input)
    print("reference_cell_target") 
    print(reference_cell_target)
    print("reference_cell_present_target") 
    print(reference_cell_present_target)
    print("neighbourhood_cell_rel_target") 
    print(neighbourhood_cell_rel_target)
    print("neighbourhood_cell_present_target") 
    print(neighbourhood_cell_present_target)

In [92]:
BATCH_SIZE = 16 
INPUT_SEQ_LEN = 3
TARGET_SEQ_LEN = 3
LEARNING_RATE = 1e-3
SAVE_VARIABLE_EVERY = 100
DATASET_WORKERS = 4


input_seq_len= torch.Tensor([INPUT_SEQ_LEN]).int()
target_seq_len= torch.Tensor([TARGET_SEQ_LEN]).int()

# Define Data.
cell_load_dataset = load_topology_dataset.LoadCellDataset(initial_cell_counts=(2, 8), initial_load_counts=(1,5),
                                    input_seq_len=INPUT_SEQ_LEN, target_seq_len=TARGET_SEQ_LEN, network_mutate_prob=[0.1, 0.1
                                                                                                                    ])
dataloader = DataLoader(cell_load_dataset, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=DATASET_WORKERS)

# Define Model
dynamic_topology_model = DynamicTopologyModel(neighbourhood_hidden_size=32, neighbourhood_cell_count=6,
                     neighbourhood_output_size=16, lstm_hidden_size=32, lstm_layers=2, teacher_forcing_probability=0.5, device=device).to(device)

parameters = dynamic_topology_model.parameters

# Define optimizer
optimizer = optim.Adam(parameters
                       , lr=LEARNING_RATE)
criterion = nn.MSELoss()
last_parameters_np = list(np.copy(param.cpu().detach().numpy()) for param in parameters)

# Train
start = datetime.now()
for batch_idx, data in enumerate(dataloader):
    optimizer.zero_grad()
    reference_cell_input =  data[0].to(device)
    reference_cell_present_input = data[1].to(device)
    neighbourhood_cell_rel_input = data[2].to(device)
    neighbourhood_cell_load_input = data[3].to(device)
    neighbourhood_cell_present_input = data[4].to(device)
    reference_cell_target = data[5].to(device)
    reference_cell_present_target = data[6].to(device)
    neighbourhood_cell_rel_target = data[7].to(device)
    neighbourhood_cell_present_target = data[8].to(device)
    
    
    decoder_output_seq = dynamic_topology_model.forward(input_seq_len, target_seq_len, reference_cell_input, reference_cell_present_input,
                neighbourhood_cell_rel_input, neighbourhood_cell_load_input, reference_cell_target, reference_cell_present_target, neighbourhood_cell_rel_target)
    
    loss = criterion(decoder_output_seq, reference_cell_target)
    loss.backward()
    optimizer.step()
    if batch_idx == 0:
        parameters = dynamic_topology_model.parameters
        parameters_np = list(param.cpu().detach().numpy() for param in parameters)
        test_parameters_update(last_parameters_np, parameters_np)
        inputs =(input_seq_len, target_seq_len, reference_cell_input, reference_cell_present_input,
                neighbourhood_cell_rel_input, neighbourhood_cell_load_input, reference_cell_target, reference_cell_present_target, neighbourhood_cell_rel_target)
        writer.add_graph(dynamic_topology_model, inputs, True)
    if batch_idx % SAVE_VARIABLE_EVERY == 0:
        #print_example(reference_cell_input, reference_cell_present_input,
        #            neighbourhood_cell_rel_input, neighbourhood_cell_load_input, neighbourhood_cell_present_input,
        #              reference_cell_target, reference_cell_present_target, neighbourhood_cell_rel_target,
        #              neighbourhood_cell_present_target)    
        end = datetime.now()
        print("{} Examples/sec".format((SAVE_VARIABLE_EVERY* BATCH_SIZE) / ((end - start).total_seconds()) ))
        start = datetime.now()
        print("loss: {}".format(loss))
        #print("output: {}".format(decoder_output_seq))
        #print("target: {}".format(reference_cell_target))
        writer.add_scalar("loss", loss, batch_idx)
        save_parameters(writer, dynamic_topology_model, batch_idx)


0.00091055036
0.0009315908
0.000985384
0.0009913072
0.0009810068
-0.0009690374
0.000996327
0.0009959042
0.0009999871
0.0
0.0009999648
0.0009999685
0.0009999345
0.0
0.000999976
0.0009999722
0.0
0.0
0.0009999573
0.0009999834
0.0009999909
0.0009999946
0.0009999946
0.0009999238
0.0009999871
0.0009999834
0.0009999871
0.0009999312
0.0009999946
0.0009999946
0.001000002
0.0009999983
graph(%0 : Int(1)
      %1 : Int(1)
      %2 : Float(16, 3, 1)
      %3 : Float(16, 3, 1)
      %4 : Float(16, 3, 16, 2)
      %5 : Float(16, 3, 16, 1)
      %6 : Float(16, 3, 1)
      %7 : Float(16, 3, 1)
      %8 : Float(16, 3, 16, 2)
      %9 : Float(32, 2)
      %10 : Float(32)
      %11 : Float(32, 1)
      %12 : Float(32)
      %13 : Float(1, 32)
      %14 : Float(1)
      %15 : Float(1, 32)
      %16 : Float(1)
      %17 : Float(128, 3)
      %18 : Float(128, 32)
      %19 : Float(128)
      %20 : Float(128)
      %21 : Float(128, 32)
      %22 : Float(128, 32)
      %23 : Float(128)
      %24 : Float(128)
 

KeyboardInterrupt: 

In [93]:
def add_cell_if_not_exists(new_cell, cells):
    cell_already_exists = False

    for i, cell in enumerate(cells):
        if cell[0] == new_cell[0] and cell[1] == new_cell[1]:
            cell_already_exists = True
            break
    if not cell_already_exists:
        cells.append(new_cell)
    return cells

def remove_cell_if_exists(remove_cell, cells):
    for i, cell in enumerate(cells):
        if cell[0] == remove_cell[0] and cell[1] == remove_cell[1]:
            cells.pop(i)
            break
    return cells


In [94]:
PIXELS = 640
LEFT = 1
RIGHT = 3
INPUT_SCREEN_COUNT = 3
TARGET_SCREEN_COUNT = 3
SCREEN_COUNT = INPUT_SCREEN_COUNT + TARGET_SCREEN_COUNT

PIXELS_PER_BLOCK = int(PIXELS / IMAGE_SIZE)
SCALE_ARR = np.ones((PIXELS_PER_BLOCK, PIXELS_PER_BLOCK, 1))

load_screen = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1))
load_screen[1,12, 0] = 255
load_screen[22,22, 0] = 255
loads = [(1, 12), (22, 22)]
block_img = np.kron(load_screen, SCALE_ARR)

screen = pygame.display.set_mode((PIXELS, PIXELS))
screens = []

screen_draw_idx = 0
last_pos = (0, 0)
color = (255, 128, 0)
radius = 10
block_screen = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 3))
model_screen = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)
frame = 0
cells_seq = [list() for i in range(SCREEN_COUNT)]
#cells_load_seq = [list() for i in range(SCREEN_COUNT)]
screen_draw_idx = 1
BATCH_SIZE=1
show_labels = True
input_seq_len= torch.Tensor([INPUT_SCREEN_COUNT]).int()
target_seq_len= torch.Tensor([TARGET_SCREEN_COUNT]).int()

load_cells_seq_input_0, load_cells_seq_target_0 = load_topology_dataset.LoadCellDataset.generate_inputs_and_targets(INPUT_SCREEN_COUNT, TARGET_SCREEN_COUNT,
                                            (2, 5), (8,8), (0.05, 0.05))

cells_load_seq = load_cells_seq_input_0 + load_cells_seq_target_0
for seq_idx in range(len(cells_load_seq)):
    for cell_idx in range(len(cells_load_seq[seq_idx])):
        cells_seq[seq_idx].append((cells_load_seq[seq_idx][cell_idx][0], cells_load_seq[seq_idx][cell_idx][1]))

while True:
    
    ##############################
    # Use model to forecast future load
    ##############################
    load_cells_seq_input = cells_load_seq[:INPUT_SCREEN_COUNT]
    load_cells_seq_target = cells_load_seq[INPUT_SCREEN_COUNT:]
    
 
    model_target = [list() for i in range(TARGET_SCREEN_COUNT)]
    
    if show_labels:
        caption = "Label frame {}".format(screen_draw_idx)
    else:
        caption = "Model frame {}".format(screen_draw_idx)
    pygame.display.set_caption(caption)
    for seq_idx  in range(TARGET_SCREEN_COUNT):
        for reference_cell in cells_seq[INPUT_SCREEN_COUNT + seq_idx]:
            #print(reference_cell)
            ref_x, ref_y = reference_cell[0], reference_cell[1]

            # Get inputs from state.
            (reference_cell_input,
             reference_cell_present_input,
             neighbourhood_cell_rel_input,
             neighbourhood_cell_load_input,
             neighbourhood_cell_present_input,
             reference_cell_target,
             reference_cell_present_target,
             neighbourhood_cell_rel_target,
             neighbourhood_cell_present_target,) = load_topology_dataset.LoadCellDataset.build_inputs_and_targets(INPUT_SCREEN_COUNT,
                                                                                                 TARGET_SCREEN_COUNT,
                                                                                                 ref_x, ref_y,
                                                                                                 load_cells_seq_input,
                                                                                                 load_cells_seq_target)
            # Infer output from model.  Add batch dimension to each input since dataloader is not used here.
            decoder_output_seq = dynamic_topology_model.forward(input_seq_len, target_seq_len,
                                                                reference_cell_input.unsqueeze(0),
                                                                reference_cell_present_input.unsqueeze(0),
                                                                neighbourhood_cell_rel_input.unsqueeze(0),
                                                                neighbourhood_cell_load_input.unsqueeze(0),
                                                                reference_cell_target.unsqueeze(0),
                                                                reference_cell_present_target.unsqueeze(0),
                                                                neighbourhood_cell_rel_target.unsqueeze(0))

            model_target[seq_idx].append((ref_x, ref_y, decoder_output_seq.detach().numpy()[0][seq_idx][0]))    
    ################################
    # DISPLAY
    #################################
    screen_draw = pygame.Surface((PIXELS, PIXELS))
    
    #print(cells_seq)
    cells = cells_seq[screen_draw_idx]
    cell_loads = cells_load_seq[screen_draw_idx]

    if (screen_draw_idx < INPUT_SCREEN_COUNT) or (show_labels and screen_draw_idx >= INPUT_SCREEN_COUNT):
        for cell in cell_loads:
            color = (cell[2] * 1024) - 256
            if color > 255:
                color = 255
            if color < 0:
                color = 0
            pygame.draw.rect(
            screen_draw,
            (color, 50, 0),
            (
                cell[0] * PIXELS_PER_BLOCK,
                cell[1] * PIXELS_PER_BLOCK,
                PIXELS_PER_BLOCK,
                PIXELS_PER_BLOCK,),
            )
    else:
        model_loads = model_target[screen_draw_idx - INPUT_SCREEN_COUNT]
        for cell in model_loads:
            color = (cell[2] * 1024) - 256
            if color > 255:
                color = 255
            if color < 0:
                color = 0
            pygame.draw.rect(
            screen_draw,
            (color, 50, 0),
            (
                cell[0] * PIXELS_PER_BLOCK,
                cell[1] * PIXELS_PER_BLOCK,
                PIXELS_PER_BLOCK,
                PIXELS_PER_BLOCK,),
            )

    e = pygame.event.poll()


    if e.type == pygame.QUIT:
        raise StopIteration
    if e.type == pygame.MOUSEBUTTONDOWN:
        leftclick, middleclick, rightclick = pygame.mouse.get_pressed()
        x, y = e.pos
        x_cell= int(x / PIXELS_PER_BLOCK)
        y_cell = int(y / PIXELS_PER_BLOCK)
        if leftclick:
            cells = add_cell_if_not_exists((x_cell, y_cell), cells)
        elif rightclick:
            cells = remove_cell_if_exists((x_cell, y_cell), cells)            
        # Update cells
        cells_seq[screen_draw_idx] = cells
        
        # Update all cell loads
        for i, cells in enumerate(cells_seq):
            cells_load_seq[i] = load.calculate_cell_load(cells, loads)
    
    if e.type == pygame.KEYDOWN:
        if e.key == pygame.K_SPACE:
            if show_labels:
                show_labels = False
            else:
                show_labels = True
                print(model_target[-1])
                print(load_cells_seq_target[-1])
        if e.key ==  pygame.K_LEFT:
            if screen_draw_idx > 0:
                screen_draw_idx -= 1
        if e.key ==  pygame.K_RIGHT:
            if screen_draw_idx < SCREEN_COUNT - 1:
                screen_draw_idx += 1
    pygame.display.flip()
    target_screen = pygame.surfarray.array3d(screen_draw)
    display_img = block_img + target_screen
    new_surf = pygame.pixelcopy.make_surface(display_img.astype(np.uint8))
    screen.blit(new_surf, (0, 0))


[(23, 12, 0.37337315), (27, 11, 0.35980818), (5, 2, 0.24901387), (31, 28, 0.34813836), (26, 14, 0.22824855), (18, 25, 0.43930426)]
[(23, 12, 0.3505999796372732), (27, 11, 0.33936809998639483), (5, 2, 0.24915755975243975), (31, 28, 0.3408943471808886), (26, 14, 0.35254846058686096), (18, 25, 0.3674315528561427)]
[(23, 12, 0.3594452), (27, 11, 0.34660953), (5, 2, 0.25358844), (31, 28, 0.34813836), (26, 14, 0.35279292), (18, 25, 0.4355334)]
[(23, 12, 0.3505999796372732), (27, 11, 0.33936809998639483), (5, 2, 0.24915755975243975), (31, 28, 0.3408943471808886), (26, 14, 0.35254846058686096), (18, 25, 0.3674315528561427)]
[(23, 12, 0.3594452), (27, 11, 0.36119333), (5, 2, 0.24901387), (31, 28, 0.34776133), (26, 14, 0.22824855), (18, 25, 0.4355334)]
[(23, 12, 0.3505999796372732), (27, 11, 0.33936809998639483), (5, 2, 0.24915755975243975), (31, 28, 0.3408943471808886), (26, 14, 0.35254846058686096), (18, 25, 0.3674315528561427)]
[(23, 12, 0.37603754), (27, 11, 0.36119333), (5, 2, 0.25055832), 

KeyboardInterrupt: 

In [None]:
for seq_idx  in range(TARGET_SCREEN_COUNT):
    print(seq_idx)
    model_target[seq_idx].append((ref_x, ref_y, decoder_output_seq.detach().numpy()[0][seq_idx][0]))

In [None]:

cell_load_dataset = load_topology_dataset.LoadCellDataset(initial_cell_counts=(2, 5), initial_load_counts=(2,24),
                                    input_seq_len=INPUT_SEQ_LEN, target_seq_len=TARGET_SEQ_LEN, network_mutate_prob=[0.2, 0.2])



In [12]:
decoder_output.detach().numpy()[seq_idx][0]

NameError: name 'decoder_output' is not defined

In [38]:
len(cells_load_seq[0])

8

In [42]:
cells_seq = [(cell_load[0], cell_load[1]) for cell_load in cell_load for cell_load_seq in cells_load_seq]

SyntaxError: invalid syntax (<ipython-input-42-8e3eb257a1ca>, line 1)