In [2]:
from DSPN import settoset, hungarian_loss 
import scipy.optimize
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import math
import numpy as np 
from torch.utils.data import Dataset
import pandas as pd 
import matplotlib
from matplotlib import pyplot as plt
from fspool import FSPool
class FSEncoder(nn.Module):
    ##Set encoder from the DSPN/FSencoder papers
    def __init__(self, input_channels, output_channels, dim,set_dim, mask = False):
        super().__init__()
        self.set_dim = set_dim
        n_out = 30
        if mask == True: 
            self.mask = 1 
        else: 
            self.mask = 0
        self.conv = nn.Sequential(
            nn.Conv1d(input_channels + self.mask, dim, 1),
           ## nn.ReLU(),
           ## nn.Conv1d(dim, dim, 1),
            nn.ReLU(),
            nn.Conv1d(dim, output_channels, 1),
        )
        self.pool = FSPool(output_channels, n_out, relaxed=False)

    def forward(self, x, mask=None):
        ##mask = mask.unsqueeze(1)
        ##x = torch.cat([x, mask], dim=1)  # include mask as part of set
        x = x.reshape((self.set_dim[1],-1)).unsqueeze(2).permute(2,0,1)

        x = self.conv(x)
        x = x / x.size(2)  # normalise so that activations aren't too high with big sets
        x, _ = self.pool(x)
        return x

    
    
import csv
import json 

##Dataset class for predicting new objects from the currect state
class newobjectdataset(Dataset): 
    def __init__(self, data_root, label_dim = 1):
      
        self.data = []
        
        self.label_dim = label_dim
        results = []
        with open(data_root) as csvfile:
            reader = csv.reader(csvfile, delimiter = "|") # change contents to floats
            for row in reader: # each row is a list
                dataentry = json.loads(row[0])
                results.append(dataentry)
        self.data = results
        
    def __getitem__(self,idx): 
        return torch.Tensor(self.data[2*idx]), torch.Tensor(self.data[2*idx + 1])
    def __len__(self): 
        return int(len(self.data)/2)

In [7]:
rate = .0005

dataset = newobjectdataset("newstate.csv")

setdimout = (1,2)  #(num objects, object length)
setdimin = (1,14)
iterator = torch.utils.data.DataLoader(dataset)

##initialize encoder and DSPN encoder
encoder1 = FSEncoder(14,32,32, setdimin)
encoder2 = FSEncoder(2,32,32, setdimout)
num_actions = 6

#construct settoset network
setnet = settoset(encoder1,encoder2,32,setdimout,10, masks = True)
lambda1 = 6
optimizer = optim.Adam(setnet.parameters(),lr = rate)
running_loss = 0
epochs = 12
setnet.train()
loss_func = hungarian_loss
loss_func2 = nn.MSELoss()
for e in range(epochs): 
    print(running_loss)
    running_loss = 0
    accuracy = 0
    i = 0
    torch.set_grad_enabled(True)
    for set_,targetset_ in iterator: 
        i += 1
        
        if sum(targetset_[0]) == 0: 
            continue 
        optimizer.zero_grad()
        
        #All of this is because currently it is training to predict the next player position from the last one + action
        set_.requires_grad = True
        set_ = set_[0]
        action = set_[-1*num_actions:]
        set_ = set_[0:8]
        
        ##After reshaping the data from one long list back into sets, need to transpose it to get correct shape
        targetset_ = targetset_[0][0:2].reshape(1,2).transpose(1,0)/10
        
        set_ = set_.reshape(int(len(set_)/8),8)
        
        set_ = torch.cat((set_,action.unsqueeze(0)),1)/10
        
        
        
     

        out = setnet(set_)

        #compute the loss and propagate backwards
        loss =     loss_func(out,targetset_,setdimout) ##+ loss_func2(angle,setnet.encoder(set_))
        
     
        loss.backward(retain_graph = True)
        optimizer.step()
        running_loss += loss
        if i == 1000: 
            ##for parameter in setnet.parameters(): print(parameter)
            print(running_loss) 
            print(loss)
            running_loss = 0
            print(out)
            i = 0 

0
tensor(72.9373, grad_fn=<AddBackward0>)
tensor(0.0630, grad_fn=<AddBackward0>)
tensor([0.7413, 0.9475], grad_fn=<SubBackward0>)
tensor(29.2179, grad_fn=<AddBackward0>)
tensor(0.0139, grad_fn=<AddBackward0>)
tensor([0.6569, 0.6903], grad_fn=<SubBackward0>)
tensor(24.7251, grad_fn=<AddBackward0>)
tensor(0.0535, grad_fn=<AddBackward0>)
tensor([0.9472, 0.3216], grad_fn=<SubBackward0>)
tensor(24.3917, grad_fn=<AddBackward0>)
tensor(0.0148, grad_fn=<AddBackward0>)
tensor([0.8025, 0.4348], grad_fn=<SubBackward0>)
tensor(28.7355, grad_fn=<AddBackward0>)
tensor(0.0031, grad_fn=<AddBackward0>)
tensor([0.7492, 0.5218], grad_fn=<SubBackward0>)
tensor(29.6673, grad_fn=<AddBackward0>)
tensor(0.0783, grad_fn=<AddBackward0>)
tensor([0.7324, 0.2534], grad_fn=<SubBackward0>)
tensor(26.7821, grad_fn=<AddBackward0>)
tensor(0.1035, grad_fn=<AddBackward0>)
tensor([0.8359, 0.4084], grad_fn=<SubBackward0>)
tensor(19.6304, grad_fn=<AddBackward0>)
tensor(0.0097, grad_fn=<AddBackward0>)
tensor([0.8348, 0.1741]

tensor([0.3916, 0.7694], grad_fn=<SubBackward0>)
tensor(2.7894, grad_fn=<AddBackward0>)
tensor(0.0003, grad_fn=<AddBackward0>)
tensor([0.5175, 0.6994], grad_fn=<SubBackward0>)
tensor(3.1762, grad_fn=<AddBackward0>)
tensor(0.0011, grad_fn=<AddBackward0>)
tensor([0.1218, 0.4253], grad_fn=<SubBackward0>)
tensor(2.6432, grad_fn=<AddBackward0>)
tensor(0.0030, grad_fn=<AddBackward0>)
tensor([0.1841, 0.2523], grad_fn=<SubBackward0>)
tensor(3.7165, grad_fn=<AddBackward0>)
tensor(0.0005, grad_fn=<AddBackward0>)
tensor([0.4214, 0.7988], grad_fn=<SubBackward0>)
tensor(2.5122, grad_fn=<AddBackward0>)
tensor(3.1870e-05, grad_fn=<AddBackward0>)
tensor([0.7958, 0.7962], grad_fn=<SubBackward0>)
tensor(3.8484, grad_fn=<AddBackward0>)
tensor(0.0007, grad_fn=<AddBackward0>)
tensor([0.1239, 0.3134], grad_fn=<SubBackward0>)
tensor(2.9848, grad_fn=<AddBackward0>)
tensor(0.0004, grad_fn=<AddBackward0>)
tensor([0.0901, 0.8187], grad_fn=<SubBackward0>)
tensor(3.0973, grad_fn=<AddBackward0>)
tensor(0.0016, grad

KeyboardInterrupt: 

In [19]:
set_[0][-4] = 1 

In [24]:
setnet(set_)

tensor([-0.1681,  0.7126], grad_fn=<SubBackward0>)

In [25]:
set_

tensor([[ 0.8000,  0.3000,  0.0000, -0.1000,  0.1000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000]],
       grad_fn=<CopySlices>)

tensor(5.)